[Mlir-commits] [mlir] e413b86 - [MLIR][Shape] Combine `cstr_eq` only if they share shape operands

Frederik Gossen llvmlistbot at llvm.org
Fri Apr 9 07:55:12 PDT 2021


Author: Frederik Gossen
Date: 2021-04-09T16:54:54+02:00
New Revision: e413b86a2c0c0225f562d16da1aa50aa9a5cb7aa

URL: https://github.com/llvm/llvm-project/commit/e413b86a2c0c0225f562d16da1aa50aa9a5cb7aa
DIFF: https://github.com/llvm/llvm-project/commit/e413b86a2c0c0225f562d16da1aa50aa9a5cb7aa.diff

LOG: [MLIR][Shape] Combine `cstr_eq` only if they share shape operands

Differential Revision: https://reviews.llvm.org/D100198

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 3fa314d74705..b82b28c88f61 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -389,12 +389,16 @@ struct AssumingAllToCstrEqCanonicalization
   LogicalResult matchAndRewrite(AssumingAllOp op,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Value, 8> shapes;
-    for (Value v : op.inputs()) {
-      auto cstrEqOp = v.getDefiningOp<CstrEqOp>();
+    for (Value w : op.inputs()) {
+      auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
       if (!cstrEqOp)
         return failure();
-      auto range = cstrEqOp.shapes();
-      shapes.append(range.begin(), range.end());
+      bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) {
+        return llvm::is_contained(shapes, s);
+      });
+      if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes)
+        return failure();
+      shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end());
     }
     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
     return success();

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 0631c37395d2..a2aaf7b84afc 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -434,7 +434,7 @@ func @cstr_require_no_fold(%arg0: i1) {
 }
 
 // -----
-// `assuming_all` with all `cstr_eq` can be collapsed.
+// `assuming_all` with all `cstr_eq` and shared operands can be collapsed.
 // CHECK-LABEL: func @assuming_all_to_cstr_eq
 // CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<3xindex>)
 func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
@@ -447,6 +447,22 @@ func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
   return %2 : !shape.witness
 }
 
+// -----
+// `assuming_all` with all `cstr_eq` but disjoint operands cannot be collapsed.
+// CHECK-LABEL: func @assuming_all_to_cstr_eq
+// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<3xindex>, %[[D:.*]]: tensor<3xindex>)
+func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
+    %c : tensor<3xindex>, %d : tensor<3xindex>) -> !shape.witness {
+  // CHECK: %[[EQ0:.*]] = shape.cstr_eq %[[A]], %[[B]]
+  // CHECK: %[[EQ1:.*]] = shape.cstr_eq %[[C]], %[[D]]
+  // CHECK: %[[RESULT:.*]] = shape.assuming_all %[[EQ0]], %[[EQ1]]
+  // CHECK: return %[[RESULT]]
+  %0 = shape.cstr_eq %a, %b : !shape.shape, tensor<?xindex>
+  %1 = shape.cstr_eq %c, %d : tensor<3xindex>, tensor<3xindex>
+  %2 = shape.assuming_all %0, %1
+  return %2 : !shape.witness
+}
+
 // -----
 // assuming_all with known passing witnesses can be folded
 // CHECK-LABEL: func @f


        


More information about the Mlir-commits mailing list