[Mlir-commits] [mlir] ba916c0 - [mlir][MemRef] Canonicalize reinterpret_cast(extract_strided_metadata)

Quentin Colombet llvmlistbot at llvm.org
Mon Aug 29 10:01:39 PDT 2022


Author: Quentin Colombet
Date: 2022-08-29T17:00:50Z
New Revision: ba916c0cf6d0149f81bf1137e88f7d6fd3b0cc76

URL: https://github.com/llvm/llvm-project/commit/ba916c0cf6d0149f81bf1137e88f7d6fd3b0cc76
DIFF: https://github.com/llvm/llvm-project/commit/ba916c0cf6d0149f81bf1137e88f7d6fd3b0cc76.diff

LOG: [mlir][MemRef] Canonicalize reinterpret_cast(extract_strided_metadata)

Add a canonicalizetion step for
reinterpret_cast(extract_strided_metadata).
This step replaces this sequence of operations by either:
- A noop, i.e., the original memref is directly used, or
- A plain cast of the original memref

The choice is ultimately made based on whether the original memref type
is equal to what the reinterpret_cast iss producing. For instance, the
reinterpret_cast could be changing some dimensions from static to
dynamic and in such case, we need to keep a cast.

The transformation is currently only performed when the reinterpret_cast
uses exactly the same arguments as what the extract_strided_metadata
produces. It may be possible to be more aggressive here but I wanted to
start with a relatively simple MLIR patch for my first one!

Differential Revision: https://reviews.llvm.org/D132776

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 5ef8d8fcefb5a..a9b9d54f0e127 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1142,6 +1142,7 @@ def MemRef_ReinterpretCastOp
   }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 616b228051e5e..a2c49db4b51b8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1600,6 +1600,65 @@ OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
   return nullptr;
 }
 
+namespace {
+/// Replace reinterpret_cast(extract_strided_metadata memref) -> memref.
+struct ReinterpretCastOpExtractStridedMetadataFolder
+    : public OpRewritePattern<ReinterpretCastOp> {
+public:
+  using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReinterpretCastOp op,
+                                PatternRewriter &rewriter) const override {
+    auto extractStridedMetadata =
+        op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
+    if (!extractStridedMetadata)
+      return failure();
+    // Check if the reinterpret cast reconstructs a memref with the exact same
+    // properties as the extract strided metadata.
+
+    // First, check that the strides are the same.
+    if (extractStridedMetadata.getStrides().size() != op.getStrides().size())
+      return failure();
+    for (auto [extractStride, reinterpretStride] :
+         llvm::zip(extractStridedMetadata.getStrides(), op.getStrides()))
+      if (extractStride != reinterpretStride)
+        return failure();
+
+    // Second, check the sizes.
+    if (extractStridedMetadata.getSizes().size() != op.getSizes().size())
+      return failure();
+    for (auto [extractSize, reinterpretSize] :
+         llvm::zip(extractStridedMetadata.getSizes(), op.getSizes()))
+      if (extractSize != reinterpretSize)
+        return failure();
+
+    // Finally, check the offset.
+    if (op.getOffsets().size() != 1 &&
+        extractStridedMetadata.getOffset() != *op.getOffsets().begin())
+      return failure();
+
+    // At this point, we know that the back and forth between extract strided
+    // metadata and reinterpret cast is a noop. However, the final type of the
+    // reinterpret cast may not be exactly the same as the original memref.
+    // E.g., it could be changing a dimension from static to dynamic. Check that
+    // here and add a cast if necessary.
+    Type srcTy = extractStridedMetadata.getSource().getType();
+    if (srcTy == op.getResult().getType())
+      rewriter.replaceOp(op, extractStridedMetadata.getSource());
+    else
+      rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
+                                          extractStridedMetadata.getSource());
+
+    return success();
+  }
+};
+} // namespace
+
+void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                    MLIRContext *context) {
+  results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Reassociative reshape ops
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 138a0f4474084..454277599dc71 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -740,6 +740,63 @@ func.func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: in
 
 // -----
 
+// Check that a reinterpret cast of an equivalent extract strided metadata
+// is canonicalized to a plain cast when the destination type is 
diff erent
+// than the type of the original memref.
+// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach
+//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
+//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
+//       CHECK: return %[[CAST]]
+func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, offset : ?, strides : [?, ?]> {
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
+  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+  return %m2 : memref<?x?xf32, offset: ?, strides: [?, ?]>
+}
+
+// -----
+
+// Check that a reinterpret cast of an equivalent extract strided metadata
+// is completely removed when the original memref has the same type.
+// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type
+//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
+//       CHECK: return %[[ARG]]
+func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<8x2xf32>) -> memref<8x2xf32> {
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
+  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<8x2xf32>
+  return %m2 : memref<8x2xf32>
+}
+
+// -----
+
+// Check that we don't simplify reinterpret cast of extract strided metadata
+// when the strides don't match.
+// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_
diff erent_stride
+//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
+//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [4, 2, 2], strides: [1, 1, %[[STRIDES]]#1]
+//       CHECK: return %[[RES]]
+func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]> {
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
+  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+  return %m2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+}
+// -----
+
+// Check that we don't simplify reinterpret cast of extract strided metadata
+// when the offset doesn't match.
+// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_
diff erent_offset
+//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
+//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[SIZES]]#0, %[[SIZES]]#1], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1]
+//       CHECK: return %[[RES]]
+func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, offset : ?, strides : [?, ?]> {
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
+  %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+  return %m2 : memref<?x?xf32, offset: ?, strides: [?, ?]>
+}
+
+// -----
+
 func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>,
     %arg1 : index) -> memref<?xf32, offset : ?, strides : [?]> {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list