[Mlir-commits] [mlir] edca177 - [mlir] Add canonicalizer to remove redundant shape.cstr_broadcastable ops

Eugene Zhulenev llvmlistbot at llvm.org
Sun Feb 6 14:46:49 PST 2022


Author: Eugene Zhulenev
Date: 2022-02-06T14:46:42-08:00
New Revision: edca177cbeb66bc0f4cb1b1458633b57c1ee33a5

URL: https://github.com/llvm/llvm-project/commit/edca177cbeb66bc0f4cb1b1458633b57c1ee33a5
DIFF: https://github.com/llvm/llvm-project/commit/edca177cbeb66bc0f4cb1b1458633b57c1ee33a5.diff

LOG: [mlir] Add canonicalizer to remove redundant shape.cstr_broadcastable ops

Depends On D119025

Reviewed By: frgossen

Differential Revision: https://reviews.llvm.org/D119043

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ecf9ade0b05c5..a25f6dd7b1cb5 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/raw_ostream.h"
@@ -493,6 +494,99 @@ struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
   }
 };
 
+// Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
+// are subsumed by others.
+//
+//   %0 = shape.cstr_broadcastable %shape0, %shape1
+//   %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
+//
+//   %2 = shape.cstr_broadcastable %shape3, %shape4
+//   %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
+//
+//   %4 = shape.assuming_all %0, %1, %2, %3
+//
+// to:
+//
+//   %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
+//   %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
+//   %2 = shape.assuming_all %0, %1
+//
+// In this example if shapes [0, 1, 2] are broadcastable, then it means that
+// shapes [0, 1] are broadcastable too, and can be removed from the list of
+// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
+// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
+struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
+  using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AssumingAllOp op,
+                                PatternRewriter &rewriter) const override {
+    // Collect all `CstrBroadcastableOp` operands first.
+    SetVector<CstrBroadcastableOp> operands;
+    for (Value operand : op.getInputs()) {
+      // TODO: Apply this optimization if some of the witnesses are not
+      // produced by the `cstr_broadcastable`.
+      auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
+      if (!broadcastable)
+        return failure();
+
+      operands.insert(broadcastable);
+    }
+
+    // Skip trivial `assuming_all` operations.
+    if (operands.size() <= 1)
+      return failure();
+
+    // Collect shapes checked by `cstr_broadcastable` operands.
+    SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
+    for (auto cstr : operands) {
+      DenseSet<Value> shapes_set(cstr->operand_begin(), cstr->operand_end());
+      shapes.emplace_back(cstr, std::move(shapes_set));
+    }
+
+    // Sort by the number of shape operands (larger to smaller).
+    llvm::sort(shapes, [](auto a, auto b) {
+      return a.first.getNumOperands() > b.first.getNumOperands();
+    });
+
+    // We start from the `cst_broadcastable` operations with largest number of
+    // shape operands, and remove redundant `cst_broadcastable` operations. We
+    // do this until we find a set of `cst_broadcastable` operations with
+    // non-overlapping constraints.
+    SmallVector<CstrBroadcastableOp> marked_for_erase;
+
+    for (unsigned i = 0; i < shapes.size(); ++i) {
+      auto isSubset = [&](auto pair) {
+        return llvm::set_is_subset(pair.second, shapes[i].second);
+      };
+
+      // Keep redundant `cstr_broadcastable` operations to be erased.
+      auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
+      for (auto *it0 = it; it0 < shapes.end(); ++it0)
+        marked_for_erase.push_back(it0->first);
+      shapes.erase(it, shapes.end());
+    }
+
+    // We didn't find any operands that could be removed.
+    if (marked_for_erase.empty())
+      return failure();
+
+    // Collect non-overlapping `cst_broadcastable` constraints.
+    SmallVector<Value> unique_constraints;
+    for (auto &shape : shapes)
+      unique_constraints.push_back(shape.first.getResult());
+
+    // Replace with a new `assuming_all` operation ...
+    rewriter.replaceOpWithNewOp<AssumingAllOp>(op, unique_constraints);
+
+    // ... and maybe erase `cstr_broadcastable` ops without uses.
+    for (auto &op : marked_for_erase)
+      if (op->use_empty())
+        rewriter.eraseOp(op);
+
+    return success();
+  }
+};
+
 struct AssumingAllToCstrEqCanonicalization
     : public OpRewritePattern<AssumingAllOp> {
   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
@@ -539,9 +633,10 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
 
 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {
-  patterns.add<MergeAssumingAllOps, AssumingAllOneOp,
-               AssumingAllToCstrEqCanonicalization,
-               RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
+  patterns
+      .add<MergeAssumingAllOps, AssumingAllOneOp,
+           AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
+           RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
 }
 
 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 425f8fe71a42f..470f22190f738 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -565,6 +565,46 @@ func @f() {
 
 // -----
 
+// merge cstr_broadcastable operations
+//
+// CHECK-LABEL: func @f
+// CHECK:         %[[ARG0:[a-z0-9]*]]: !shape.shape
+// CHECK-SAME:    %[[ARG1:[a-z0-9]*]]: !shape.shape
+// CHECK-SAME:    %[[ARG2:[a-z0-9]*]]: !shape.shape
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape, %arg2 : !shape.shape) {
+  // CHECK-NEXT: %[[W:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]], %[[ARG2]]
+  // CHECK-NEXT: "consume.witness"(%[[W]])
+  // CHECK-NEXT: return
+  %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
+  %1 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : !shape.shape, !shape.shape, !shape.shape
+  %2 = shape.assuming_all %0, %1
+  "consume.witness"(%2) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+
+// do not merge cstr_broadcastable operations
+//
+// CHECK-LABEL: func @f
+// CHECK:         %[[ARG0:[a-z0-9]*]]: !shape.shape
+// CHECK-SAME:    %[[ARG1:[a-z0-9]*]]: !shape.shape
+// CHECK-SAME:    %[[ARG2:[a-z0-9]*]]: !shape.shape
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape, %arg2 : !shape.shape) {
+  // CHECK-NEXT: %[[W0:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]]
+  // CHECK-NEXT: %[[W1:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
+  // CHECK-NEXT: %[[W2:.*]] = shape.assuming_all %[[W0]], %[[W1]]
+  // CHECK-NEXT: "consume.witness"(%[[W2]])
+  // CHECK-NEXT: return
+  %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
+  %1 = shape.cstr_broadcastable %arg1, %arg2 : !shape.shape, !shape.shape
+  %2 = shape.assuming_all %0, %1
+  "consume.witness"(%2) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+
 // any can be replaced with a constant input if it has one.
 // CHECK-LABEL: func @f
 func @f(%arg : !shape.shape) -> !shape.shape {


        


More information about the Mlir-commits mailing list