[Mlir-commits] [mlir] 5443d2e - [MLIR][SCF] Simplify scf.if by swapping regions if condition is a not

William S. Moses llvmlistbot at llvm.org
Tue Jan 11 09:57:33 PST 2022


Author: William S. Moses
Date: 2022-01-11T12:57:29-05:00
New Revision: 5443d2ed982dc8e0ccbf2089a78659baec4fcd37

URL: https://github.com/llvm/llvm-project/commit/5443d2ed982dc8e0ccbf2089a78659baec4fcd37
DIFF: https://github.com/llvm/llvm-project/commit/5443d2ed982dc8e0ccbf2089a78659baec4fcd37.diff

LOG: [MLIR][SCF] Simplify scf.if by swapping regions if condition is a not

Given an if of the form, simplify it by eliminating the not and swapping the regions

scf.if not(c) {
  yield origTrue
} else {
  yield origFalse
}

becomes

scf.if c {
  yield origFalse
} else {
  yield origTrue
}

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D116990

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index e3ae535f59d3c..c8e51692252d6 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -411,7 +411,7 @@ def IfOp : SCF_Op<"if",
     void getNumRegionInvocations(ArrayRef<Attribute> operands,
                                  SmallVectorImpl<int64_t> &countPerRegion);
   }];
-
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index dd47a55fe6b99..3d6d2052f7fe1 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -13,10 +13,10 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/InliningUtils.h"
-
 using namespace mlir;
 using namespace mlir::scf;
 
@@ -1199,6 +1199,30 @@ void IfOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
   }
 }
 
+LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
+                         SmallVectorImpl<OpFoldResult> &results) {
+  // if (!c) then A() else B() -> if c then B() else A()
+  if (getElseRegion().empty())
+    return failure();
+
+  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
+  if (!xorStmt)
+    return failure();
+
+  if (!matchPattern(xorStmt.getRhs(), m_One()))
+    return failure();
+
+  getConditionMutable().assign(xorStmt.getLhs());
+  Block *thenBlock = &getThenRegion().front();
+  // It would be nicer to use iplist::swap, but that has no implemented
+  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
+  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
+                                     getElseRegion().getBlocks());
+  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
+                                     getThenRegion().getBlocks(), thenBlock);
+  return success();
+}
+
 namespace {
 // Pattern to remove unused IfOp results.
 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index d946c55e11abd..7f424d892b764 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -447,6 +447,29 @@ func @merge_nested_if(%arg0: i1, %arg1: i1) {
 
 // -----
 
+// CHECK-LABEL:   func @if_condition_swap
+// CHECK-NEXT:     %{{.*}} = scf.if %arg0 -> (index) {
+// CHECK-NEXT:       %[[i1:.+]] = "test.origFalse"() : () -> index
+// CHECK-NEXT:       scf.yield %[[i1]] : index
+// CHECK-NEXT:     } else {
+// CHECK-NEXT:       %[[i2:.+]] = "test.origTrue"() : () -> index
+// CHECK-NEXT:       scf.yield %[[i2]] : index
+// CHECK-NEXT:     }
+func @if_condition_swap(%cond: i1) -> index {
+  %true = arith.constant true
+  %not = arith.xori %cond, %true : i1
+  %0 = scf.if %not -> (index) {
+    %1 = "test.origTrue"() : () -> index
+    scf.yield %1 : index
+  } else {
+    %1 = "test.origFalse"() : () -> index
+    scf.yield %1 : index
+  }
+  return %0 : index
+}
+
+// -----
+
 // CHECK-LABEL: @remove_zero_iteration_loop
 func @remove_zero_iteration_loop() {
   %c42 = arith.constant 42 : index


        


More information about the Mlir-commits mailing list