[Mlir-commits] [mlir] [mlir][SPIR-V][VectorToSPIRV] Add conversion patterns for vector.gather/scatter to SPIR-V (PR #193422)

Arseniy Obolenskiy llvmlistbot at llvm.org
Wed Apr 22 00:08:45 PDT 2026


https://github.com/aobolensk created https://github.com/llvm/llvm-project/pull/193422

Add VectorGatherOpConverter and VectorScatterOpConverter that lower vector.gather and vector.scatter to spirv.INTEL.MaskedGather and spirv.INTEL.MaskedScatter respectively (SPV_INTEL_masked_gather_scatter extension)

Extend CompositeConstruct ODS result type constraint with SPIRV_CompositeOrPtrVector, which adds SPIRV_PtrVector alongside the existing SPIRV_Composite types

>From 0d2c61ad32012fb827847d0eb324c999a7340106 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Wed, 22 Apr 2026 09:06:38 +0200
Subject: [PATCH] [mlir][SPIR-V][VectorToSPIRV] Add conversion patterns for
 vector.gather/scatter to SPIR-V

Add VectorGatherOpConverter and VectorScatterOpConverter that lower vector.gather and vector.scatter to spirv.INTEL.MaskedGather and spirv.INTEL.MaskedScatter respectively (SPV_INTEL_masked_gather_scatter extension)

Extend CompositeConstruct ODS result type constraint with SPIRV_CompositeOrPtrVector, which adds SPIRV_PtrVector alongside the existing SPIRV_Composite types
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |   8 +
 .../Dialect/SPIRV/IR/SPIRVCompositeOps.td     |   2 +-
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 164 +++++++++++++++++-
 .../VectorToSPIRV/vector-to-spirv.mlir        | 105 +++++++++++
 4 files changed, 276 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 27abd13b8ddb1..b2c2484552be4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4342,6 +4342,14 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
 class SPIRV_VectorOf<Type type> :
     FixedVectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16], [type]>;
 
+// Pointer vector type (SPV_INTEL_masked_gather_scatter extension).
+def SPIRV_PtrVector : SPIRV_VectorOf<SPIRV_AnyPtr>;
+// Composite types extended with pointer vectors (SPV_INTEL_masked_gather_scatter).
+def SPIRV_CompositeOrPtrVector :
+    AnyTypeOf<[SPIRV_Vector, SPIRV_PtrVector,
+               SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
+               SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnyTensorArm]>;
+
 class SPIRV_ScalarOrVectorOf<Type type> :
     AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
index 981131484498d..111668ccbf774 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
@@ -63,7 +63,7 @@ def SPIRV_CompositeConstructOp : SPIRV_Op<"CompositeConstruct", [Pure]> {
   );
 
   let results = (outs
-    SPIRV_Composite:$result
+    SPIRV_CompositeOrPtrVector:$result
   );
 
   let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 921075736e97b..202173bf4c947 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -853,6 +853,165 @@ struct VectorStoreOpConverter final
   }
 };
 
+struct VectorGatherOpConverter final
+    : public OpConversionPattern<vector::GatherOp> {
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(vector::GatherOp gatherOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only support 1-D result vectors.
+    auto vectorType = gatherOp.getVectorType();
+    if (vectorType.getRank() != 1)
+      return rewriter.notifyMatchFailure(gatherOp,
+                                         "only 1-D vectors supported");
+
+    // Only support memref base (not tensor).
+    auto memrefType = dyn_cast<MemRefType>(gatherOp.getBaseType());
+    if (!memrefType)
+      return rewriter.notifyMatchFailure(gatherOp,
+                                         "only memref base supported");
+
+    auto attr =
+        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+    if (!attr)
+      return rewriter.notifyMatchFailure(gatherOp,
+                                         "expected spirv.storage_class");
+
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto loc = gatherOp.getLoc();
+
+    // Compute base element pointer from memref + offsets.
+    Value basePtr =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+                             adaptor.getOffsets(), loc, rewriter);
+    if (!basePtr)
+      return rewriter.notifyMatchFailure(gatherOp,
+                                         "failed to get element pointer");
+
+    // Convert element type and construct pointer vector type.
+    auto storageClass = attr.getValue();
+    Type elementType = typeConverter.convertType(memrefType.getElementType());
+    if (!elementType)
+      return rewriter.notifyMatchFailure(gatherOp, "unsupported element type");
+    auto ptrType = spirv::PointerType::get(elementType, storageClass);
+    int64_t numElements = vectorType.getDimSize(0);
+    auto ptrVectorType = VectorType::get({numElements}, ptrType);
+
+    // Build pointer vector: for each index, compute ptr via PtrAccessChain.
+    auto indexType = typeConverter.getIndexType();
+    SmallVector<Value> pointers;
+    for (int64_t i = 0; i < numElements; ++i) {
+      auto i32Type = rewriter.getI32Type();
+      Value idx = spirv::ConstantOp::create(rewriter, loc, i32Type,
+                                            rewriter.getI32IntegerAttr(i));
+      Value scalarIndex = spirv::VectorExtractDynamicOp::create(
+          rewriter, loc, adaptor.getIndices(), idx);
+      // Cast index to the SPIR-V index type if needed.
+      if (scalarIndex.getType() != indexType)
+        scalarIndex =
+            spirv::SConvertOp::create(rewriter, loc, indexType, scalarIndex);
+      Value ptr = spirv::PtrAccessChainOp::create(rewriter, loc, basePtr,
+                                                  scalarIndex, /*indices=*/{});
+      pointers.push_back(ptr);
+    }
+    Value ptrVector = spirv::CompositeConstructOp::create(
+        rewriter, loc, ptrVectorType, pointers);
+
+    // Alignment.
+    auto i32Type = rewriter.getI32Type();
+    uint32_t align = gatherOp.getAlignment().value_or(0);
+    Value alignmentVal = spirv::ConstantOp::create(
+        rewriter, loc, i32Type, rewriter.getI32IntegerAttr(align));
+
+    // Create spirv.INTEL.MaskedGather.
+    auto resultType = typeConverter.convertType(vectorType);
+    rewriter.replaceOpWithNewOp<spirv::INTELMaskedGatherOp>(
+        gatherOp, resultType, ptrVector, alignmentVal, adaptor.getMask(),
+        adaptor.getPassThru());
+    return success();
+  }
+};
+
+struct VectorScatterOpConverter final
+    : public OpConversionPattern<vector::ScatterOp> {
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(vector::ScatterOp scatterOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only support 1-D vectors.
+    auto vectorType = scatterOp.getVectorType();
+    if (vectorType.getRank() != 1)
+      return rewriter.notifyMatchFailure(scatterOp,
+                                         "only 1-D vectors supported");
+
+    // Only support memref base (not tensor).
+    auto memrefType = dyn_cast<MemRefType>(scatterOp.getBaseType());
+    if (!memrefType)
+      return rewriter.notifyMatchFailure(scatterOp,
+                                         "only memref base supported");
+
+    auto attr =
+        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+    if (!attr)
+      return rewriter.notifyMatchFailure(scatterOp,
+                                         "expected spirv.storage_class");
+
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto loc = scatterOp.getLoc();
+
+    // Compute base element pointer from memref + offsets.
+    Value basePtr =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+                             adaptor.getOffsets(), loc, rewriter);
+    if (!basePtr)
+      return rewriter.notifyMatchFailure(scatterOp,
+                                         "failed to get element pointer");
+
+    // Convert element type and construct pointer vector type.
+    auto storageClass = attr.getValue();
+    Type elementType = typeConverter.convertType(memrefType.getElementType());
+    if (!elementType)
+      return rewriter.notifyMatchFailure(scatterOp, "unsupported element type");
+    auto ptrType = spirv::PointerType::get(elementType, storageClass);
+    int64_t numElements = vectorType.getDimSize(0);
+    auto ptrVectorType = VectorType::get({numElements}, ptrType);
+
+    // Build pointer vector: for each index, compute ptr via PtrAccessChain.
+    auto indexType = typeConverter.getIndexType();
+    SmallVector<Value> pointers;
+    for (int64_t i = 0; i < numElements; ++i) {
+      auto i32Type = rewriter.getI32Type();
+      Value idx = spirv::ConstantOp::create(rewriter, loc, i32Type,
+                                            rewriter.getI32IntegerAttr(i));
+      Value scalarIndex = spirv::VectorExtractDynamicOp::create(
+          rewriter, loc, adaptor.getIndices(), idx);
+      if (scalarIndex.getType() != indexType)
+        scalarIndex =
+            spirv::SConvertOp::create(rewriter, loc, indexType, scalarIndex);
+      Value ptr = spirv::PtrAccessChainOp::create(rewriter, loc, basePtr,
+                                                  scalarIndex, /*indices=*/{});
+      pointers.push_back(ptr);
+    }
+    Value ptrVector = spirv::CompositeConstructOp::create(
+        rewriter, loc, ptrVectorType, pointers);
+
+    // Alignment.
+    auto i32Type = rewriter.getI32Type();
+    uint32_t align = scatterOp.getAlignment().value_or(0);
+    Value alignmentVal = spirv::ConstantOp::create(
+        rewriter, loc, i32Type, rewriter.getI32IntegerAttr(align));
+
+    // Create spirv.INTEL.MaskedScatter.
+    spirv::INTELMaskedScatterOp::create(rewriter, loc, ptrVector, alignmentVal,
+                                        adaptor.getMask(),
+                                        adaptor.getValueToStore());
+    rewriter.eraseOp(scatterOp);
+    return success();
+  }
+};
+
 struct VectorReductionToIntDotProd final
     : OpRewritePattern<vector::ReductionOp> {
   using Base::Base;
@@ -1112,8 +1271,9 @@ void mlir::populateVectorToSPIRVPatterns(
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
       VectorScalarBroadcastPattern, VectorLoadOpConverter,
-      VectorStoreOpConverter, VectorStepOpConvert>(
-      typeConverter, patterns.getContext(), PatternBenefit(1));
+      VectorStoreOpConverter, VectorGatherOpConverter, VectorScatterOpConverter,
+      VectorStepOpConvert>(typeConverter, patterns.getContext(),
+                           PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 48a1298bc4877..9f275eb39a56a 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1224,3 +1224,108 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
   %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
   return %0: vector<4xf32>
 }
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Kernel, Addresses, MaskedGatherScatterINTEL],
+               [SPV_INTEL_masked_gather_scatter]>,
+    #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @vector_gather
+//  CHECK-SAME: %[[BASE:.+]]: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>
+//  CHECK-SAME: %[[INDICES:.+]]: vector<4xindex>
+//  CHECK-SAME: %[[MASK:.+]]: vector<4xi1>
+//  CHECK-SAME: %[[PT:.+]]: vector<4xf32>
+//   CHECK-DAG: %[[IDXVEC:.+]] = builtin.unrealized_conversion_cast %[[INDICES]] : vector<4xindex> to vector<4xi32>
+//       CHECK: %[[EPTR:.+]] = spirv.AccessChain %{{.+}}[%{{.+}}] : !spirv.ptr<!spirv.array<16 x f32>, CrossWorkgroup>, i32
+//       CHECK: %[[IDX0:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+//       CHECK: %[[P0:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX0]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+//       CHECK: %[[IDX1:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+//       CHECK: %[[P1:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+//       CHECK: %[[IDX2:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+//       CHECK: %[[P2:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX2]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+//       CHECK: %[[IDX3:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+//       CHECK: %[[P3:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX3]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+//       CHECK: %[[PTRVEC:.+]] = spirv.CompositeConstruct %[[P0]], %[[P1]], %[[P2]], %[[P3]]
+//  CHECK-SAME: -> vector<4x!spirv.ptr<f32, CrossWorkgroup>>
+//       CHECK: %[[ALIGN:.+]] = spirv.Constant 0 : i32
+//       CHECK: %[[RES:.+]] = spirv.INTEL.MaskedGather %[[PTRVEC]], %[[ALIGN]], %[[MASK]], %[[PT]]
+//  CHECK-SAME:   : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32> -> vector<4xf32>
+//       CHECK: return %[[RES]]
+func.func @vector_gather(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+                          %indices: vector<4xindex>,
+                          %mask: vector<4xi1>,
+                          %pass_thru: vector<4xf32>) -> vector<4xf32> {
+  %c0 = arith.constant 0 : index
+  %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+    : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+      vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %result : vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_scatter
+//  CHECK-SAME: %[[BASE:.+]]: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>
+//  CHECK-SAME: %[[INDICES:.+]]: vector<4xindex>
+//  CHECK-SAME: %[[MASK:.+]]: vector<4xi1>
+//  CHECK-SAME: %[[VALUES:.+]]: vector<4xf32>
+//   CHECK-DAG: %[[IDXVEC:.+]] = builtin.unrealized_conversion_cast %[[INDICES]] : vector<4xindex> to vector<4xi32>
+//       CHECK: %[[EPTR:.+]] = spirv.AccessChain %{{.+}}[%{{.+}}] : !spirv.ptr<!spirv.array<16 x f32>, CrossWorkgroup>, i32
+//       CHECK: %[[IDX0:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+//       CHECK: %[[P0:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX0]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+//       CHECK: %[[IDX1:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+//       CHECK: %[[P1:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+//       CHECK: %[[IDX2:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+//       CHECK: %[[P2:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX2]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+//       CHECK: %[[IDX3:.+]] = spirv.VectorExtractDynamic %[[IDXVEC]][%{{.+}}] : vector<4xi32>, i32
+//       CHECK: %[[P3:.+]] = spirv.PtrAccessChain %[[EPTR]][%[[IDX3]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+//       CHECK: %[[PTRVEC:.+]] = spirv.CompositeConstruct %[[P0]], %[[P1]], %[[P2]], %[[P3]]
+//  CHECK-SAME: -> vector<4x!spirv.ptr<f32, CrossWorkgroup>>
+//       CHECK: %[[ALIGN:.+]] = spirv.Constant 0 : i32
+//       CHECK: spirv.INTEL.MaskedScatter %[[PTRVEC]], %[[ALIGN]], %[[MASK]], %[[VALUES]]
+//  CHECK-SAME:   : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32>
+//       CHECK: return
+func.func @vector_scatter(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+                           %indices: vector<4xindex>,
+                           %mask: vector<4xi1>,
+                           %values: vector<4xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.scatter %base[%c0][%indices], %mask, %values
+    : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+      vector<4xindex>, vector<4xi1>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: @vector_gather_i32
+//  CHECK-SAME: %[[BASE:.+]]: memref<16xi32, #spirv.storage_class<CrossWorkgroup>>
+//       CHECK: spirv.INTEL.MaskedGather
+//  CHECK-SAME: vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xi32> -> vector<4xi32>
+func.func @vector_gather_i32(%base: memref<16xi32, #spirv.storage_class<CrossWorkgroup>>,
+                              %indices: vector<4xindex>,
+                              %mask: vector<4xi1>,
+                              %pass_thru: vector<4xi32>) -> vector<4xi32> {
+  %c0 = arith.constant 0 : index
+  %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+    : memref<16xi32, #spirv.storage_class<CrossWorkgroup>>,
+      vector<4xindex>, vector<4xi1>, vector<4xi32> into vector<4xi32>
+  return %result : vector<4xi32>
+}
+
+// CHECK-LABEL: @vector_gather_with_alignment
+//       CHECK: %[[ALIGN:.+]] = spirv.Constant 4 : i32
+//       CHECK: spirv.INTEL.MaskedGather %{{.+}}, %[[ALIGN]]
+func.func @vector_gather_with_alignment(%base: memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+                                         %indices: vector<4xindex>,
+                                         %mask: vector<4xi1>,
+                                         %pass_thru: vector<4xf32>) -> vector<4xf32> {
+  %c0 = arith.constant 0 : index
+  %result = vector.gather %base[%c0][%indices], %mask, %pass_thru
+    { alignment = 4 : i64 }
+    : memref<16xf32, #spirv.storage_class<CrossWorkgroup>>,
+      vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %result : vector<4xf32>
+}
+
+} // end module



More information about the Mlir-commits mailing list