[Mlir-commits] [mlir] [mlir][SPIR-V][VectorToSPIRV] Add conversion patterns for vector.gather/scatter to SPIR-V (PR #193422)
Krzysztof Drewniak
llvmlistbot at llvm.org
Wed May 6 15:46:45 PDT 2026
================
@@ -853,6 +853,175 @@ struct VectorStoreOpConverter final
}
};
+struct VectorGatherOpConverter final
+ : public OpConversionPattern<vector::GatherOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(vector::GatherOp gatherOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only support 1-D result vectors.
+ auto vectorType = gatherOp.getVectorType();
+ if (vectorType.getRank() != 1)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "only 1-D vectors supported");
+
+ // Only support memref base (not tensor).
+ auto memrefType = dyn_cast<MemRefType>(gatherOp.getBaseType());
+ if (!memrefType)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "only memref base supported");
+
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "expected spirv.storage_class");
+
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ if (!typeConverter.getTargetEnv().allows(
+ spirv::Extension::SPV_INTEL_masked_gather_scatter))
+ return rewriter.notifyMatchFailure(gatherOp,
+ "target environment does not enable "
+ "SPV_INTEL_masked_gather_scatter");
+ auto loc = gatherOp.getLoc();
+
+ // Compute base element pointer from memref + offsets.
+ Value basePtr =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+ adaptor.getOffsets(), loc, rewriter);
+ if (!basePtr)
+ return rewriter.notifyMatchFailure(gatherOp,
+ "failed to get element pointer");
+
+ // Convert element type and construct pointer vector type.
+ auto storageClass = attr.getValue();
+ Type elementType = typeConverter.convertType(memrefType.getElementType());
+ if (!elementType)
+ return rewriter.notifyMatchFailure(gatherOp, "unsupported element type");
+ auto ptrType = spirv::PointerType::get(elementType, storageClass);
+ int64_t numElements = vectorType.getDimSize(0);
+ auto ptrVectorType = VectorType::get({numElements}, ptrType);
+
+ // Build pointer vector: for each index, compute ptr via PtrAccessChain.
+ auto indexType = typeConverter.getIndexType();
----------------
krzysz00 wrote:
... why are we extracting things out element by element? That looks overcomplicated
https://github.com/llvm/llvm-project/pull/193422
More information about the Mlir-commits
mailing list