[Mlir-commits] [mlir] 538254e - [MLIR] Do not yield values from an assuming op that are never used

Frederik Gossen llvmlistbot at llvm.org
Fri Apr 9 02:07:05 PDT 2021


Author: Frederik Gossen
Date: 2021-04-09T11:06:41+02:00
New Revision: 538254e8e0e09e89776f21bd39c23be1f5868fa1

URL: https://github.com/llvm/llvm-project/commit/538254e8e0e09e89776f21bd39c23be1f5868fa1
DIFF: https://github.com/llvm/llvm-project/commit/538254e8e0e09e89776f21bd39c23be1f5868fa1.diff

LOG: [MLIR] Do not yield values from an assuming op that are never used

Differential Revision: https://reviews.llvm.org/D100042

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 388a3a5763b1..0529357e35b6 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -268,12 +268,57 @@ struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
     return success();
   }
 };
+
+struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
+  using OpRewritePattern<AssumingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AssumingOp op,
+                                PatternRewriter &rewriter) const override {
+    Block *body = op.getBody();
+    auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
+
+    // Find used values.
+    SmallVector<Value, 4> newYieldOperands;
+    Value opResult, yieldOperand;
+    for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) {
+      std::tie(opResult, yieldOperand) = it;
+      if (!opResult.getUses().empty()) {
+        newYieldOperands.push_back(yieldOperand);
+      }
+    }
+
+    // Rewrite only if redundant results exist.
+    if (newYieldOperands.size() == yieldOp->getNumOperands())
+      return failure();
+
+    // Replace yield op in the old assuming op's body and move the entire region
+    // to the new assuming op.
+    rewriter.setInsertionPointToEnd(body);
+    auto newYieldOp =
+        rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
+    rewriter.setInsertionPoint(op);
+    auto newOp = rewriter.create<AssumingOp>(
+        op.getLoc(), newYieldOp->getOperandTypes(), op.witness());
+    newOp.doRegion().takeBody(op.doRegion());
+
+    // Use the new results to replace the previously used ones.
+    SmallVector<Value, 4> replacementValues;
+    auto src = newOp.getResults().begin();
+    for (auto it : op.getResults()) {
+      if (it.getUses().empty())
+        replacementValues.push_back(nullptr);
+      else
+        replacementValues.push_back(*src++);
+    }
+    rewriter.replaceOp(op, replacementValues);
+    return success();
+  }
+};
 } // namespace
 
 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                              MLIRContext *context) {
-  // If taking a passing witness, inline region.
-  patterns.add<AssumingWithTrue>(context);
+  patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
 }
 
 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 86ac4c9af963..883f91672c00 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -543,6 +543,27 @@ func @f() {
   return
 }
 
+// -----
+
+// Remove unused results from assuming ops.
+// CHECK-LABEL: func @unused_assuming_results
+func @unused_assuming_results() {
+  // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %0 -> (f32) {
+  // CHECK:   %{{.*}} = "produce.redundant"
+  // CHECK:   %[[MEANINGFUL:.*]] = "produce.meaningful"
+  // CHECK:   shape.assuming_yield %[[MEANINGFUL]] : f32
+  // CHECK: }
+  // CHECK: "use"(%[[ASSUMING_RESULT]])
+  %0 = "test.source"() : () -> (!shape.witness)
+  %1:2 = shape.assuming %0 -> (f32, f32) {
+    %2 = "produce.redundant"() : () -> (f32)
+    %3 = "produce.meaningful"() : () -> (f32)
+    shape.assuming_yield %2, %3 : f32, f32
+  }
+  "use"(%1#1) : (f32) -> ()
+  return
+}
+
 // -----
 // Broadcastable with broadcastable constant shapes can be removed.
 // CHECK-LABEL: func @f


        


More information about the Mlir-commits mailing list