[Mlir-commits] [mlir] [mlir][SPIR-V][VectorToSPIRV] Add conversion patterns for vector.gather/scatter to SPIR-V (PR #193422)
Arseniy Obolenskiy
llvmlistbot at llvm.org
Wed Apr 22 22:48:11 PDT 2026
https://github.com/aobolensk updated https://github.com/llvm/llvm-project/pull/193422
>From 0d2c61ad32012fb827847d0eb324c999a7340106 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Wed, 22 Apr 2026 09:06:38 +0200
Subject: [PATCH 1/2] [mlir][SPIR-V][VectorToSPIRV] Add conversion patterns for
vector.gather/scatter to SPIR-V
Add VectorGatherOpConverter and VectorScatterOpConverter that lower vector.gather and vector.scatter to spirv.INTEL.MaskedGather and spirv.INTEL.MaskedScatter respectively (SPV_INTEL_masked_gather_scatter extension)
Extend CompositeConstruct ODS result type constraint with SPIRV_CompositeOrPtrVector, which adds SPIRV_PtrVector alongside the existing SPIRV_Composite types
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 8 +
.../Dialect/SPIRV/IR/SPIRVCompositeOps.td | 2 +-
.../VectorToSPIRV/VectorToSPIRV.cpp | 164 +++++++++++++++++-
.../VectorToSPIRV/vector-to-spirv.mlir | 105 +++++++++++
4 files changed, 276 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 27abd13b8ddb1..b2c2484552be4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4342,6 +4342,14 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
class SPIRV_VectorOf<Type type> :
FixedVectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16], [type]>;
+// Pointer vector type (SPV_INTEL_masked_gather_scatter extension).
+def SPIRV_PtrVector : SPIRV_VectorOf<SPIRV_AnyPtr>;
+// Composite types extended with pointer vectors (SPV_INTEL_masked_gather_scatter).
+def SPIRV_CompositeOrPtrVector :
+ AnyTypeOf<[SPIRV_Vector, SPIRV_PtrVector,
+ SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnyTensorArm]>;
+
class SPIRV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
index 981131484498d..111668ccbf774 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
@@ -63,7 +63,7 @@ def SPIRV_CompositeConstructOp : SPIRV_Op<"CompositeConstruct", [Pure]> {
);
let results = (outs
- SPIRV_Composite:$result
+ SPIRV_CompositeOrPtrVector:$result
);
let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 921075736e97b..202173bf4c947 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -853,6 +853,165 @@ struct VectorStoreOpConverter final
}
};
+struct VectorGatherOpConverter final
+ : public OpConversionPattern<vector::GatherOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(vector::GatherOp gatherOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only support 1-D result vectors.
+ auto vectorType = gatherOp.getVectorType();
+ if (vectorType.getRank() != 1)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "only 1-D vectors supported");
+
+ // Only support memref base (not tensor).
+ auto memrefType = dyn_cast<MemRefType>(gatherOp.getBaseType());
+ if (!memrefType)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "only memref base supported");
+
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "expected spirv.storage_class");
+
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto loc = gatherOp.getLoc();
+
+ // Compute base element pointer from memref + offsets.
+ Value basePtr =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+ adaptor.getOffsets(), loc, rewriter);
+ if (!basePtr)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "failed to get element pointer");
+
+ // Convert element type and construct pointer vector type.
+ auto storageClass = attr.getValue();
+ Type elementType = typeConverter.convertType(memrefType.getElementType());
+ if (!elementType)
+ return rewriter.notifyMatchFailure(gatherOp, "unsupported element type");
+ auto ptrType = spirv::PointerType::get(elementType, storageClass);
+ int64_t numElements = vectorType.getDimSize(0);
+ auto ptrVectorType = VectorType::get({numElements}, ptrType);
+
+ // Build pointer vector: for each index, compute ptr via PtrAccessChain.
+ auto indexType = typeConverter.getIndexType();
+ SmallVector<Value> pointers;
+ for (int64_t i = 0; i < numElements; ++i) {
+ auto i32Type = rewriter.getI32Type();
+ Value idx = spirv::ConstantOp::create(rewriter, loc, i32Type,
+ rewriter.getI32IntegerAttr(i));
+ Value scalarIndex = spirv::VectorExtractDynamicOp::create(
+ rewriter, loc, adaptor.getIndices(), idx);
+ // Cast index to the SPIR-V index type if needed.
+ if (scalarIndex.getType() != indexType)
+ scalarIndex =
+ spirv::SConvertOp::create(rewriter, loc, indexType, scalarIndex);
+ Value ptr = spirv::PtrAccessChainOp::create(rewriter, loc, basePtr,
+ scalarIndex, /*indices=*/{});
+ pointers.push_back(ptr);
+ }
+ Value ptrVector = spirv::CompositeConstructOp::create(
+ rewriter, loc, ptrVectorType, pointers);
+
+ // Alignment.
+ auto i32Type = rewriter.getI32Type();
+ uint32_t align = gatherOp.getAlignment().value_or(0);
+ Value alignmentVal = spirv::ConstantOp::create(
+ rewriter, loc, i32Type, rewriter.getI32IntegerAttr(align));
+
+ // Create spirv.INTEL.MaskedGather.
+ auto resultType = typeConverter.convertType(vectorType);
+ rewriter.replaceOpWithNewOp<spirv::INTELMaskedGatherOp>(
+ gatherOp, resultType, ptrVector, alignmentVal, adaptor.getMask(),
+ adaptor.getPassThru());
+ return success();
+ }
+};
+
+struct VectorScatterOpConverter final
+ : public OpConversionPattern<vector::ScatterOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(vector::ScatterOp scatterOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only support 1-D vectors.
+ auto vectorType = scatterOp.getVectorType();
+ if (vectorType.getRank() != 1)
+ return rewriter.notifyMatchFailure(scatterOp,
+ "only 1-D vectors supported");
+
+ // Only support memref base (not tensor).
+ auto memrefType = dyn_cast<MemRefType>(scatterOp.getBaseType());
+ if (!memrefType)
+ return rewriter.notifyMatchFailure(scatterOp,
+ "only memref base supported");
+
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return rewriter.notifyMatchFailure(scatterOp,
+ "expected spirv.storage_class");
+
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto loc = scatterOp.getLoc();
+
+ // Compute base element pointer from memref + offsets.
+ Value basePtr =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+ adaptor.getOffsets(), loc, rewriter);
+ if (!basePtr)
+ return rewriter.notifyMatchFailure(scatterOp,
+ "failed to get element pointer");
+
+ // Convert element type and construct pointer vector type.
+ auto storageClass = attr.getValue();
+ Type elementType = typeConverter.convertType(memrefType.getElementType());
+ if (!elementType)
+ return rewriter.notifyMatchFailure(scatterOp, "unsupported element type");
+ auto ptrType = spirv::PointerType::get(elementType, storageClass);
+ int64_t numElements = vectorType.getDimSize(0);
+ auto ptrVectorType = VectorType::get({numElements}, ptrType);
+
+ // Build pointer vector: for each index, compute ptr via PtrAccessChain.
+ auto indexType = typeConverter.getIndexType();
+ SmallVector<Value> pointers;
+ for (int64_t i = 0; i < numElements; ++i) {
+ auto i32Type = rewriter.getI32Type();
+ Value idx = spirv::ConstantOp::create(rewriter, loc, i32Type,
+ rewriter.getI32IntegerAttr(i));
+ Value scalarIndex = spirv::VectorExtractDynamicOp::create(
+ rewriter, loc, adaptor.getIndices(), idx);
+ if (scalarIndex.getType() != indexType)
+ scalarIndex =
+ spirv::SConvertOp::create(rewriter, loc, indexType, scalarIndex);
+ Value ptr = spirv::PtrAccessChainOp::create(rewriter, loc, basePtr,
+ scalarIndex, /*indices=*/{});
+ pointers.push_back(ptr);
+ }
+ Value ptrVector = spirv::CompositeConstructOp::create(
+ rewriter, loc, ptrVectorType, pointers);
+
+ // Alignment.
+ auto i32Type = rewriter.getI32Type();
+ uint32_t align = scatterOp.getAlignment().value_or(0);
+ Value alignmentVal = spirv::ConstantOp::create(
+ rewriter, loc, i32Type, rewriter.getI32IntegerAttr(align));
+
+ // Create spirv.INTEL.MaskedScatter.
+ spirv::INTELMaskedScatterOp::create(rewriter, loc, ptrVector, alignmentVal,
+ adaptor.getMask(),
+ adaptor.getValueToStore());
+ rewriter.eraseOp(scatterOp);
+ return success();
+ }
+};
+
struct VectorReductionToIntDotProd final
: OpRewritePattern<vector::ReductionOp> {
using Base::Base;
@@ -1112,8 +1271,9 @@ void mlir::populateVectorToSPIRVPatterns(
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
VectorScalarBroadcastPattern, VectorLoadOpConverter,
- VectorStoreOpConverter, VectorStepOpConvert>(
- typeConverter, patterns.getContext(), PatternBenefit(1));
+ VectorStoreOpConverter, VectorGatherOpConverter, VectorScatterOpConverter,
+ VectorStepOpConvert>(typeConverter, patterns.getContext(),
+ PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 48a1298bc4877..9f275eb39a56a 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1224,3 +1224,108 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
%0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
return %0: vector<4xf32>
}
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Kernel, Addresses, MaskedGatherScatterINTEL],
+ [SPV_INTEL_masked_gather_scatter]>,
+ #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @vector_gather
+// CHECK-SAME: %[[BASE:.+]]: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>
+// CHECK-SAME: %[[INDICES:.+]]: vector<4xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<4xi1>
+// CHECK-SAME: %[[PT:.+]]: vector<4xf32>
+// CHECK-DAG: %[[IDXVEC:.+]] = builtin.unrealized_conversion_cast %[[INDICES]] : vector<4xindex> to vector<4xi32>
+// CHECK: %[[EPTR:.+]] = spirv.AccessChain %{{.+}}[%{{.+}}] : !spirv.ptr<!spirv.array<16 x f32>, CrossWorkgroup>, i32
+// CHECK: %[[IDX0:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P0:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX0]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX1:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P1:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX2:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P2:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX2]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX3:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P3:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX3]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[PTRVEC:.+]] = spirv.CompositeConstruct %[[P0]], %[[P1]], %[[P2]], %[[P3]]
+// CHECK-SAME: -> vector<4x!spirv.ptr<f32, CrossWorkgroup>>
+// CHECK: %[[ALIGN:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[RES:.+]] = spirv.INTEL.MaskedGather %[[PTRVEC]], %[[ALIGN]], %[[MASK]], %[[PT]]
+// CHECK-SAME: : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32> -> vector<4xf32>
+// CHECK: return %[[RES]]
+func.func @vector_gather(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %pass_thru: vector<4xf32>) -> vector<4xf32> {
+ %c0 = arith.constant 0 : index
+ %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %result : vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_scatter
+// CHECK-SAME: %[[BASE:.+]]: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>
+// CHECK-SAME: %[[INDICES:.+]]: vector<4xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<4xi1>
+// CHECK-SAME: %[[VALUES:.+]]: vector<4xf32>
+// CHECK-DAG: %[[IDXVEC:.+]] = builtin.unrealized_conversion_cast %[[INDICES]] : vector<4xindex> to vector<4xi32>
+// CHECK: %[[EPTR:.+]] = spirv.AccessChain %{{.+}}[%{{.+}}] : !spirv.ptr<!spirv.array<16 x f32>, CrossWorkgroup>, i32
+// CHECK: %[[IDX0:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P0:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX0]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX1:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P1:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX2:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P2:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX2]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX3:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P3:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX3]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[PTRVEC:.+]] = spirv.CompositeConstruct %[[P0]], %[[P1]], %[[P2]], %[[P3]]
+// CHECK-SAME: -> vector<4x!spirv.ptr<f32, CrossWorkgroup>>
+// CHECK: %[[ALIGN:.+]] = spirv.Constant 0 : i32
+// CHECK: spirv.INTEL.MaskedScatter %[[PTRVEC]], %[[ALIGN]], %[[MASK]], %[[VALUES]]
+// CHECK-SAME: : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32>
+// CHECK: return
+func.func @vector_scatter(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %values: vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.scatter %base[%c0][%indices], %mask, %values
+ : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: @vector_gather_i32
+// CHECK-SAME: %[[BASE:.+]]: memref<16xi32, #spirv.storage_class<CrossWorkgroup>>
+// CHECK: spirv.INTEL.MaskedGather
+// CHECK-SAME: vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xi32> -> vector<4xi32>
+func.func @vector_gather_i32(%base: memref<16xi32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %pass_thru: vector<4xi32>) -> vector<4xi32> {
+ %c0 = arith.constant 0 : index
+ %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ : memref<16xi32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xi32> into vector<4xi32>
+ return %result : vector<4xi32>
+}
+
+// CHECK-LABEL: @vector_gather_with_alignment
+// CHECK: %[[ALIGN:.+]] = spirv.Constant 4 : i32
+// CHECK: spirv.INTEL.MaskedGather %{{.+}}, %[[ALIGN]]
+func.func @vector_gather_with_alignment(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %pass_thru: vector<4xf32>) -> vector<4xf32> {
+ %c0 = arith.constant 0 : index
+ %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ { alignment = 4 : i64 }
+ : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %result : vector<4xf32>
+}
+
+} // end module
>From d7e12949730adfafbc884430625a557ce6df3eb2 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Thu, 23 Apr 2026 07:45:27 +0200
Subject: [PATCH 2/2] Lower gather/scatter conditionally
---
.../Conversion/VectorToSPIRV/VectorToSPIRV.h | 5 +
.../VectorToSPIRV/VectorToSPIRV.cpp | 21 +++-
.../vector-gather-scatter-to-spirv-intel.mlir | 104 +++++++++++++++++
.../VectorToSPIRV/vector-to-spirv.mlir | 105 ------------------
.../Conversion/VectorToSPIRV/CMakeLists.txt | 4 +
.../TestVectorGatherScatterToSPIRV.cpp | 71 ++++++++++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
7 files changed, 204 insertions(+), 108 deletions(-)
create mode 100644 mlir/test/Conversion/VectorToSPIRV/vector-gather-scatter-to-spirv-intel.mlir
create mode 100644 mlir/test/lib/Conversion/VectorToSPIRV/TestVectorGatherScatterToSPIRV.cpp
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index b64be4f733ec0..a39ef19c99271 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -32,6 +32,11 @@ void populateVectorToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
void populateVectorReductionToSPIRVDotProductPatterns(
RewritePatternSet &patterns);
+/// Appends patterns that lower vector.gather/vector.scatter to the
+/// SPV_INTEL_masked_gather_scatter extension ops.
+void populateVectorGatherScatterToSPIRVPatterns(
+ const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
+
} // namespace mlir
#endif // MLIR_CONVERSION_VECTORTOSPIRV_VECTORTOSPIRV_H
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 202173bf4c947..48dbb4363682a 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -879,6 +879,11 @@ struct VectorGatherOpConverter final
"expected spirv.storage_class");
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ if (!typeConverter.getTargetEnv().allows(
+ spirv::Extension::SPV_INTEL_masked_gather_scatter))
+ return rewriter.notifyMatchFailure(gatherOp,
+ "target environment does not enable "
+ "SPV_INTEL_masked_gather_scatter");
auto loc = gatherOp.getLoc();
// Compute base element pointer from memref + offsets.
@@ -959,6 +964,11 @@ struct VectorScatterOpConverter final
"expected spirv.storage_class");
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ if (!typeConverter.getTargetEnv().allows(
+ spirv::Extension::SPV_INTEL_masked_gather_scatter))
+ return rewriter.notifyMatchFailure(scatterOp,
+ "target environment does not enable "
+ "SPV_INTEL_masked_gather_scatter");
auto loc = scatterOp.getLoc();
// Compute base element pointer from memref + offsets.
@@ -1271,9 +1281,8 @@ void mlir::populateVectorToSPIRVPatterns(
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
VectorScalarBroadcastPattern, VectorLoadOpConverter,
- VectorStoreOpConverter, VectorGatherOpConverter, VectorScatterOpConverter,
- VectorStepOpConvert>(typeConverter, patterns.getContext(),
- PatternBenefit(1));
+ VectorStoreOpConverter, VectorStepOpConvert>(
+ typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
@@ -1285,3 +1294,9 @@ void mlir::populateVectorReductionToSPIRVDotProductPatterns(
RewritePatternSet &patterns) {
patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
}
+
+void mlir::populateVectorGatherScatterToSPIRVPatterns(
+ const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<VectorGatherOpConverter, VectorScatterOpConverter>(
+ typeConverter, patterns.getContext(), PatternBenefit(1));
+}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-gather-scatter-to-spirv-intel.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-gather-scatter-to-spirv-intel.mlir
new file mode 100644
index 0000000000000..a0ffdeec8e3da
--- /dev/null
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-gather-scatter-to-spirv-intel.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt -split-input-file -test-vector-gather-scatter-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Kernel, Addresses, MaskedGatherScatterINTEL],
+ [SPV_INTEL_masked_gather_scatter]>,
+ #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @vector_gather
+// CHECK-SAME: %[[BASE:.+]]: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>
+// CHECK-SAME: %[[INDICES:.+]]: vector<4xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<4xi1>
+// CHECK-SAME: %[[PT:.+]]: vector<4xf32>
+// CHECK-DAG: %[[IDXVEC:.+]] = builtin.unrealized_conversion_cast %[[INDICES]] : vector<4xindex> to vector<4xi32>
+// CHECK: %[[EPTR:.+]] = spirv.AccessChain %{{.+}}[%{{.+}}] : !spirv.ptr<!spirv.array<16 x f32>, CrossWorkgroup>, i32
+// CHECK: %[[IDX0:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P0:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX0]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX1:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P1:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX2:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P2:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX2]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX3:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P3:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX3]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[PTRVEC:.+]] = spirv.CompositeConstruct %[[P0]], %[[P1]], %[[P2]], %[[P3]]
+// CHECK-SAME: -> vector<4x!spirv.ptr<f32, CrossWorkgroup>>
+// CHECK: %[[ALIGN:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[RES:.+]] = spirv.INTEL.MaskedGather %[[PTRVEC]], %[[ALIGN]], %[[MASK]], %[[PT]]
+// CHECK-SAME: : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32> -> vector<4xf32>
+// CHECK: return %[[RES]]
+func.func @vector_gather(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %pass_thru: vector<4xf32>) -> vector<4xf32> {
+ %c0 = arith.constant 0 : index
+ %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %result : vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_scatter
+// CHECK-SAME: %[[BASE:.+]]: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>
+// CHECK-SAME: %[[INDICES:.+]]: vector<4xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<4xi1>
+// CHECK-SAME: %[[VALUES:.+]]: vector<4xf32>
+// CHECK-DAG: %[[IDXVEC:.+]] = builtin.unrealized_conversion_cast %[[INDICES]] : vector<4xindex> to vector<4xi32>
+// CHECK: %[[EPTR:.+]] = spirv.AccessChain %{{.+}}[%{{.+}}] : !spirv.ptr<!spirv.array<16 x f32>, CrossWorkgroup>, i32
+// CHECK: %[[IDX0:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P0:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX0]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX1:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P1:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX2:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P2:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX2]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX3:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P3:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX3]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[PTRVEC:.+]] = spirv.CompositeConstruct %[[P0]], %[[P1]], %[[P2]], %[[P3]]
+// CHECK-SAME: -> vector<4x!spirv.ptr<f32, CrossWorkgroup>>
+// CHECK: %[[ALIGN:.+]] = spirv.Constant 0 : i32
+// CHECK: spirv.INTEL.MaskedScatter %[[PTRVEC]], %[[ALIGN]], %[[MASK]], %[[VALUES]]
+// CHECK-SAME: : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32>
+// CHECK: return
+func.func @vector_scatter(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %values: vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.scatter %base[%c0][%indices], %mask, %values
+ : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: @vector_gather_i32
+// CHECK-SAME: %[[BASE:.+]]: memref<16xi32, #spirv.storage_class<CrossWorkgroup>>
+// CHECK: spirv.INTEL.MaskedGather
+// CHECK-SAME: vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xi32> -> vector<4xi32>
+func.func @vector_gather_i32(%base: memref<16xi32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %pass_thru: vector<4xi32>) -> vector<4xi32> {
+ %c0 = arith.constant 0 : index
+ %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ : memref<16xi32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xi32> into vector<4xi32>
+ return %result : vector<4xi32>
+}
+
+// CHECK-LABEL: @vector_gather_with_alignment
+// CHECK: %[[ALIGN:.+]] = spirv.Constant 4 : i32
+// CHECK: spirv.INTEL.MaskedGather %{{.+}}, %[[ALIGN]]
+func.func @vector_gather_with_alignment(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %pass_thru: vector<4xf32>) -> vector<4xf32> {
+ %c0 = arith.constant 0 : index
+ %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ { alignment = 4 : i64 }
+ : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %result : vector<4xf32>
+}
+
+} // end module
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 9f275eb39a56a..48a1298bc4877 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1224,108 +1224,3 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
%0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
return %0: vector<4xf32>
}
-
-// -----
-
-module attributes {
- spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.0, [Kernel, Addresses, MaskedGatherScatterINTEL],
- [SPV_INTEL_masked_gather_scatter]>,
- #spirv.resource_limits<>>
-} {
-
-// CHECK-LABEL: @vector_gather
-// CHECK-SAME: %[[BASE:.+]]: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>
-// CHECK-SAME: %[[INDICES:.+]]: vector<4xindex>
-// CHECK-SAME: %[[MASK:.+]]: vector<4xi1>
-// CHECK-SAME: %[[PT:.+]]: vector<4xf32>
-// CHECK-DAG: %[[IDXVEC:.+]] = builtin.unrealized_conversion_cast %[[INDICES]] : vector<4xindex> to vector<4xi32>
-// CHECK: %[[EPTR:.+]] = spirv.AccessChain %{{.+}}[%{{.+}}] : !spirv.ptr<!spirv.array<16 x f32>, CrossWorkgroup>, i32
-// CHECK: %[[IDX0:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
-// CHECK: %[[P0:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX0]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
-// CHECK: %[[IDX1:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
-// CHECK: %[[P1:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
-// CHECK: %[[IDX2:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
-// CHECK: %[[P2:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX2]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
-// CHECK: %[[IDX3:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
-// CHECK: %[[P3:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX3]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
-// CHECK: %[[PTRVEC:.+]] = spirv.CompositeConstruct %[[P0]], %[[P1]], %[[P2]], %[[P3]]
-// CHECK-SAME: -> vector<4x!spirv.ptr<f32, CrossWorkgroup>>
-// CHECK: %[[ALIGN:.+]] = spirv.Constant 0 : i32
-// CHECK: %[[RES:.+]] = spirv.INTEL.MaskedGather %[[PTRVEC]], %[[ALIGN]], %[[MASK]], %[[PT]]
-// CHECK-SAME: : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32> -> vector<4xf32>
-// CHECK: return %[[RES]]
-func.func @vector_gather(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
- %indices: vector<4xindex>,
- %mask: vector<4xi1>,
- %pass_thru: vector<4xf32>) -> vector<4xf32> {
- %c0 = arith.constant 0 : index
- %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
- : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
- vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
- return %result : vector<4xf32>
-}
-
-// CHECK-LABEL: @vector_scatter
-// CHECK-SAME: %[[BASE:.+]]: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>
-// CHECK-SAME: %[[INDICES:.+]]: vector<4xindex>
-// CHECK-SAME: %[[MASK:.+]]: vector<4xi1>
-// CHECK-SAME: %[[VALUES:.+]]: vector<4xf32>
-// CHECK-DAG: %[[IDXVEC:.+]] = builtin.unrealized_conversion_cast %[[INDICES]] : vector<4xindex> to vector<4xi32>
-// CHECK: %[[EPTR:.+]] = spirv.AccessChain %{{.+}}[%{{.+}}] : !spirv.ptr<!spirv.array<16 x f32>, CrossWorkgroup>, i32
-// CHECK: %[[IDX0:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
-// CHECK: %[[P0:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX0]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
-// CHECK: %[[IDX1:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
-// CHECK: %[[P1:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
-// CHECK: %[[IDX2:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
-// CHECK: %[[P2:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX2]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
-// CHECK: %[[IDX3:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
-// CHECK: %[[P3:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX3]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
-// CHECK: %[[PTRVEC:.+]] = spirv.CompositeConstruct %[[P0]], %[[P1]], %[[P2]], %[[P3]]
-// CHECK-SAME: -> vector<4x!spirv.ptr<f32, CrossWorkgroup>>
-// CHECK: %[[ALIGN:.+]] = spirv.Constant 0 : i32
-// CHECK: spirv.INTEL.MaskedScatter %[[PTRVEC]], %[[ALIGN]], %[[MASK]], %[[VALUES]]
-// CHECK-SAME: : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32>
-// CHECK: return
-func.func @vector_scatter(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
- %indices: vector<4xindex>,
- %mask: vector<4xi1>,
- %values: vector<4xf32>) {
- %c0 = arith.constant 0 : index
- vector.scatter %base[%c0][%indices], %mask, %values
- : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
- vector<4xindex>, vector<4xi1>, vector<4xf32>
- return
-}
-
-// CHECK-LABEL: @vector_gather_i32
-// CHECK-SAME: %[[BASE:.+]]: memref<16xi32, #spirv.storage_class<CrossWorkgroup>>
-// CHECK: spirv.INTEL.MaskedGather
-// CHECK-SAME: vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xi32> -> vector<4xi32>
-func.func @vector_gather_i32(%base: memref<16xi32, #spirv.storage_class<CrossWorkgroup>>,
- %indices: vector<4xindex>,
- %mask: vector<4xi1>,
- %pass_thru: vector<4xi32>) -> vector<4xi32> {
- %c0 = arith.constant 0 : index
- %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
- : memref<16xi32, #spirv.storage_class<CrossWorkgroup>>,
- vector<4xindex>, vector<4xi1>, vector<4xi32> into vector<4xi32>
- return %result : vector<4xi32>
-}
-
-// CHECK-LABEL: @vector_gather_with_alignment
-// CHECK: %[[ALIGN:.+]] = spirv.Constant 4 : i32
-// CHECK: spirv.INTEL.MaskedGather %{{.+}}, %[[ALIGN]]
-func.func @vector_gather_with_alignment(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
- %indices: vector<4xindex>,
- %mask: vector<4xi1>,
- %pass_thru: vector<4xf32>) -> vector<4xf32> {
- %c0 = arith.constant 0 : index
- %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
- { alignment = 4 : i64 }
- : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
- vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
- return %result : vector<4xf32>
-}
-
-} // end module
diff --git a/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt
index cb5ecc933ad82..0f339745632a5 100644
--- a/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt
+++ b/mlir/test/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -1,10 +1,14 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestVectorToSPIRV
+ TestVectorGatherScatterToSPIRV.cpp
TestVectorReductionToSPIRVDotProd.cpp
EXCLUDE_FROM_LIBMLIR
)
mlir_target_link_libraries(MLIRTestVectorToSPIRV PUBLIC
+ MLIRSPIRVConversion
+ MLIRUBToSPIRV
+ MLIRUBDialect
MLIRVectorToSPIRV
MLIRArithDialect
MLIRFuncDialect
diff --git a/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorGatherScatterToSPIRV.cpp b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorGatherScatterToSPIRV.cpp
new file mode 100644
index 0000000000000..fc06383144ee2
--- /dev/null
+++ b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorGatherScatterToSPIRV.cpp
@@ -0,0 +1,71 @@
+//===- TestVectorGatherScatterToSPIRV.cpp - Test gather/scatter ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace {
+
+struct TestVectorGatherScatterToSPIRV
+ : PassWrapper<TestVectorGatherScatterToSPIRV, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherScatterToSPIRV)
+
+ StringRef getArgument() const final {
+ return "test-vector-gather-scatter-to-spirv";
+ }
+
+ StringRef getDescription() const final {
+ return "Test lowering of vector.gather/vector.scatter to "
+ "spirv.INTEL.MaskedGather/MaskedScatter "
+ "(SPV_INTEL_masked_gather_scatter)";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect,
+ ub::UBDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ Operation *op = getOperation();
+
+ auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
+ std::unique_ptr<ConversionTarget> target =
+ SPIRVConversionTarget::get(targetAttr);
+ SPIRVTypeConverter typeConverter(targetAttr);
+
+ target->addLegalOp<UnrealizedConversionCastOp>();
+
+ RewritePatternSet patterns(context);
+ populateVectorToSPIRVPatterns(typeConverter, patterns);
+ populateVectorGatherScatterToSPIRVPatterns(typeConverter, patterns);
+ ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
+
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+namespace test {
+void registerTestVectorGatherScatterToSPIRV() {
+ PassRegistration<TestVectorGatherScatterToSPIRV>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 48b8c179bd1b0..02956206f2dc9 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -157,6 +157,7 @@ void registerTestTensorTransforms();
void registerTestTopologicalSortAnalysisPass();
void registerTestTransformDialectEraseSchedulePass();
void registerTestPassStateExtensionCommunication();
+void registerTestVectorGatherScatterToSPIRV();
void registerTestVectorLowerings();
void registerTestVectorReductionToSPIRVDotProd();
void registerTestVulkanRunnerPipeline();
@@ -305,6 +306,7 @@ static void registerTestPasses() {
mlir::test::registerTestTopologicalSortAnalysisPass();
mlir::test::registerTestTransformDialectEraseSchedulePass();
mlir::test::registerTestPassStateExtensionCommunication();
+ mlir::test::registerTestVectorGatherScatterToSPIRV();
mlir::test::registerTestVectorLowerings();
mlir::test::registerTestVectorReductionToSPIRVDotProd();
mlir::test::registerTestVulkanRunnerPipeline();
More information about the Mlir-commits
mailing list