[Mlir-commits] [mlir] a89e55c - [mlir][std] Canonicalize a dim(memref_reshape) into a load from the shape operand

Stephan Herhut llvmlistbot at llvm.org
Fri Nov 20 05:03:17 PST 2020


Author: Stephan Herhut
Date: 2020-11-20T14:03:02+01:00
New Revision: a89e55ca572ee25f6e6104d76a7dcbd6f07ebbe3

URL: https://github.com/llvm/llvm-project/commit/a89e55ca572ee25f6e6104d76a7dcbd6f07ebbe3
DIFF: https://github.com/llvm/llvm-project/commit/a89e55ca572ee25f6e6104d76a7dcbd6f07ebbe3.diff

LOG: [mlir][std] Canonicalize a dim(memref_reshape) into a load from the shape operand

This canonicalization helps propagate shape information through the program.

Differential Revision: https://reviews.llvm.org/D91854

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 1ad3df63c1c9..cae7212c8379 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1753,6 +1753,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
     Optional<int64_t> getConstantIndex();
   }];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 04efc25a92ee..000d61012326 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1555,6 +1555,34 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+namespace {
+/// Fold dim of a memref reshape operation to a load into the reshape's shape
+/// operand.
+struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dim,
+                                PatternRewriter &rewriter) const override {
+    auto reshape = dim.memrefOrTensor().getDefiningOp<MemRefReshapeOp>();
+
+    if (!reshape)
+      return failure();
+
+    // Place the load directly after the reshape to ensure that the shape memref
+    // was not mutated.
+    rewriter.setInsertionPointAfter(reshape);
+    rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
+                                        llvm::makeArrayRef({dim.index()}));
+    return success();
+  }
+};
+} // end anonymous namespace.
+
+void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                        MLIRContext *context) {
+  results.insert<DimOfMemRefReshape>(context);
+}
+
 // ---------------------------------------------------------------------------
 // DmaStartOp
 // ---------------------------------------------------------------------------

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index ebc59c8dbeac..74401cb6c723 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -95,3 +95,23 @@ func @cmpi_equal_operands(%arg0: i64)
   return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
       : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
 }
+
+// Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xindex>
+//  CHECK-NEXT:   %[[IDX:.*]] = constant 3
+//  CHECK-NEXT:   %[[DIM:.*]] = load %[[SHP]][%[[IDX]]]
+//  CHECK-NEXT:   store
+//   CHECK-NOT:   dim
+//       CHECK:   return %[[DIM]] : index
+func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
+    -> index {
+  %c3 = constant 3 : index
+  %0 = memref_reshape %arg0(%arg1)
+      : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+  // Update the shape to test that he load ends up in the right place.
+  store %c3, %arg1[%c3] : memref<?xindex>
+  %1 = dim %0, %c3 : memref<*xf32>
+  return %1 : index
+}


        


More information about the Mlir-commits mailing list