[Mlir-commits] [mlir] [mlir][vector][spirv] Lower vector.transfer_read and vector.transfer_write to SPIR-V (PR #69708)
Lei Zhang
llvmlistbot at llvm.org
Fri Oct 27 09:36:47 PDT 2023
================
@@ -509,6 +509,87 @@ struct VectorShuffleOpConvert final
}
};
+struct VectorTransferReadOpConverter final
+ : public OpConversionPattern<vector::TransferReadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferReadOp transferReadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (transferReadOp.getMask())
+ return rewriter.notifyMatchFailure(transferReadOp,
+ "unsupported transfer_read with mask");
+ auto sourceType = transferReadOp.getSource().getType();
+ if (!llvm::isa<MemRefType>(sourceType))
+ return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
+
+ auto memrefType = cast<MemRefType>(sourceType);
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto loc = transferReadOp.getLoc();
+ Value accessChain =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getSource(),
+ adaptor.getIndices(), loc, rewriter);
+ if (!accessChain)
+ return failure();
+
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return failure();
+
+ spirv::StorageClass storageClass = attr.getValue();
+ auto vectorType = transferReadOp.getVectorType();
+ auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+ Value castedAccessChain =
+ rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(transferReadOp, vectorType,
+ castedAccessChain);
+
+ return success();
+ }
+};
+
+struct VectorTransferWriteOpConverter final
----------------
antiagainst wrote:
Comments in the above applies to this pattern too.
https://github.com/llvm/llvm-project/pull/69708
More information about the Mlir-commits
mailing list