[Mlir-commits] [mlir] 62f84c7 - [MLIR][SCF] Allow combining subsequent if statements that yield & negated condition

William S. Moses llvmlistbot at llvm.org
Fri Mar 4 09:07:52 PST 2022


Author: William S. Moses
Date: 2022-03-04T12:07:47-05:00
New Revision: 62f84c73d23a68c1a22391d97e578a304795c86c

URL: https://github.com/llvm/llvm-project/commit/62f84c73d23a68c1a22391d97e578a304795c86c
DIFF: https://github.com/llvm/llvm-project/commit/62f84c73d23a68c1a22391d97e578a304795c86c.diff

LOG: [MLIR][SCF] Allow combining subsequent if statements that yield & negated condition

This patch extends the existing if combining canonicalization to also handle the case where a value returned by the first if is used within the body of the second if.

This patch also extends if combining to support if's whose conditions are logical negations of each other.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D120924

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index d2eceadb7d78b..23ef80dfb02d8 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1519,51 +1519,98 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
     if (!prevIf)
       return failure();
 
-    if (nextIf.getCondition() != prevIf.getCondition())
-      return failure();
+    // Determine the logical then/else blocks when prevIf's
+    // condition is used. Null means the block does not exist
+    // in that case (e.g. empty else). If neither of these
+    // are set, the two conditions cannot be compared.
+    Block *nextThen = nullptr;
+    Block *nextElse = nullptr;
+    if (nextIf.getCondition() == prevIf.getCondition()) {
+      nextThen = nextIf.thenBlock();
+      if (!nextIf.getElseRegion().empty())
+        nextElse = nextIf.elseBlock();
+    }
+    if (arith::XOrIOp notv =
+            nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
+      if (notv.getLhs() == prevIf.getCondition() &&
+          matchPattern(notv.getRhs(), m_One())) {
+        nextElse = nextIf.thenBlock();
+        if (!nextIf.getElseRegion().empty())
+          nextThen = nextIf.elseBlock();
+      }
+    }
+    if (arith::XOrIOp notv =
+            prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
+      if (notv.getLhs() == nextIf.getCondition() &&
+          matchPattern(notv.getRhs(), m_One())) {
+        nextElse = nextIf.thenBlock();
+        if (!nextIf.getElseRegion().empty())
+          nextThen = nextIf.elseBlock();
+      }
+    }
 
-    // Don't permit merging if a result of the first if is used
-    // within the second.
-    if (llvm::any_of(prevIf->getUsers(),
-                     [&](Operation *user) { return nextIf->isAncestor(user); }))
+    if (!nextThen && !nextElse)
       return failure();
 
+    SmallVector<Value> prevElseYielded;
+    if (!prevIf.getElseRegion().empty())
+      prevElseYielded = prevIf.elseYield().getOperands();
+    // Replace all uses of return values of op within nextIf with the
+    // corresponding yields
+    for (auto it : llvm::zip(prevIf.getResults(),
+                             prevIf.thenYield().getOperands(), prevElseYielded))
+      for (OpOperand &use :
+           llvm::make_early_inc_range(std::get<0>(it).getUses())) {
+        if (nextThen && nextThen->getParent()->isAncestor(
+                            use.getOwner()->getParentRegion())) {
+          rewriter.startRootUpdate(use.getOwner());
+          use.set(std::get<1>(it));
+          rewriter.finalizeRootUpdate(use.getOwner());
+        } else if (nextElse && nextElse->getParent()->isAncestor(
+                                   use.getOwner()->getParentRegion())) {
+          rewriter.startRootUpdate(use.getOwner());
+          use.set(std::get<2>(it));
+          rewriter.finalizeRootUpdate(use.getOwner());
+        }
+      }
+
     SmallVector<Type> mergedTypes(prevIf.getResultTypes());
     llvm::append_range(mergedTypes, nextIf.getResultTypes());
 
     IfOp combinedIf = rewriter.create<IfOp>(
-        nextIf.getLoc(), mergedTypes, nextIf.getCondition(), /*hasElse=*/false);
+        nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
     rewriter.eraseBlock(&combinedIf.getThenRegion().back());
 
-    YieldOp thenYield = prevIf.thenYield();
-    YieldOp thenYield2 = nextIf.thenYield();
-
-    combinedIf.getThenRegion().getBlocks().splice(
-        combinedIf.getThenRegion().getBlocks().begin(),
-        prevIf.getThenRegion().getBlocks());
-
-    rewriter.mergeBlocks(nextIf.thenBlock(), combinedIf.thenBlock());
-    rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
-
-    SmallVector<Value> mergedYields(thenYield.getOperands());
-    llvm::append_range(mergedYields, thenYield2.getOperands());
-    rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
-    rewriter.eraseOp(thenYield);
-    rewriter.eraseOp(thenYield2);
+    rewriter.inlineRegionBefore(prevIf.getThenRegion(),
+                                combinedIf.getThenRegion(),
+                                combinedIf.getThenRegion().begin());
+
+    if (nextThen) {
+      YieldOp thenYield = combinedIf.thenYield();
+      YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
+      rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
+      rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
+
+      SmallVector<Value> mergedYields(thenYield.getOperands());
+      llvm::append_range(mergedYields, thenYield2.getOperands());
+      rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
+      rewriter.eraseOp(thenYield);
+      rewriter.eraseOp(thenYield2);
+    }
 
-    combinedIf.getElseRegion().getBlocks().splice(
-        combinedIf.getElseRegion().getBlocks().begin(),
-        prevIf.getElseRegion().getBlocks());
+    rewriter.inlineRegionBefore(prevIf.getElseRegion(),
+                                combinedIf.getElseRegion(),
+                                combinedIf.getElseRegion().begin());
 
-    if (!nextIf.getElseRegion().empty()) {
+    if (nextElse) {
       if (combinedIf.getElseRegion().empty()) {
-        combinedIf.getElseRegion().getBlocks().splice(
-            combinedIf.getElseRegion().getBlocks().begin(),
-            nextIf.getElseRegion().getBlocks());
+        rewriter.inlineRegionBefore(*nextElse->getParent(),
+                                    combinedIf.getElseRegion(),
+                                    combinedIf.getElseRegion().begin());
       } else {
         YieldOp elseYield = combinedIf.elseYield();
-        YieldOp elseYield2 = nextIf.elseYield();
-        rewriter.mergeBlocks(nextIf.elseBlock(), combinedIf.elseBlock());
+        YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
+        rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
 
         rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
 

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 80af8174d2f59..86c478ec4eb68 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1119,6 +1119,79 @@ func @combineIfs4(%arg0 : i1, %arg2: i64) {
 // CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
 // CHECK-NEXT:     }
 
+// CHECK-LABEL: @combineIfsUsed
+// CHECK-SAME: %[[arg0:.+]]: i1
+func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) {
+  %res = scf.if %arg0 -> i32 {
+    %v = "test.firstCodeTrue"() : () -> i32
+    scf.yield %v : i32
+  } else {
+    %v2 = "test.firstCodeFalse"() : () -> i32
+    scf.yield %v2 : i32
+  }
+  %res2 = scf.if %arg0 -> i32 {
+    %v = "test.secondCodeTrue"(%res) : (i32) -> i32
+    scf.yield %v : i32
+  } else {
+    %v2 = "test.secondCodeFalse"(%res) : (i32) -> i32
+    scf.yield %v2 : i32
+  }
+  return %res, %res2 : i32, i32
+}
+// CHECK-NEXT:     %[[res:.+]]:2 = scf.if %[[arg0]] -> (i32, i32) {
+// CHECK-NEXT:       %[[tval0:.+]] = "test.firstCodeTrue"() : () -> i32
+// CHECK-NEXT:       %[[tval:.+]] = "test.secondCodeTrue"(%[[tval0]]) : (i32) -> i32
+// CHECK-NEXT:       scf.yield %[[tval0]], %[[tval]] : i32, i32
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       %[[fval0:.+]] = "test.firstCodeFalse"() : () -> i32
+// CHECK-NEXT:       %[[fval:.+]] = "test.secondCodeFalse"(%[[fval0]]) : (i32) -> i32
+// CHECK-NEXT:       scf.yield %[[fval0]], %[[fval]] : i32, i32
+// CHECK-NEXT:     }
+// CHECK-NEXT:     return %[[res]]#0, %[[res]]#1 : i32, i32
+
+// CHECK-LABEL: @combineIfsNot
+// CHECK-SAME: %[[arg0:.+]]: i1
+func @combineIfsNot(%arg0 : i1, %arg2: i64) {
+  %true = arith.constant true
+  %not = arith.xori %arg0, %true : i1
+  scf.if %arg0 {
+    "test.firstCodeTrue"() : () -> ()
+    scf.yield
+  }
+  scf.if %not {
+    "test.secondCodeTrue"() : () -> ()
+    scf.yield
+  }
+  return
+}
+
+// CHECK-NEXT:     scf.if %[[arg0]] {
+// CHECK-NEXT:       "test.firstCodeTrue"() : () -> ()
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
+// CHECK-NEXT:     }
+
+// CHECK-LABEL: @combineIfsNot2
+// CHECK-SAME: %[[arg0:.+]]: i1
+func @combineIfsNot2(%arg0 : i1, %arg2: i64) {
+  %true = arith.constant true
+  %not = arith.xori %arg0, %true : i1
+  scf.if %not {
+    "test.firstCodeTrue"() : () -> ()
+    scf.yield
+  }
+  scf.if %arg0 {
+    "test.secondCodeTrue"() : () -> ()
+    scf.yield
+  }
+  return
+}
+
+// CHECK-NEXT:     scf.if %[[arg0]] {
+// CHECK-NEXT:       "test.secondCodeTrue"() : () -> ()
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       "test.firstCodeTrue"() : () -> ()
+// CHECK-NEXT:     }
 // -----
 
 // CHECK-LABEL: func @propagate_into_execute_region


        


More information about the Mlir-commits mailing list