[Mlir-commits] [mlir] 5a31403 - [MLIR][SCF] Create selects from if yield results which are not defined in the body
William S. Moses
llvmlistbot at llvm.org
Fri Mar 18 08:38:34 PDT 2022
Author: William S. Moses
Date: 2022-03-18T11:38:29-04:00
New Revision: 5a31403b513d7449c4f045347740056800996357
URL: https://github.com/llvm/llvm-project/commit/5a31403b513d7449c4f045347740056800996357
DIFF: https://github.com/llvm/llvm-project/commit/5a31403b513d7449c4f045347740056800996357.diff
LOG: [MLIR][SCF] Create selects from if yield results which are not defined in the body
Previously, the canonicalizer to create ifs from selects would only work
if the if did not have a body other than yielding. This patch upgrade the functionality
to be able to create selects from any if result whose operands are not defined
within the body.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D121943
Added:
Modified:
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index cd2b8a0b45055..d888a341f0781 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1326,6 +1326,8 @@ struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
}
};
+/// Hoist any yielded results whose operands are defined outside
+/// the if, to a select instruction.
struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
@@ -1334,31 +1336,58 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
if (op->getNumResults() == 0)
return failure();
- if (!llvm::hasSingleElement(op.getThenRegion().front()) ||
- !llvm::hasSingleElement(op.getElseRegion().front()))
+ auto cond = op.getCondition();
+ auto thenYieldArgs = op.thenYield().getOperands();
+ auto elseYieldArgs = op.elseYield().getOperands();
+
+ SmallVector<Type> nonHoistable;
+ for (const auto &it :
+ llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
+ Value trueVal = std::get<0>(it.value());
+ Value falseVal = std::get<1>(it.value());
+ if (&op.getThenRegion() == trueVal.getParentRegion() ||
+ &op.getElseRegion() == falseVal.getParentRegion())
+ nonHoistable.push_back(trueVal.getType());
+ }
+ // Early exit if there aren't any yielded values we can
+ // hoist outside the if.
+ if (nonHoistable.size() == op->getNumResults())
return failure();
- auto cond = op.getCondition();
- auto thenYieldArgs =
- cast<scf::YieldOp>(op.getThenRegion().front().getTerminator())
- .getOperands();
- auto elseYieldArgs =
- cast<scf::YieldOp>(op.getElseRegion().front().getTerminator())
- .getOperands();
+ IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
+ if (replacement.thenBlock())
+ rewriter.eraseBlock(replacement.thenBlock());
+ replacement.getThenRegion().takeBody(op.getThenRegion());
+ replacement.getElseRegion().takeBody(op.getElseRegion());
+
SmallVector<Value> results(op->getNumResults());
assert(thenYieldArgs.size() == results.size());
assert(elseYieldArgs.size() == results.size());
+
+ SmallVector<Value> trueYields;
+ SmallVector<Value> falseYields;
for (const auto &it :
llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
Value trueVal = std::get<0>(it.value());
Value falseVal = std::get<1>(it.value());
- if (trueVal == falseVal)
+ if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
+ &replacement.getElseRegion() == falseVal.getParentRegion()) {
+ results[it.index()] = replacement.getResult(trueYields.size());
+ trueYields.push_back(trueVal);
+ falseYields.push_back(falseVal);
+ } else if (trueVal == falseVal)
results[it.index()] = trueVal;
else
results[it.index()] = rewriter.create<arith::SelectOp>(
op.getLoc(), cond, trueVal, falseVal);
}
+ rewriter.setInsertionPointToEnd(replacement.thenBlock());
+ rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
+
+ rewriter.setInsertionPointToEnd(replacement.elseBlock());
+ rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
+
rewriter.replaceOp(op, results);
return success();
}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 6edf9ff9ddff7..f06e9833d1211 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -136,26 +136,26 @@ func @nested_parallel(%0: memref<?x?x?xf64>) -> memref<?x?x?xf64> {
func private @side_effect()
func @one_unused(%cond: i1) -> (index) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c3 = arith.constant 3 : index
%0, %1 = scf.if %cond -> (index, index) {
call @side_effect() : () -> ()
+ %c0 = "test.value0"() : () -> (index)
+ %c1 = "test.value1"() : () -> (index)
scf.yield %c0, %c1 : index, index
} else {
+ %c2 = "test.value2"() : () -> (index)
+ %c3 = "test.value3"() : () -> (index)
scf.yield %c2, %c3 : index, index
}
return %1 : index
}
// CHECK-LABEL: func @one_unused
-// CHECK-DAG: [[C0:%.*]] = arith.constant 1 : index
-// CHECK-DAG: [[C3:%.*]] = arith.constant 3 : index
// CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) {
// CHECK: call @side_effect() : () -> ()
-// CHECK: scf.yield [[C0]] : index
+// CHECK: [[C1:%.*]] = "test.value1"
+// CHECK: scf.yield [[C1]] : index
// CHECK: } else
+// CHECK: [[C3:%.*]] = "test.value3"
// CHECK: scf.yield [[C3]] : index
// CHECK: }
// CHECK: return [[V0]] : index
@@ -164,37 +164,40 @@ func @one_unused(%cond: i1) -> (index) {
func private @side_effect()
func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c3 = arith.constant 3 : index
%0, %1 = scf.if %cond1 -> (index, index) {
%2, %3 = scf.if %cond2 -> (index, index) {
call @side_effect() : () -> ()
+ %c0 = "test.value0"() : () -> (index)
+ %c1 = "test.value1"() : () -> (index)
scf.yield %c0, %c1 : index, index
} else {
+ %c2 = "test.value2"() : () -> (index)
+ %c3 = "test.value3"() : () -> (index)
scf.yield %c2, %c3 : index, index
}
scf.yield %2, %3 : index, index
} else {
+ %c0 = "test.value0_2"() : () -> (index)
+ %c1 = "test.value1_2"() : () -> (index)
scf.yield %c0, %c1 : index, index
}
return %1 : index
}
// CHECK-LABEL: func @nested_unused
-// CHECK-DAG: [[C0:%.*]] = arith.constant 1 : index
-// CHECK-DAG: [[C3:%.*]] = arith.constant 3 : index
// CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) {
// CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) {
// CHECK: call @side_effect() : () -> ()
-// CHECK: scf.yield [[C0]] : index
+// CHECK: [[C1:%.*]] = "test.value1"
+// CHECK: scf.yield [[C1]] : index
// CHECK: } else
+// CHECK: [[C3:%.*]] = "test.value3"
// CHECK: scf.yield [[C3]] : index
// CHECK: }
// CHECK: scf.yield [[V1]] : index
// CHECK: } else
-// CHECK: scf.yield [[C0]] : index
+// CHECK: [[C1_2:%.*]] = "test.value1_2"
+// CHECK: scf.yield [[C1_2]] : index
// CHECK: }
// CHECK: return [[V0]] : index
@@ -302,6 +305,27 @@ func @to_select_same_val(%cond: i1) -> (index, index) {
// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
// CHECK: return [[V0]], [[C1]] : index, index
+
+func @to_select_with_body(%cond: i1) -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = scf.if %cond -> index {
+ "test.op"() : () -> ()
+ scf.yield %c0 : index
+ } else {
+ scf.yield %c1 : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @to_select_with_body
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK: scf.if {{.*}} {
+// CHECK: "test.op"() : () -> ()
+// CHECK: }
+// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
+// CHECK: return [[V0]] : index
// -----
func @to_select2(%cond: i1) -> (index, index) {
@@ -731,38 +755,32 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
// CHECK-LABEL: @cond_prop
func @cond_prop(%arg0 : i1) -> index {
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c3 = arith.constant 3 : index
- %c4 = arith.constant 4 : index
%res = scf.if %arg0 -> index {
%res1 = scf.if %arg0 -> index {
- %v1 = "test.get_some_value"() : () -> i32
- scf.yield %c1 : index
+ %v1 = "test.get_some_value1"() : () -> index
+ scf.yield %v1 : index
} else {
- %v2 = "test.get_some_value"() : () -> i32
- scf.yield %c2 : index
+ %v2 = "test.get_some_value2"() : () -> index
+ scf.yield %v2 : index
}
scf.yield %res1 : index
} else {
%res2 = scf.if %arg0 -> index {
- %v3 = "test.get_some_value"() : () -> i32
- scf.yield %c3 : index
+ %v3 = "test.get_some_value3"() : () -> index
+ scf.yield %v3 : index
} else {
- %v4 = "test.get_some_value"() : () -> i32
- scf.yield %c4 : index
+ %v4 = "test.get_some_value4"() : () -> index
+ scf.yield %v4 : index
}
scf.yield %res2 : index
}
return %res : index
}
-// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
-// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index
// CHECK-NEXT: scf.yield %[[c1]] : index
// CHECK-NEXT: } else {
-// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32
+// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index
// CHECK-NEXT: scf.yield %[[c4]] : index
// CHECK-NEXT: }
// CHECK-NEXT: return %[[if]] : index
@@ -808,7 +826,6 @@ func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
return %res#0, %res#1 : i32, i1
}
// CHECK-NEXT: %true = arith.constant true
-// CHECK-NEXT: %[[toret:.+]] = arith.xori %arg0, %true : i1
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
// CHECK-NEXT: scf.yield %[[sv1]] : i32
@@ -816,6 +833,7 @@ func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
// CHECK-NEXT: scf.yield %[[sv2]] : i32
// CHECK-NEXT: }
+// CHECK-NEXT: %[[toret:.+]] = arith.xori %arg0, %true : i1
// CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1
// -----
More information about the Mlir-commits
mailing list