[Mlir-commits] [mlir] a89e55c - [mlir][std] Canonicalize a dim(memref_reshape) into a load from the shape operand
Stephan Herhut
llvmlistbot at llvm.org
Fri Nov 20 05:03:17 PST 2020
Author: Stephan Herhut
Date: 2020-11-20T14:03:02+01:00
New Revision: a89e55ca572ee25f6e6104d76a7dcbd6f07ebbe3
URL: https://github.com/llvm/llvm-project/commit/a89e55ca572ee25f6e6104d76a7dcbd6f07ebbe3
DIFF: https://github.com/llvm/llvm-project/commit/a89e55ca572ee25f6e6104d76a7dcbd6f07ebbe3.diff
LOG: [mlir][std] Canonicalize a dim(memref_reshape) into a load from the shape operand
This canonicalization helps propagate shape information through the program.
Differential Revision: https://reviews.llvm.org/D91854
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 1ad3df63c1c9..cae7212c8379 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1753,6 +1753,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
Optional<int64_t> getConstantIndex();
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 04efc25a92ee..000d61012326 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1555,6 +1555,34 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+namespace {
+/// Fold dim of a memref reshape operation to a load into the reshape's shape
+/// operand.
+struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dim,
+ PatternRewriter &rewriter) const override {
+ auto reshape = dim.memrefOrTensor().getDefiningOp<MemRefReshapeOp>();
+
+ if (!reshape)
+ return failure();
+
+ // Place the load directly after the reshape to ensure that the shape memref
+ // was not mutated.
+ rewriter.setInsertionPointAfter(reshape);
+ rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
+ llvm::makeArrayRef({dim.index()}));
+ return success();
+ }
+};
+} // end anonymous namespace.
+
+void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<DimOfMemRefReshape>(context);
+}
+
// ---------------------------------------------------------------------------
// DmaStartOp
// ---------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index ebc59c8dbeac..74401cb6c723 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -95,3 +95,23 @@ func @cmpi_equal_operands(%arg0: i64)
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
: i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
}
+
+// Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
+// CHECK-NEXT: %[[IDX:.*]] = constant 3
+// CHECK-NEXT: %[[DIM:.*]] = load %[[SHP]][%[[IDX]]]
+// CHECK-NEXT: store
+// CHECK-NOT: dim
+// CHECK: return %[[DIM]] : index
+func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
+ -> index {
+ %c3 = constant 3 : index
+ %0 = memref_reshape %arg0(%arg1)
+ : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+ // Update the shape to test that he load ends up in the right place.
+ store %c3, %arg1[%c3] : memref<?xindex>
+ %1 = dim %0, %c3 : memref<*xf32>
+ return %1 : index
+}
More information about the Mlir-commits
mailing list