[Mlir-commits] [mlir] 1e3a021 - [mlir][scf] Update IfOp to have getInvocationBounds
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 27 15:15:56 PST 2022
Author: Mogball
Date: 2022-01-27T23:15:53Z
New Revision: 1e3a02162db20264e9615b1346420c8d199cb347
URL: https://github.com/llvm/llvm-project/commit/1e3a02162db20264e9615b1346420c8d199cb347
DIFF: https://github.com/llvm/llvm-project/commit/1e3a02162db20264e9615b1346420c8d199cb347.diff
LOG: [mlir][scf] Update IfOp to have getInvocationBounds
This allows `scf.if` to be used by Control-Flow sink.
Depends on D115088
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D115089
Added:
mlir/test/Dialect/SCF/control-flow-sink.mlir
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/lib/Dialect/SCF/SCF.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 74834c03cae13..423fcbd19e0b4 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -315,7 +315,9 @@ def ForOp : SCF_Op<"for",
}
def IfOp : SCF_Op<"if",
- [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getNumRegionInvocations",
+ "getRegionInvocationBounds"]>,
SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects,
NoRegionArguments]> {
let summary = "if-then-else operation";
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 896eb501d3799..1fadf5bf57edc 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -140,9 +140,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
of invocations cannot be statically determined, then it will not have a
value (i.e., it is set to `llvm::None`).
- `operands` is a set of optional attributes that either correspond to a
- constant values for each operand of this operation, or null if that
+ `operands` is a set of optional attributes that either correspond to
+ constant values for each operand of this operation or null if that
operand is not a constant.
+
+ This method may be called speculatively on operations where the provided
+ operands are not necessarily the same as the operation's current
+ operands. This may occur in analyses that wish to determine "what would
+ be the region invocations if these were the operands?"
}],
"void", "getRegionInvocationBounds",
(ins "::mlir::ArrayRef<::mlir::Attribute>":$operands,
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 382e3c662009c..b05b10584b334 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1188,6 +1188,20 @@ LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
return success();
}
+void IfOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<InvocationBounds> &invocationBounds) {
+ if (auto cond = operands[0].dyn_cast_or_null<BoolAttr>()) {
+ // If the condition is known, then one region is known to be executed once
+ // and the other zero times.
+ invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
+ invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
+ } else {
+ // Non-constant condition. Each region may be executed 0 or 1 times.
+ invocationBounds.assign(2, {0, 1});
+ }
+}
+
namespace {
// Pattern to remove unused IfOp results.
struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
diff --git a/mlir/test/Dialect/SCF/control-flow-sink.mlir b/mlir/test/Dialect/SCF/control-flow-sink.mlir
new file mode 100644
index 0000000000000..787c8d0c2914a
--- /dev/null
+++ b/mlir/test/Dialect/SCF/control-flow-sink.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt -split-input-file -control-flow-sink %s | FileCheck %s
+
+// CHECK-LABEL: @test_scf_if_sink
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32)
+// CHECK: %[[V0:.*]] = scf.if %[[ARG0]]
+// CHECK: %[[V1:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
+// CHECK: scf.yield %[[V1]]
+// CHECK: else
+// CHECK: %[[V1:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
+// CHECK: scf.yield %[[V1]]
+// CHECK: return %[[V0]]
+func @test_scf_if_sink(%arg0: i1, %arg1: i32) -> i32 {
+ %0 = arith.addi %arg1, %arg1 : i32
+ %1 = arith.muli %arg1, %arg1 : i32
+ %result = scf.if %arg0 -> i32 {
+ scf.yield %0 : i32
+ } else {
+ scf.yield %1 : i32
+ }
+ return %result : i32
+}
+
+// -----
+
+func private @consume(i32) -> ()
+
+// CHECK-LABEL: @test_scf_if_then_only_sink
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32)
+// CHECK: scf.if %[[ARG0]]
+// CHECK: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
+// CHECK: call @consume(%[[V0]])
+func @test_scf_if_then_only_sink(%arg0: i1, %arg1: i32) {
+ %0 = arith.addi %arg1, %arg1 : i32
+ scf.if %arg0 {
+ call @consume(%0) : (i32) -> ()
+ scf.yield
+ }
+ return
+}
+
+// -----
+
+func private @consume(i32) -> ()
+
+// CHECK-LABEL: @test_scf_if_double_sink
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32)
+// CHECK: scf.if %[[ARG0]]
+// CHECK: scf.if %[[ARG0]]
+// CHECK: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
+// CHECK: call @consume(%[[V0]])
+func @test_scf_if_double_sink(%arg0: i1, %arg1: i32) {
+ %0 = arith.addi %arg1, %arg1 : i32
+ scf.if %arg0 {
+ scf.if %arg0 {
+ call @consume(%0) : (i32) -> ()
+ scf.yield
+ }
+ }
+ return
+}
More information about the Mlir-commits
mailing list