[Mlir-commits] [mlir] c50f335 - [mlir][spirv] `memref.cast` to SPIR-V conversion

Ivan Butygin llvmlistbot at llvm.org
Wed Jul 26 04:20:44 PDT 2023


Author: Ivan Butygin
Date: 2023-07-26T13:20:21+02:00
New Revision: c50f335ba556b2c2780a34c252e7efdc27e06483

URL: https://github.com/llvm/llvm-project/commit/c50f335ba556b2c2780a34c252e7efdc27e06483
DIFF: https://github.com/llvm/llvm-project/commit/c50f335ba556b2c2780a34c252e7efdc27e06483.diff

LOG: [mlir][spirv] `memref.cast` to SPIR-V conversion

Differential Revision: https://reviews.llvm.org/D156251

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index ce1f74cfaeb488..16bedc6b9858e1 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -267,6 +267,28 @@ class ReinterpretCastPattern final
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+class CastPattern final : public OpConversionPattern<memref::CastOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value src = adaptor.getSource();
+    Type srcType = src.getType();
+
+    TypeConverter *converter = getTypeConverter();
+    Type dstType = converter->convertType(op.getType());
+    if (srcType != dstType)
+      return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+        diag << "types doesn't match: " << srcType << " and " << dstType;
+      });
+
+    rewriter.replaceOp(op, src);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -779,10 +801,10 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
 namespace mlir {
 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns) {
-  patterns
-      .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
-           DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
-           MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern>(
-          typeConverter, patterns.getContext());
+  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+               DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
+               LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
+               ReinterpretCastPattern, CastPattern>(typeConverter,
+                                                    patterns.getContext());
 }
 } // namespace mlir

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 895316358beab0..2b3678542e8db4 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -566,3 +566,46 @@ func.func @reinterpret_cast_5(%arg: memref<?xf32, #spirv.storage_class<CrossWork
 }
 
 } // end module
+
+
+// -----
+
+// Check casts
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0,
+      [Kernel, Addresses, GenericPointer], []>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: func.func @cast
+//  CHECK-SAME:  (%[[MEM:.*]]: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>)
+func.func @cast(%arg: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> {
+//       CHECK:  %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
+//       CHECK:  %[[MEM2:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr<f32, CrossWorkgroup> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
+//       CHECK:  return %[[MEM2]]
+  %ret = memref.cast %arg : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
+  return %ret : memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
+}
+
+// TODO: Not supported yet
+// CHECK-LABEL: func.func @cast_from_static
+//  CHECK-SAME:  (%[[MEM:.*]]: memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>)
+func.func @cast_from_static(%arg: memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>> {
+//       CHECK:  %[[MEM1:.*]] =  memref.cast %[[MEM]] : memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
+//       CHECK:  return %[[MEM1]]
+  %ret = memref.cast %arg : memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> to memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
+  return %ret : memref<?x4xf32, #spirv.storage_class<CrossWorkgroup>>
+}
+
+// TODO: Not supported yet
+// CHECK-LABEL: func.func @cast_to_static
+//  CHECK-SAME:  (%[[MEM:.*]]: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>)
+func.func @cast_to_static(%arg: memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>> {
+//       CHECK:  %[[MEM1:.*]] =  memref.cast %[[MEM]] : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>
+//       CHECK:  return %[[MEM1]]
+  %ret = memref.cast %arg : memref<4x?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>
+  return %ret : memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>
+}
+
+}


        


More information about the Mlir-commits mailing list