[Mlir-commits] [mlir] [mlir][SPIR-V][VectorToSPIRV] Add conversion patterns for vector.gather/scatter to SPIR-V (PR #193422)
Arseniy Obolenskiy
llvmlistbot at llvm.org
Wed Apr 22 00:08:45 PDT 2026
https://github.com/aobolensk created https://github.com/llvm/llvm-project/pull/193422
Add VectorGatherOpConverter and VectorScatterOpConverter that lower vector.gather and vector.scatter to spirv.INTEL.MaskedGather and spirv.INTEL.MaskedScatter respectively (SPV_INTEL_masked_gather_scatter extension)
Extend CompositeConstruct ODS result type constraint with SPIRV_CompositeOrPtrVector, which adds SPIRV_PtrVector alongside the existing SPIRV_Composite types
>From 0d2c61ad32012fb827847d0eb324c999a7340106 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Wed, 22 Apr 2026 09:06:38 +0200
Subject: [PATCH] [mlir][SPIR-V][VectorToSPIRV] Add conversion patterns for
vector.gather/scatter to SPIR-V
Add VectorGatherOpConverter and VectorScatterOpConverter that lower vector.gather and vector.scatter to spirv.INTEL.MaskedGather and spirv.INTEL.MaskedScatter respectively (SPV_INTEL_masked_gather_scatter extension)
Extend CompositeConstruct ODS result type constraint with SPIRV_CompositeOrPtrVector, which adds SPIRV_PtrVector alongside the existing SPIRV_Composite types
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 8 +
.../Dialect/SPIRV/IR/SPIRVCompositeOps.td | 2 +-
.../VectorToSPIRV/VectorToSPIRV.cpp | 164 +++++++++++++++++-
.../VectorToSPIRV/vector-to-spirv.mlir | 105 +++++++++++
4 files changed, 276 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 27abd13b8ddb1..b2c2484552be4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4342,6 +4342,14 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
class SPIRV_VectorOf<Type type> :
FixedVectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16], [type]>;
+// Pointer vector type (SPV_INTEL_masked_gather_scatter extension).
+def SPIRV_PtrVector : SPIRV_VectorOf<SPIRV_AnyPtr>;
+// Composite types extended with pointer vectors (SPV_INTEL_masked_gather_scatter).
+def SPIRV_CompositeOrPtrVector :
+ AnyTypeOf<[SPIRV_Vector, SPIRV_PtrVector,
+ SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnyTensorArm]>;
+
class SPIRV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
index 981131484498d..111668ccbf774 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
@@ -63,7 +63,7 @@ def SPIRV_CompositeConstructOp : SPIRV_Op<"CompositeConstruct", [Pure]> {
);
let results = (outs
- SPIRV_Composite:$result
+ SPIRV_CompositeOrPtrVector:$result
);
let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 921075736e97b..202173bf4c947 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -853,6 +853,165 @@ struct VectorStoreOpConverter final
}
};
+struct VectorGatherOpConverter final
+ : public OpConversionPattern<vector::GatherOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(vector::GatherOp gatherOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only support 1-D result vectors.
+ auto vectorType = gatherOp.getVectorType();
+ if (vectorType.getRank() != 1)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "only 1-D vectors supported");
+
+ // Only support memref base (not tensor).
+ auto memrefType = dyn_cast<MemRefType>(gatherOp.getBaseType());
+ if (!memrefType)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "only memref base supported");
+
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "expected spirv.storage_class");
+
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto loc = gatherOp.getLoc();
+
+ // Compute base element pointer from memref + offsets.
+ Value basePtr =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+ adaptor.getOffsets(), loc, rewriter);
+ if (!basePtr)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "failed to get element pointer");
+
+ // Convert element type and construct pointer vector type.
+ auto storageClass = attr.getValue();
+ Type elementType = typeConverter.convertType(memrefType.getElementType());
+ if (!elementType)
+ return rewriter.notifyMatchFailure(gatherOp, "unsupported element type");
+ auto ptrType = spirv::PointerType::get(elementType, storageClass);
+ int64_t numElements = vectorType.getDimSize(0);
+ auto ptrVectorType = VectorType::get({numElements}, ptrType);
+
+ // Build pointer vector: for each index, compute ptr via PtrAccessChain.
+ auto indexType = typeConverter.getIndexType();
+ SmallVector<Value> pointers;
+ for (int64_t i = 0; i < numElements; ++i) {
+ auto i32Type = rewriter.getI32Type();
+ Value idx = spirv::ConstantOp::create(rewriter, loc, i32Type,
+ rewriter.getI32IntegerAttr(i));
+ Value scalarIndex = spirv::VectorExtractDynamicOp::create(
+ rewriter, loc, adaptor.getIndices(), idx);
+ // Cast index to the SPIR-V index type if needed.
+ if (scalarIndex.getType() != indexType)
+ scalarIndex =
+ spirv::SConvertOp::create(rewriter, loc, indexType, scalarIndex);
+ Value ptr = spirv::PtrAccessChainOp::create(rewriter, loc, basePtr,
+ scalarIndex, /*indices=*/{});
+ pointers.push_back(ptr);
+ }
+ Value ptrVector = spirv::CompositeConstructOp::create(
+ rewriter, loc, ptrVectorType, pointers);
+
+ // Alignment.
+ auto i32Type = rewriter.getI32Type();
+ uint32_t align = gatherOp.getAlignment().value_or(0);
+ Value alignmentVal = spirv::ConstantOp::create(
+ rewriter, loc, i32Type, rewriter.getI32IntegerAttr(align));
+
+ // Create spirv.INTEL.MaskedGather.
+ auto resultType = typeConverter.convertType(vectorType);
+ rewriter.replaceOpWithNewOp<spirv::INTELMaskedGatherOp>(
+ gatherOp, resultType, ptrVector, alignmentVal, adaptor.getMask(),
+ adaptor.getPassThru());
+ return success();
+ }
+};
+
+struct VectorScatterOpConverter final
+ : public OpConversionPattern<vector::ScatterOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(vector::ScatterOp scatterOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only support 1-D vectors.
+ auto vectorType = scatterOp.getVectorType();
+ if (vectorType.getRank() != 1)
+ return rewriter.notifyMatchFailure(scatterOp,
+ "only 1-D vectors supported");
+
+ // Only support memref base (not tensor).
+ auto memrefType = dyn_cast<MemRefType>(scatterOp.getBaseType());
+ if (!memrefType)
+ return rewriter.notifyMatchFailure(scatterOp,
+ "only memref base supported");
+
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return rewriter.notifyMatchFailure(scatterOp,
+ "expected spirv.storage_class");
+
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto loc = scatterOp.getLoc();
+
+ // Compute base element pointer from memref + offsets.
+ Value basePtr =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+ adaptor.getOffsets(), loc, rewriter);
+ if (!basePtr)
+ return rewriter.notifyMatchFailure(scatterOp,
+ "failed to get element pointer");
+
+ // Convert element type and construct pointer vector type.
+ auto storageClass = attr.getValue();
+ Type elementType = typeConverter.convertType(memrefType.getElementType());
+ if (!elementType)
+ return rewriter.notifyMatchFailure(scatterOp, "unsupported element type");
+ auto ptrType = spirv::PointerType::get(elementType, storageClass);
+ int64_t numElements = vectorType.getDimSize(0);
+ auto ptrVectorType = VectorType::get({numElements}, ptrType);
+
+ // Build pointer vector: for each index, compute ptr via PtrAccessChain.
+ auto indexType = typeConverter.getIndexType();
+ SmallVector<Value> pointers;
+ for (int64_t i = 0; i < numElements; ++i) {
+ auto i32Type = rewriter.getI32Type();
+ Value idx = spirv::ConstantOp::create(rewriter, loc, i32Type,
+ rewriter.getI32IntegerAttr(i));
+ Value scalarIndex = spirv::VectorExtractDynamicOp::create(
+ rewriter, loc, adaptor.getIndices(), idx);
+ if (scalarIndex.getType() != indexType)
+ scalarIndex =
+ spirv::SConvertOp::create(rewriter, loc, indexType, scalarIndex);
+ Value ptr = spirv::PtrAccessChainOp::create(rewriter, loc, basePtr,
+ scalarIndex, /*indices=*/{});
+ pointers.push_back(ptr);
+ }
+ Value ptrVector = spirv::CompositeConstructOp::create(
+ rewriter, loc, ptrVectorType, pointers);
+
+ // Alignment.
+ auto i32Type = rewriter.getI32Type();
+ uint32_t align = scatterOp.getAlignment().value_or(0);
+ Value alignmentVal = spirv::ConstantOp::create(
+ rewriter, loc, i32Type, rewriter.getI32IntegerAttr(align));
+
+ // Create spirv.INTEL.MaskedScatter.
+ spirv::INTELMaskedScatterOp::create(rewriter, loc, ptrVector, alignmentVal,
+ adaptor.getMask(),
+ adaptor.getValueToStore());
+ rewriter.eraseOp(scatterOp);
+ return success();
+ }
+};
+
struct VectorReductionToIntDotProd final
: OpRewritePattern<vector::ReductionOp> {
using Base::Base;
@@ -1112,8 +1271,9 @@ void mlir::populateVectorToSPIRVPatterns(
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
VectorScalarBroadcastPattern, VectorLoadOpConverter,
- VectorStoreOpConverter, VectorStepOpConvert>(
- typeConverter, patterns.getContext(), PatternBenefit(1));
+ VectorStoreOpConverter, VectorGatherOpConverter, VectorScatterOpConverter,
+ VectorStepOpConvert>(typeConverter, patterns.getContext(),
+ PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 48a1298bc4877..9f275eb39a56a 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1224,3 +1224,108 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
%0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
return %0: vector<4xf32>
}
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Kernel, Addresses, MaskedGatherScatterINTEL],
+ [SPV_INTEL_masked_gather_scatter]>,
+ #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @vector_gather
+// CHECK-SAME: %[[BASE:.+]]: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>
+// CHECK-SAME: %[[INDICES:.+]]: vector<4xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<4xi1>
+// CHECK-SAME: %[[PT:.+]]: vector<4xf32>
+// CHECK-DAG: %[[IDXVEC:.+]] = builtin.unrealized_conversion_cast %[[INDICES]] : vector<4xindex> to vector<4xi32>
+// CHECK: %[[EPTR:.+]] = spirv.AccessChain %{{.+}}[%{{.+}}] : !spirv.ptr<!spirv.array<16 x f32>, CrossWorkgroup>, i32
+// CHECK: %[[IDX0:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P0:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX0]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX1:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P1:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX2:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P2:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX2]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX3:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P3:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX3]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[PTRVEC:.+]] = spirv.CompositeConstruct %[[P0]], %[[P1]], %[[P2]], %[[P3]]
+// CHECK-SAME: -> vector<4x!spirv.ptr<f32, CrossWorkgroup>>
+// CHECK: %[[ALIGN:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[RES:.+]] = spirv.INTEL.MaskedGather %[[PTRVEC]], %[[ALIGN]], %[[MASK]], %[[PT]]
+// CHECK-SAME: : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32> -> vector<4xf32>
+// CHECK: return %[[RES]]
+func.func @vector_gather(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %pass_thru: vector<4xf32>) -> vector<4xf32> {
+ %c0 = arith.constant 0 : index
+ %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %result : vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_scatter
+// CHECK-SAME: %[[BASE:.+]]: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>
+// CHECK-SAME: %[[INDICES:.+]]: vector<4xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<4xi1>
+// CHECK-SAME: %[[VALUES:.+]]: vector<4xf32>
+// CHECK-DAG: %[[IDXVEC:.+]] = builtin.unrealized_conversion_cast %[[INDICES]] : vector<4xindex> to vector<4xi32>
+// CHECK: %[[EPTR:.+]] = spirv.AccessChain %{{.+}}[%{{.+}}] : !spirv.ptr<!spirv.array<16 x f32>, CrossWorkgroup>, i32
+// CHECK: %[[IDX0:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P0:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX0]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX1:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P1:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX2:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P2:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX2]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[IDX3:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+// CHECK: %[[P3:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX3]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[PTRVEC:.+]] = spirv.CompositeConstruct %[[P0]], %[[P1]], %[[P2]], %[[P3]]
+// CHECK-SAME: -> vector<4x!spirv.ptr<f32, CrossWorkgroup>>
+// CHECK: %[[ALIGN:.+]] = spirv.Constant 0 : i32
+// CHECK: spirv.INTEL.MaskedScatter %[[PTRVEC]], %[[ALIGN]], %[[MASK]], %[[VALUES]]
+// CHECK-SAME: : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32>
+// CHECK: return
+func.func @vector_scatter(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %values: vector<4xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.scatter %base[%c0][%indices], %mask, %values
+ : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: @vector_gather_i32
+// CHECK-SAME: %[[BASE:.+]]: memref<16xi32, #spirv.storage_class<CrossWorkgroup>>
+// CHECK: spirv.INTEL.MaskedGather
+// CHECK-SAME: vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xi32> -> vector<4xi32>
+func.func @vector_gather_i32(%base: memref<16xi32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %pass_thru: vector<4xi32>) -> vector<4xi32> {
+ %c0 = arith.constant 0 : index
+ %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ : memref<16xi32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xi32> into vector<4xi32>
+ return %result : vector<4xi32>
+}
+
+// CHECK-LABEL: @vector_gather_with_alignment
+// CHECK: %[[ALIGN:.+]] = spirv.Constant 4 : i32
+// CHECK: spirv.INTEL.MaskedGather %{{.+}}, %[[ALIGN]]
+func.func @vector_gather_with_alignment(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ %indices: vector<4xindex>,
+ %mask: vector<4xi1>,
+ %pass_thru: vector<4xf32>) -> vector<4xf32> {
+ %c0 = arith.constant 0 : index
+ %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ { alignment = 4 : i64 }
+ : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+ vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %result : vector<4xf32>
+}
+
+} // end module
More information about the Mlir-commits
mailing list