[Mlir-commits] [mlir] e413b86 - [MLIR][Shape] Combine `cstr_eq` only if they share shape operands
Frederik Gossen
llvmlistbot at llvm.org
Fri Apr 9 07:55:12 PDT 2021
Author: Frederik Gossen
Date: 2021-04-09T16:54:54+02:00
New Revision: e413b86a2c0c0225f562d16da1aa50aa9a5cb7aa
URL: https://github.com/llvm/llvm-project/commit/e413b86a2c0c0225f562d16da1aa50aa9a5cb7aa
DIFF: https://github.com/llvm/llvm-project/commit/e413b86a2c0c0225f562d16da1aa50aa9a5cb7aa.diff
LOG: [MLIR][Shape] Combine `cstr_eq` only if they share shape operands
Differential Revision: https://reviews.llvm.org/D100198
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 3fa314d74705..b82b28c88f61 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -389,12 +389,16 @@ struct AssumingAllToCstrEqCanonicalization
LogicalResult matchAndRewrite(AssumingAllOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 8> shapes;
- for (Value v : op.inputs()) {
- auto cstrEqOp = v.getDefiningOp<CstrEqOp>();
+ for (Value w : op.inputs()) {
+ auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
if (!cstrEqOp)
return failure();
- auto range = cstrEqOp.shapes();
- shapes.append(range.begin(), range.end());
+ bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) {
+ return llvm::is_contained(shapes, s);
+ });
+ if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes)
+ return failure();
+ shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end());
}
rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
return success();
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 0631c37395d2..a2aaf7b84afc 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -434,7 +434,7 @@ func @cstr_require_no_fold(%arg0: i1) {
}
// -----
-// `assuming_all` with all `cstr_eq` can be collapsed.
+// `assuming_all` with all `cstr_eq` and shared operands can be collapsed.
// CHECK-LABEL: func @assuming_all_to_cstr_eq
// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<3xindex>)
func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
@@ -447,6 +447,22 @@ func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
return %2 : !shape.witness
}
+// -----
+// `assuming_all` with all `cstr_eq` but disjoint operands cannot be collapsed.
+// CHECK-LABEL: func @assuming_all_to_cstr_eq
+// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<3xindex>, %[[D:.*]]: tensor<3xindex>)
+func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
+ %c : tensor<3xindex>, %d : tensor<3xindex>) -> !shape.witness {
+ // CHECK: %[[EQ0:.*]] = shape.cstr_eq %[[A]], %[[B]]
+ // CHECK: %[[EQ1:.*]] = shape.cstr_eq %[[C]], %[[D]]
+ // CHECK: %[[RESULT:.*]] = shape.assuming_all %[[EQ0]], %[[EQ1]]
+ // CHECK: return %[[RESULT]]
+ %0 = shape.cstr_eq %a, %b : !shape.shape, tensor<?xindex>
+ %1 = shape.cstr_eq %c, %d : tensor<3xindex>, tensor<3xindex>
+ %2 = shape.assuming_all %0, %1
+ return %2 : !shape.witness
+}
+
// -----
// assuming_all with known passing witnesses can be folded
// CHECK-LABEL: func @f
More information about the Mlir-commits
mailing list