[Mlir-commits] [mlir] edca177 - [mlir] Add canonicalizer to remove redundant shape.cstr_broadcastable ops
Eugene Zhulenev
llvmlistbot at llvm.org
Sun Feb 6 14:46:49 PST 2022
Author: Eugene Zhulenev
Date: 2022-02-06T14:46:42-08:00
New Revision: edca177cbeb66bc0f4cb1b1458633b57c1ee33a5
URL: https://github.com/llvm/llvm-project/commit/edca177cbeb66bc0f4cb1b1458633b57c1ee33a5
DIFF: https://github.com/llvm/llvm-project/commit/edca177cbeb66bc0f4cb1b1458633b57c1ee33a5.diff
LOG: [mlir] Add canonicalizer to remove redundant shape.cstr_broadcastable ops
Depends On D119025
Reviewed By: frgossen
Differential Revision: https://reviews.llvm.org/D119043
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ecf9ade0b05c5..a25f6dd7b1cb5 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
@@ -493,6 +494,99 @@ struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
}
};
+// Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
+// are subsumed by others.
+//
+// %0 = shape.cstr_broadcastable %shape0, %shape1
+// %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
+//
+// %2 = shape.cstr_broadcastable %shape3, %shape4
+// %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
+//
+// %4 = shape.assuming_all %0, %1, %2, %3
+//
+// to:
+//
+// %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
+// %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
+// %2 = shape.assuming_all %0, %1
+//
+// In this example if shapes [0, 1, 2] are broadcastable, then it means that
+// shapes [0, 1] are broadcastable too, and can be removed from the list of
+// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
+// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
+struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
+ using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AssumingAllOp op,
+ PatternRewriter &rewriter) const override {
+ // Collect all `CstrBroadcastableOp` operands first.
+ SetVector<CstrBroadcastableOp> operands;
+ for (Value operand : op.getInputs()) {
+ // TODO: Apply this optimization if some of the witnesses are not
+ // produced by the `cstr_broadcastable`.
+ auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
+ if (!broadcastable)
+ return failure();
+
+ operands.insert(broadcastable);
+ }
+
+ // Skip trivial `assuming_all` operations.
+ if (operands.size() <= 1)
+ return failure();
+
+ // Collect shapes checked by `cstr_broadcastable` operands.
+ SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
+ for (auto cstr : operands) {
+ DenseSet<Value> shapes_set(cstr->operand_begin(), cstr->operand_end());
+ shapes.emplace_back(cstr, std::move(shapes_set));
+ }
+
+ // Sort by the number of shape operands (larger to smaller).
+ llvm::sort(shapes, [](auto a, auto b) {
+ return a.first.getNumOperands() > b.first.getNumOperands();
+ });
+
+ // We start from the `cst_broadcastable` operations with largest number of
+ // shape operands, and remove redundant `cst_broadcastable` operations. We
+ // do this until we find a set of `cst_broadcastable` operations with
+ // non-overlapping constraints.
+ SmallVector<CstrBroadcastableOp> marked_for_erase;
+
+ for (unsigned i = 0; i < shapes.size(); ++i) {
+ auto isSubset = [&](auto pair) {
+ return llvm::set_is_subset(pair.second, shapes[i].second);
+ };
+
+ // Keep redundant `cstr_broadcastable` operations to be erased.
+ auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
+ for (auto *it0 = it; it0 < shapes.end(); ++it0)
+ marked_for_erase.push_back(it0->first);
+ shapes.erase(it, shapes.end());
+ }
+
+ // We didn't find any operands that could be removed.
+ if (marked_for_erase.empty())
+ return failure();
+
+ // Collect non-overlapping `cst_broadcastable` constraints.
+ SmallVector<Value> unique_constraints;
+ for (auto &shape : shapes)
+ unique_constraints.push_back(shape.first.getResult());
+
+ // Replace with a new `assuming_all` operation ...
+ rewriter.replaceOpWithNewOp<AssumingAllOp>(op, unique_constraints);
+
+ // ... and maybe erase `cstr_broadcastable` ops without uses.
+ for (auto &op : marked_for_erase)
+ if (op->use_empty())
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
struct AssumingAllToCstrEqCanonicalization
: public OpRewritePattern<AssumingAllOp> {
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
@@ -539,9 +633,10 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<MergeAssumingAllOps, AssumingAllOneOp,
- AssumingAllToCstrEqCanonicalization,
- RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
+ patterns
+ .add<MergeAssumingAllOps, AssumingAllOneOp,
+ AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
+ RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
}
OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 425f8fe71a42f..470f22190f738 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -565,6 +565,46 @@ func @f() {
// -----
+// merge cstr_broadcastable operations
+//
+// CHECK-LABEL: func @f
+// CHECK: %[[ARG0:[a-z0-9]*]]: !shape.shape
+// CHECK-SAME: %[[ARG1:[a-z0-9]*]]: !shape.shape
+// CHECK-SAME: %[[ARG2:[a-z0-9]*]]: !shape.shape
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape, %arg2 : !shape.shape) {
+ // CHECK-NEXT: %[[W:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]], %[[ARG2]]
+ // CHECK-NEXT: "consume.witness"(%[[W]])
+ // CHECK-NEXT: return
+ %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
+ %1 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : !shape.shape, !shape.shape, !shape.shape
+ %2 = shape.assuming_all %0, %1
+ "consume.witness"(%2) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+
+// do not merge cstr_broadcastable operations
+//
+// CHECK-LABEL: func @f
+// CHECK: %[[ARG0:[a-z0-9]*]]: !shape.shape
+// CHECK-SAME: %[[ARG1:[a-z0-9]*]]: !shape.shape
+// CHECK-SAME: %[[ARG2:[a-z0-9]*]]: !shape.shape
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape, %arg2 : !shape.shape) {
+ // CHECK-NEXT: %[[W0:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]]
+ // CHECK-NEXT: %[[W1:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
+ // CHECK-NEXT: %[[W2:.*]] = shape.assuming_all %[[W0]], %[[W1]]
+ // CHECK-NEXT: "consume.witness"(%[[W2]])
+ // CHECK-NEXT: return
+ %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
+ %1 = shape.cstr_broadcastable %arg1, %arg2 : !shape.shape, !shape.shape
+ %2 = shape.assuming_all %0, %1
+ "consume.witness"(%2) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+
// any can be replaced with a constant input if it has one.
// CHECK-LABEL: func @f
func @f(%arg : !shape.shape) -> !shape.shape {
More information about the Mlir-commits
mailing list