[Mlir-commits] [mlir] [mlir][SPIR-V][VectorToSPIRV] Add conversion patterns for vector.gather/scatter to SPIR-V (PR #193422)

Krzysztof Drewniak llvmlistbot at llvm.org
Wed May 6 15:46:45 PDT 2026


================
@@ -853,6 +853,175 @@ struct VectorStoreOpConverter final
   }
 };
 
+struct VectorGatherOpConverter final
+    : public OpConversionPattern<vector::GatherOp> {
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(vector::GatherOp gatherOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only support 1-D result vectors.
+    auto vectorType = gatherOp.getVectorType();
+    if (vectorType.getRank() != 1)
+      return rewriter.notifyMatchFailure(gatherOp,
+                                         "only 1-D vectors supported");
+
+    // Only support memref base (not tensor).
+    auto memrefType = dyn_cast<MemRefType>(gatherOp.getBaseType());
+    if (!memrefType)
+      return rewriter.notifyMatchFailure(gatherOp,
+                                         "only memref base supported");
+
+    auto attr =
+        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+    if (!attr)
+      return rewriter.notifyMatchFailure(gatherOp,
+                                         "expected spirv.storage_class");
+
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    if (!typeConverter.getTargetEnv().allows(
+            spirv::Extension::SPV_INTEL_masked_gather_scatter))
+      return rewriter.notifyMatchFailure(gatherOp,
+                                         "target environment does not enable "
+                                         "SPV_INTEL_masked_gather_scatter");
+    auto loc = gatherOp.getLoc();
+
+    // Compute base element pointer from memref + offsets.
+    Value basePtr =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+                             adaptor.getOffsets(), loc, rewriter);
+    if (!basePtr)
+      return rewriter.notifyMatchFailure(gatherOp,
+                                         "failed to get element pointer");
+
+    // Convert element type and construct pointer vector type.
+    auto storageClass = attr.getValue();
+    Type elementType = typeConverter.convertType(memrefType.getElementType());
+    if (!elementType)
+      return rewriter.notifyMatchFailure(gatherOp, "unsupported element type");
+    auto ptrType = spirv::PointerType::get(elementType, storageClass);
+    int64_t numElements = vectorType.getDimSize(0);
+    auto ptrVectorType = VectorType::get({numElements}, ptrType);
+
+    // Build pointer vector: for each index, compute ptr via PtrAccessChain.
+    auto indexType = typeConverter.getIndexType();
----------------
krzysz00 wrote:

... why are we extracting things out element by element? That looks overcomplicated

https://github.com/llvm/llvm-project/pull/193422


More information about the Mlir-commits mailing list