[Mlir-commits] [mlir] 5443d2e - [MLIR][SCF] Simplify scf.if by swapping regions if condition is a not
William S. Moses
llvmlistbot at llvm.org
Tue Jan 11 09:57:33 PST 2022
Author: William S. Moses
Date: 2022-01-11T12:57:29-05:00
New Revision: 5443d2ed982dc8e0ccbf2089a78659baec4fcd37
URL: https://github.com/llvm/llvm-project/commit/5443d2ed982dc8e0ccbf2089a78659baec4fcd37
DIFF: https://github.com/llvm/llvm-project/commit/5443d2ed982dc8e0ccbf2089a78659baec4fcd37.diff
LOG: [MLIR][SCF] Simplify scf.if by swapping regions if condition is a not
Given an if of the form, simplify it by eliminating the not and swapping the regions
scf.if not(c) {
yield origTrue
} else {
yield origFalse
}
becomes
scf.if c {
yield origFalse
} else {
yield origTrue
}
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D116990
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index e3ae535f59d3c..c8e51692252d6 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -411,7 +411,7 @@ def IfOp : SCF_Op<"if",
void getNumRegionInvocations(ArrayRef<Attribute> operands,
SmallVectorImpl<int64_t> &countPerRegion);
}];
-
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index dd47a55fe6b99..3d6d2052f7fe1 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -13,10 +13,10 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
-
using namespace mlir;
using namespace mlir::scf;
@@ -1199,6 +1199,30 @@ void IfOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
}
}
+LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ // if (!c) then A() else B() -> if c then B() else A()
+ if (getElseRegion().empty())
+ return failure();
+
+ arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
+ if (!xorStmt)
+ return failure();
+
+ if (!matchPattern(xorStmt.getRhs(), m_One()))
+ return failure();
+
+ getConditionMutable().assign(xorStmt.getLhs());
+ Block *thenBlock = &getThenRegion().front();
+ // It would be nicer to use iplist::swap, but that has no implemented
+ // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
+ getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
+ getElseRegion().getBlocks());
+ getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
+ getThenRegion().getBlocks(), thenBlock);
+ return success();
+}
+
namespace {
// Pattern to remove unused IfOp results.
struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index d946c55e11abd..7f424d892b764 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -447,6 +447,29 @@ func @merge_nested_if(%arg0: i1, %arg1: i1) {
// -----
+// CHECK-LABEL: func @if_condition_swap
+// CHECK-NEXT: %{{.*}} = scf.if %arg0 -> (index) {
+// CHECK-NEXT: %[[i1:.+]] = "test.origFalse"() : () -> index
+// CHECK-NEXT: scf.yield %[[i1]] : index
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[i2:.+]] = "test.origTrue"() : () -> index
+// CHECK-NEXT: scf.yield %[[i2]] : index
+// CHECK-NEXT: }
+func @if_condition_swap(%cond: i1) -> index {
+ %true = arith.constant true
+ %not = arith.xori %cond, %true : i1
+ %0 = scf.if %not -> (index) {
+ %1 = "test.origTrue"() : () -> index
+ scf.yield %1 : index
+ } else {
+ %1 = "test.origFalse"() : () -> index
+ scf.yield %1 : index
+ }
+ return %0 : index
+}
+
+// -----
+
// CHECK-LABEL: @remove_zero_iteration_loop
func @remove_zero_iteration_loop() {
%c42 = arith.constant 42 : index
More information about the Mlir-commits
mailing list