[Mlir-commits] [mlir] 5fca4ce - [mlir][spirv] Lower memref.reinterpret_cast
Ivan Butygin
llvmlistbot at llvm.org
Thu Jul 13 06:54:59 PDT 2023
Author: Ivan Butygin
Date: 2023-07-13T15:54:21+02:00
New Revision: 5fca4ce1fd5aeac13cbb917ab6083c8358ce39ae
URL: https://github.com/llvm/llvm-project/commit/5fca4ce1fd5aeac13cbb917ab6083c8358ce39ae
DIFF: https://github.com/llvm/llvm-project/commit/5fca4ce1fd5aeac13cbb917ab6083c8358ce39ae.diff
LOG: [mlir][spirv] Lower memref.reinterpret_cast
For kernel SPIR-V, we are lowering memref to bare pointers, so reinterpret can be lowered to pointer, adjusted by offset value.
Differential Revision: https://reviews.llvm.org/D155011
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 28da42966e7337..ce1f74cfaeb488 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -257,6 +257,16 @@ class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
ConversionPatternRewriter &rewriter) const override;
};
+class ReinterpretCastPattern final
+ : public OpConversionPattern<memref::ReinterpretCastOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -716,6 +726,52 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
return success();
}
+LogicalResult ReinterpretCastPattern::matchAndRewrite(
+ memref::ReinterpretCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Value src = adaptor.getSource();
+ auto srcType = dyn_cast<spirv::PointerType>(src.getType());
+
+ if (!srcType)
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "invalid src type " << src.getType();
+ });
+
+ TypeConverter *converter = getTypeConverter();
+
+ auto dstType = converter->convertType<spirv::PointerType>(op.getType());
+ if (dstType != srcType)
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "invalid dst type " << op.getType();
+ });
+
+ OpFoldResult offset =
+ getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
+ .front();
+ if (isConstantIntValue(offset, 0)) {
+ rewriter.replaceOp(op, src);
+ return success();
+ }
+
+ Type intType = converter->convertType(rewriter.getIndexType());
+ if (!intType)
+ return rewriter.notifyMatchFailure(op, "failed to convert index type");
+
+ Location loc = op.getLoc();
+ auto offsetValue = [&]() -> Value {
+ if (auto val = dyn_cast<Value>(offset))
+ return val;
+
+ int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
+ Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
+ return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
+ }();
+
+ rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
+ op, src, offsetValue, std::nullopt);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
@@ -723,9 +779,10 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
namespace mlir {
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
- DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
- LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+ DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
+ MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern>(
+ typeConverter, patterns.getContext());
}
} // namespace mlir
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index ef77dc9e75933e..895316358beab0 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -520,3 +520,49 @@ func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
}
} // end module
+
+// -----
+
+// Check reinterpret_casts
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0,
+ [Kernel, Addresses, GenericPointer], []>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: func.func @reinterpret_cast
+// CHECK-SAME: (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>, %[[OFF:.*]]: index)
+func.func @reinterpret_cast(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>, %arg1: index) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
+// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
+// CHECK: %[[OFF1:.*]] = builtin.unrealized_conversion_cast %[[OFF]] : index to i32
+// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+// CHECK: return %[[RET1]]
+ %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+ return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+}
+
+// CHECK-LABEL: func.func @reinterpret_cast_0
+// CHECK-SAME: (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
+func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
+// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
+// CHECK: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+// CHECK: return %[[RET]]
+ %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+ return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+}
+
+// CHECK-LABEL: func.func @reinterpret_cast_5
+// CHECK-SAME: (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
+func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
+// CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
+// CHECK: %[[OFF:.*]] = spirv.Constant 5 : i32
+// CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+// CHECK: return %[[RET1]]
+ %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+ return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+}
+
+} // end module
More information about the Mlir-commits
mailing list