[Mlir-commits] [mlir] 79d12de - [MLIR][Shape] Canonicalize `assuming_all` when all operands are `cstr_eq` ops

Frederik Gossen llvmlistbot at llvm.org
Fri Apr 9 02:49:56 PDT 2021


Author: Frederik Gossen
Date: 2021-04-09T11:49:29+02:00
New Revision: 79d12ded535b14e9af242944a588da7cea1202c7

URL: https://github.com/llvm/llvm-project/commit/79d12ded535b14e9af242944a588da7cea1202c7
DIFF: https://github.com/llvm/llvm-project/commit/79d12ded535b14e9af242944a588da7cea1202c7.diff

LOG: [MLIR][Shape] Canonicalize `assuming_all` when all operands are `cstr_eq` ops

Differential Revision: https://reviews.llvm.org/D100104

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 0529357e35b6..31943368dbad 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -381,9 +381,30 @@ void AssumingOp::build(
 // AssumingAllOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+struct AssumingAllToCstrEqCanonicalization
+    : public OpRewritePattern<AssumingAllOp> {
+  using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AssumingAllOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value, 8> shapes;
+    for (Value v : op.inputs()) {
+      auto cstrEqOp = v.getDefiningOp<CstrEqOp>();
+      if (!cstrEqOp)
+        return failure();
+      auto range = cstrEqOp.shapes();
+      shapes.append(range.begin(), range.end());
+    }
+    rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
+    return success();
+  }
+};
+} // namespace
+
 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {
-  patterns.add<AssumingAllOneOp>(context);
+  patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization>(context);
 }
 
 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 883f91672c00..0631c37395d2 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -433,6 +433,20 @@ func @cstr_require_no_fold(%arg0: i1) {
   return
 }
 
+// -----
+// `assuming_all` with all `cstr_eq` can be collapsed.
+// CHECK-LABEL: func @assuming_all_to_cstr_eq
+// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<3xindex>)
+func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
+    %c : tensor<3xindex>) -> !shape.witness {
+  // CHECK: %[[RESULT:.*]] = shape.cstr_eq %[[A]], %[[B]], %[[B]], %[[C]]
+  // CHECK: return %[[RESULT]]
+  %0 = shape.cstr_eq %a, %b : !shape.shape, tensor<?xindex>
+  %1 = shape.cstr_eq %b, %c : tensor<?xindex>, tensor<3xindex>
+  %2 = shape.assuming_all %0, %1
+  return %2 : !shape.witness
+}
+
 // -----
 // assuming_all with known passing witnesses can be folded
 // CHECK-LABEL: func @f


        


More information about the Mlir-commits mailing list