[Mlir-commits] [mlir] 79d12de - [MLIR][Shape] Canonicalize `assuming_all` when all operands are `cstr_eq` ops
Frederik Gossen
llvmlistbot at llvm.org
Fri Apr 9 02:49:56 PDT 2021
Author: Frederik Gossen
Date: 2021-04-09T11:49:29+02:00
New Revision: 79d12ded535b14e9af242944a588da7cea1202c7
URL: https://github.com/llvm/llvm-project/commit/79d12ded535b14e9af242944a588da7cea1202c7
DIFF: https://github.com/llvm/llvm-project/commit/79d12ded535b14e9af242944a588da7cea1202c7.diff
LOG: [MLIR][Shape] Canonicalize `assuming_all` when all operands are `cstr_eq` ops
Differential Revision: https://reviews.llvm.org/D100104
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 0529357e35b6..31943368dbad 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -381,9 +381,30 @@ void AssumingOp::build(
// AssumingAllOp
//===----------------------------------------------------------------------===//
+namespace {
+struct AssumingAllToCstrEqCanonicalization
+ : public OpRewritePattern<AssumingAllOp> {
+ using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AssumingAllOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value, 8> shapes;
+ for (Value v : op.inputs()) {
+ auto cstrEqOp = v.getDefiningOp<CstrEqOp>();
+ if (!cstrEqOp)
+ return failure();
+ auto range = cstrEqOp.shapes();
+ shapes.append(range.begin(), range.end());
+ }
+ rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
+ return success();
+ }
+};
+} // namespace
+
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<AssumingAllOneOp>(context);
+ patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization>(context);
}
OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 883f91672c00..0631c37395d2 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -433,6 +433,20 @@ func @cstr_require_no_fold(%arg0: i1) {
return
}
+// -----
+// `assuming_all` with all `cstr_eq` can be collapsed.
+// CHECK-LABEL: func @assuming_all_to_cstr_eq
+// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<3xindex>)
+func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
+ %c : tensor<3xindex>) -> !shape.witness {
+ // CHECK: %[[RESULT:.*]] = shape.cstr_eq %[[A]], %[[B]], %[[B]], %[[C]]
+ // CHECK: return %[[RESULT]]
+ %0 = shape.cstr_eq %a, %b : !shape.shape, tensor<?xindex>
+ %1 = shape.cstr_eq %b, %c : tensor<?xindex>, tensor<3xindex>
+ %2 = shape.assuming_all %0, %1
+ return %2 : !shape.witness
+}
+
// -----
// assuming_all with known passing witnesses can be folded
// CHECK-LABEL: func @f
More information about the Mlir-commits
mailing list