[Mlir-commits] [mlir] 5fca4ce - [mlir][spirv] Lower memref.reinterpret_cast

Ivan Butygin llvmlistbot at llvm.org
Thu Jul 13 06:54:59 PDT 2023


Author: Ivan Butygin
Date: 2023-07-13T15:54:21+02:00
New Revision: 5fca4ce1fd5aeac13cbb917ab6083c8358ce39ae

URL: https://github.com/llvm/llvm-project/commit/5fca4ce1fd5aeac13cbb917ab6083c8358ce39ae
DIFF: https://github.com/llvm/llvm-project/commit/5fca4ce1fd5aeac13cbb917ab6083c8358ce39ae.diff

LOG: [mlir][spirv] Lower memref.reinterpret_cast

For kernel SPIR-V, we are lowering memref to bare pointers, so reinterpret can be lowered to pointer, adjusted by offset value.

Differential Revision: https://reviews.llvm.org/D155011

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 28da42966e7337..ce1f74cfaeb488 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -257,6 +257,16 @@ class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+class ReinterpretCastPattern final
+    : public OpConversionPattern<memref::ReinterpretCastOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -716,6 +726,52 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   return success();
 }
 
+LogicalResult ReinterpretCastPattern::matchAndRewrite(
+    memref::ReinterpretCastOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Value src = adaptor.getSource();
+  auto srcType = dyn_cast<spirv::PointerType>(src.getType());
+
+  if (!srcType)
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "invalid src type " << src.getType();
+    });
+
+  TypeConverter *converter = getTypeConverter();
+
+  auto dstType = converter->convertType<spirv::PointerType>(op.getType());
+  if (dstType != srcType)
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "invalid dst type " << op.getType();
+    });
+
+  OpFoldResult offset =
+      getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
+          .front();
+  if (isConstantIntValue(offset, 0)) {
+    rewriter.replaceOp(op, src);
+    return success();
+  }
+
+  Type intType = converter->convertType(rewriter.getIndexType());
+  if (!intType)
+    return rewriter.notifyMatchFailure(op, "failed to convert index type");
+
+  Location loc = op.getLoc();
+  auto offsetValue = [&]() -> Value {
+    if (auto val = dyn_cast<Value>(offset))
+      return val;
+
+    int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
+    Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
+    return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
+  }();
+
+  rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
+      op, src, offsetValue, std::nullopt);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pattern population
 //===----------------------------------------------------------------------===//
@@ -723,9 +779,10 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
 namespace mlir {
 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns) {
-  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
-               DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
-               LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern>(
-      typeConverter, patterns.getContext());
+  patterns
+      .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+           DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
+           MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern>(
+          typeConverter, patterns.getContext());
 }
 } // namespace mlir

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index ef77dc9e75933e..895316358beab0 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -520,3 +520,49 @@ func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
 }
 
 } // end module
+
+// -----
+
+// Check reinterpret_casts
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0,
+      [Kernel, Addresses, GenericPointer], []>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: func.func @reinterpret_cast
+//  CHECK-SAME:  (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>, %[[OFF:.*]]: index)
+func.func @reinterpret_cast(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>, %arg1: index) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
+//       CHECK:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
+//       CHECK:  %[[OFF1:.*]] = builtin.unrealized_conversion_cast %[[OFF]] : index to i32
+//       CHECK:  %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF1]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+//       CHECK:  %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+//       CHECK:  return %[[RET1]]
+  %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+  return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+}
+
+// CHECK-LABEL: func.func @reinterpret_cast_0
+//  CHECK-SAME:  (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
+func.func @reinterpret_cast_0(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
+//       CHECK:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
+//       CHECK:  %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+//       CHECK:  return %[[RET]]
+  %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+  return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+}
+
+// CHECK-LABEL: func.func @reinterpret_cast_5
+//  CHECK-SAME:  (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
+func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>> {
+//       CHECK:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
+//       CHECK:  %[[OFF:.*]] = spirv.Constant 5 : i32
+//       CHECK:  %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr<f32, CrossWorkgroup>, i32
+//       CHECK:  %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+//       CHECK:  return %[[RET1]]
+  %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+  return %ret : memref<?xf32, strided<[1], offset: ?>, #spirv.storage_class<CrossWorkgroup>>
+}
+
+} // end module


        


More information about the Mlir-commits mailing list