[Mlir-commits] [mlir] f4a2dbf - [MLIR][SCF] Combine adjacent scf.if with same condition

William S. Moses llvmlistbot at llvm.org
Tue May 4 21:40:15 PDT 2021


Author: William S. Moses
Date: 2021-05-05T00:39:58-04:00
New Revision: f4a2dbfe29031f02c02d6045159f22785dd611cf

URL: https://github.com/llvm/llvm-project/commit/f4a2dbfe29031f02c02d6045159f22785dd611cf
DIFF: https://github.com/llvm/llvm-project/commit/f4a2dbfe29031f02c02d6045159f22785dd611cf.diff

LOG: [MLIR][SCF] Combine adjacent scf.if with same condition

Differential Revision: https://reviews.llvm.org/D101798

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index c3c64e04ae08..fcaba0e6f2b2 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -311,6 +311,10 @@ def IfOp : SCF_Op<"if",
       return results().empty() ? OpBuilder::atBlockTerminator(body, listener)
                                : OpBuilder::atBlockEnd(body, listener);
     }
+    Block* thenBlock();
+    YieldOp thenYield();
+    Block* elseBlock();
+    YieldOp elseYield();
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index c28e438fc819..91f1e7a3e7c0 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1266,15 +1266,125 @@ struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
   }
 };
 
+/// Merge any consecutive scf.if's with the same condition.
+///
+///    scf.if %cond {
+///       firstCodeTrue();...
+///    } else {
+///       firstCodeFalse();...
+///    }
+///    %res = scf.if %cond {
+///       secondCodeTrue();...
+///    } else {
+///       secondCodeFalse();...
+///    }
+///
+///  becomes
+///    %res = scf.if %cmp {
+///       firstCodeTrue();...
+///       secondCodeTrue();...
+///    } else {
+///       firstCodeFalse();...
+///       secondCodeFalse();...
+///    }
+struct CombineIfs : public OpRewritePattern<IfOp> {
+  using OpRewritePattern<IfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IfOp nextIf,
+                                PatternRewriter &rewriter) const override {
+    Block *parent = nextIf->getBlock();
+    if (nextIf == &parent->front())
+      return failure();
+
+    auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
+    if (!prevIf)
+      return failure();
+
+    if (nextIf.condition() != prevIf.condition())
+      return failure();
+
+    // Don't permit merging if a result of the first if is used
+    // within the second.
+    if (llvm::any_of(prevIf->getUsers(),
+                     [&](Operation *user) { return nextIf->isAncestor(user); }))
+      return failure();
+
+    SmallVector<Type> mergedTypes(prevIf.getResultTypes());
+    llvm::append_range(mergedTypes, nextIf.getResultTypes());
+
+    IfOp combinedIf = rewriter.create<IfOp>(
+        nextIf.getLoc(), mergedTypes, nextIf.condition(), /*hasElse=*/false);
+    rewriter.eraseBlock(&combinedIf.thenRegion().back());
+
+    YieldOp thenYield = prevIf.thenYield();
+    YieldOp thenYield2 = nextIf.thenYield();
+
+    combinedIf.thenRegion().getBlocks().splice(
+        combinedIf.thenRegion().getBlocks().begin(),
+        prevIf.thenRegion().getBlocks());
+
+    rewriter.mergeBlocks(nextIf.thenBlock(), combinedIf.thenBlock());
+    rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
+
+    SmallVector<Value> mergedYields(thenYield.getOperands());
+    llvm::append_range(mergedYields, thenYield2.getOperands());
+    rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
+    rewriter.eraseOp(thenYield);
+    rewriter.eraseOp(thenYield2);
+
+    combinedIf.elseRegion().getBlocks().splice(
+        combinedIf.elseRegion().getBlocks().begin(),
+        prevIf.elseRegion().getBlocks());
+
+    if (!nextIf.elseRegion().empty()) {
+      if (combinedIf.elseRegion().empty()) {
+        combinedIf.elseRegion().getBlocks().splice(
+            combinedIf.elseRegion().getBlocks().begin(),
+            nextIf.elseRegion().getBlocks());
+      } else {
+        YieldOp elseYield = combinedIf.elseYield();
+        YieldOp elseYield2 = nextIf.elseYield();
+        rewriter.mergeBlocks(nextIf.elseBlock(), combinedIf.elseBlock());
+
+        rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
+
+        SmallVector<Value> mergedElseYields(elseYield.getOperands());
+        llvm::append_range(mergedElseYields, elseYield2.getOperands());
+
+        rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
+        rewriter.eraseOp(elseYield);
+        rewriter.eraseOp(elseYield2);
+      }
+    }
+
+    SmallVector<Value> prevValues;
+    SmallVector<Value> nextValues;
+    for (auto pair : llvm::enumerate(combinedIf.getResults())) {
+      if (pair.index() < prevIf.getNumResults())
+        prevValues.push_back(pair.value());
+      else
+        nextValues.push_back(pair.value());
+    }
+    rewriter.replaceOp(prevIf, prevValues);
+    rewriter.replaceOp(nextIf, nextValues);
+    return success();
+  }
+};
+
 } // namespace
 
 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                        MLIRContext *context) {
-  results
-      .add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
-           ConditionPropagation, ReplaceIfYieldWithConditionOrValue>(context);
+  results.add<RemoveUnusedResults, RemoveStaticCondition,
+              ConvertTrivialIfToSelect, ConditionPropagation,
+              ReplaceIfYieldWithConditionOrValue, CombineIfs>(context);
 }
 
+Block *IfOp::thenBlock() { return &thenRegion().back(); }
+YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
+Block *IfOp::elseBlock() { return &elseRegion().back(); }
+YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
+
 //===----------------------------------------------------------------------===//
 // ParallelOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 3ba8e8023155..8f12c90b7729 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -747,3 +747,104 @@ func @while_cond_true() {
 // CHECK-NEXT:           "test.use"(%[[true]]) : (i1) -> ()
 // CHECK-NEXT:           scf.yield
 // CHECK-NEXT:         }
+
+// -----
+
+// CHECK-LABEL: @combineIfs
+func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) {
+  %res = scf.if %arg0 -> i32 {
+    %v = "test.firstCodeTrue"() : () -> i32
+    scf.yield %v : i32
+  } else {
+    %v2 = "test.firstCodeFalse"() : () -> i32
+    scf.yield %v2 : i32
+  }
+  %res2 = scf.if %arg0 -> i32 {
+    %v = "test.secondCodeTrue"() : () -> i32
+    scf.yield %v : i32
+  } else {
+    %v2 = "test.secondCodeFalse"() : () -> i32
+    scf.yield %v2 : i32
+  }
+  return %res, %res2 : i32, i32
+}
+// CHECK-NEXT:     %[[res:.+]]:2 = scf.if %arg0 -> (i32, i32) {
+// CHECK-NEXT:       %[[tval0:.+]] = "test.firstCodeTrue"() : () -> i32
+// CHECK-NEXT:       %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32
+// CHECK-NEXT:       scf.yield %[[tval0]], %[[tval]] : i32, i32
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       %[[fval0:.+]] = "test.firstCodeFalse"() : () -> i32
+// CHECK-NEXT:       %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32
+// CHECK-NEXT:       scf.yield %[[fval0]], %[[fval]] : i32, i32
+// CHECK-NEXT:     }
+// CHECK-NEXT:     return %[[res]]#0, %[[res]]#1 : i32, i32
+
+
+// CHECK-LABEL: @combineIfs2
+func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 {
+  scf.if %arg0 {
+    "test.firstCodeTrue"() : () -> ()
+    scf.yield
+  }
+  %res = scf.if %arg0 -> i32 {
+    %v = "test.secondCodeTrue"() : () -> i32
+    scf.yield %v : i32
+  } else {
+    %v2 = "test.secondCodeFalse"() : () -> i32
+    scf.yield %v2 : i32
+  }
+  return %res : i32
+}
+// CHECK-NEXT:     %[[res:.+]] = scf.if %arg0 -> (i32) {
+// CHECK-NEXT:       "test.firstCodeTrue"() : () -> ()
+// CHECK-NEXT:       %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32
+// CHECK-NEXT:       scf.yield %[[tval]] : i32
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32
+// CHECK-NEXT:       scf.yield %[[fval]] : i32
+// CHECK-NEXT:     }
+// CHECK-NEXT:     return %[[res]] : i32
+
+
+// CHECK-LABEL: @combineIfs3
+func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 {
+  %res = scf.if %arg0 -> i32 {
+    %v = "test.firstCodeTrue"() : () -> i32
+    scf.yield %v : i32
+  } else {
+    %v2 = "test.firstCodeFalse"() : () -> i32
+    scf.yield %v2 : i32
+  }
+  scf.if %arg0 {
+    "test.secondCodeTrue"() : () -> ()
+    scf.yield
+  }
+  return %res : i32
+}
+// CHECK-NEXT:     %[[res:.+]] = scf.if %arg0 -> (i32) {
+// CHECK-NEXT:       %[[tval:.+]] = "test.firstCodeTrue"() : () -> i32
+// CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
+// CHECK-NEXT:       scf.yield %[[tval]] : i32
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       %[[fval:.+]] = "test.firstCodeFalse"() : () -> i32
+// CHECK-NEXT:       scf.yield %[[fval]] : i32
+// CHECK-NEXT:     }
+// CHECK-NEXT:     return %[[res]] : i32
+
+// CHECK-LABEL: @combineIfs4
+func @combineIfs4(%arg0 : i1, %arg2: i64) {
+  scf.if %arg0 {
+    "test.firstCodeTrue"() : () -> ()
+    scf.yield
+  }
+  scf.if %arg0 {
+    "test.secondCodeTrue"() : () -> ()
+    scf.yield
+  }
+  return
+}
+
+// CHECK-NEXT:     scf.if %arg0 {
+// CHECK-NEXT:       "test.firstCodeTrue"() : () -> ()
+// CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
+// CHECK-NEXT:     }


        


More information about the Mlir-commits mailing list