[Mlir-commits] [mlir] 9337788 - [mlir] add scf.if op canonicalization pattern that removes unused results

Tobias Gysi llvmlistbot at llvm.org
Sun Oct 11 01:44:55 PDT 2020


Author: Tobias Gysi
Date: 2020-10-11T10:40:28+02:00
New Revision: 93377888ae89560ba6d3976e2762d3d4724c4dfd

URL: https://github.com/llvm/llvm-project/commit/93377888ae89560ba6d3976e2762d3d4724c4dfd
DIFF: https://github.com/llvm/llvm-project/commit/93377888ae89560ba6d3976e2762d3d4724c4dfd.diff

LOG: [mlir] add scf.if op canonicalization pattern that removes unused results

The patch adds a canonicalization pattern that removes the unused results of scf.if operation. As a result, cse may remove unused computations in the then and else regions of the scf.if operation.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D89029

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index d7ff8b6352bb..476898ab2072 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -262,6 +262,8 @@ def IfOp : SCF_Op<"if",
                                : OpBuilder::atBlockEnd(body, listener);
     }
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 def ParallelOp : SCF_Op<"parallel",

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index e36ffc2e6b81..f25ccc454fbc 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -508,6 +508,67 @@ void IfOp::getSuccessorRegions(Optional<unsigned> index,
   regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
 }
 
+namespace {
+// Pattern to remove unused IfOp results.
+struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
+  using OpRewritePattern<IfOp>::OpRewritePattern;
+
+  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
+                    PatternRewriter &rewriter) const {
+    // Move all operations to the destination block.
+    rewriter.mergeBlocks(source, dest);
+    // Replace the yield op by one that returns only the used values.
+    auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
+    SmallVector<Value, 4> usedOperands;
+    llvm::transform(usedResults, std::back_inserter(usedOperands),
+                    [&](OpResult result) {
+                      return yieldOp.getOperand(result.getResultNumber());
+                    });
+    rewriter.updateRootInPlace(
+        yieldOp, [&]() { yieldOp.getOperation()->setOperands(usedOperands); });
+  }
+
+  LogicalResult matchAndRewrite(IfOp op,
+                                PatternRewriter &rewriter) const override {
+    // Compute the list of used results.
+    SmallVector<OpResult, 4> usedResults;
+    llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
+                  [](OpResult result) { return !result.use_empty(); });
+
+    // Replace the operation if only a subset of its results have uses.
+    if (usedResults.size() == op.getNumResults())
+      return failure();
+
+    // Compute the result types of the replacement operation.
+    SmallVector<Type, 4> newTypes;
+    llvm::transform(usedResults, std::back_inserter(newTypes),
+                    [](OpResult result) { return result.getType(); });
+
+    // Create a replacement operation with empty then and else regions.
+    auto emptyBuilder = [](OpBuilder &, Location) {};
+    auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.condition(),
+                                       emptyBuilder, emptyBuilder);
+
+    // Move the bodies and replace the terminators (note there is a then and
+    // an else region since the operation returns results).
+    transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
+    transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
+
+    // Replace the operation by the new one.
+    SmallVector<Value, 4> repResults(op.getNumResults());
+    for (auto en : llvm::enumerate(usedResults))
+      repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
+    rewriter.replaceOp(op, repResults);
+    return success();
+  }
+};
+} // namespace
+
+void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                       MLIRContext *context) {
+  results.insert<RemoveUnusedResults>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ParallelOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index fc98dabc0d2d..a96786076109 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -53,3 +53,87 @@ func @no_iteration(%A: memref<?x?xi32>) {
 // CHECK:             scf.yield
 // CHECK:           }
 // CHECK:           return
+
+// -----
+
+func @one_unused() -> (index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %true = constant true
+  %0, %1 = scf.if %true -> (index, index) {
+    scf.yield %c0, %c1 : index, index
+  } else {
+    scf.yield %c0, %c1 : index, index
+  }
+  return %1 : index
+}
+
+// CHECK-LABEL:   func @one_unused
+// CHECK:           [[C0:%.*]] = constant 1 : index
+// CHECK:           [[C1:%.*]] = constant true
+// CHECK:           [[V0:%.*]] = scf.if [[C1]] -> (index) {
+// CHECK:             scf.yield [[C0]] : index
+// CHECK:           } else
+// CHECK:             scf.yield [[C0]] : index
+// CHECK:           }
+// CHECK:           return [[V0]] : index
+
+// -----
+
+func @nested_unused() -> (index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %true = constant true
+  %0, %1 = scf.if %true -> (index, index) {
+    %2, %3 = scf.if %true -> (index, index) {
+      scf.yield %c0, %c1 : index, index
+    } else {
+      scf.yield %c0, %c1 : index, index
+    }
+    scf.yield %2, %3 : index, index
+  } else {
+    scf.yield %c0, %c1 : index, index
+  }
+  return %1 : index
+}
+
+// CHECK-LABEL:   func @nested_unused
+// CHECK:           [[C0:%.*]] = constant 1 : index
+// CHECK:           [[C1:%.*]] = constant true
+// CHECK:           [[V0:%.*]] = scf.if [[C1]] -> (index) {
+// CHECK:             [[V1:%.*]] = scf.if [[C1]] -> (index) {
+// CHECK:               scf.yield [[C0]] : index
+// CHECK:             } else
+// CHECK:               scf.yield [[C0]] : index
+// CHECK:             }
+// CHECK:             scf.yield [[V1]] : index
+// CHECK:           } else
+// CHECK:             scf.yield [[C0]] : index
+// CHECK:           }
+// CHECK:           return [[V0]] : index
+
+// -----
+
+func @side_effect() {}
+func @all_unused() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %true = constant true
+  %0, %1 = scf.if %true -> (index, index) {
+    call @side_effect() : () -> ()
+    scf.yield %c0, %c1 : index, index
+  } else {
+    call @side_effect() : () -> ()
+    scf.yield %c0, %c1 : index, index
+  }
+  return
+}
+
+// CHECK-LABEL:   func @all_unused
+// CHECK:           [[C1:%.*]] = constant true
+// CHECK:           scf.if [[C1]] {
+// CHECK:             call @side_effect() : () -> ()
+// CHECK:           } else
+// CHECK:             call @side_effect() : () -> ()
+// CHECK:           }
+// CHECK:           return


        


More information about the Mlir-commits mailing list