[Mlir-commits] [mlir] c710381 - [mlir][scf] Add getNumRegionInvocations to IfOp
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 15 06:56:24 PST 2021
Author: Mogball
Date: 2021-12-15T14:56:20Z
New Revision: c7103810bde9e300f0a272f0dc55eb324f5415f2
URL: https://github.com/llvm/llvm-project/commit/c7103810bde9e300f0a272f0dc55eb324f5415f2
DIFF: https://github.com/llvm/llvm-project/commit/c7103810bde9e300f0a272f0dc55eb324f5415f2.diff
LOG: [mlir][scf] Add getNumRegionInvocations to IfOp
Implements the RegionBranchOpInterface method getNumRegionInvocations to `scf::IfOp` so that, when the condition is constant, the number of region executions can be analyzed by `NumberOfExecutions`.
Reviewed By: jpienaar, ftynse
Differential Revision: https://reviews.llvm.org/D115087
Added:
mlir/unittests/Dialect/SCF/CMakeLists.txt
mlir/unittests/Dialect/SCF/SCFOps.cpp
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/unittests/Dialect/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 89293bd4cd9bf..29e36cb977e8a 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -403,6 +403,12 @@ def IfOp : SCF_Op<"if",
YieldOp thenYield();
Block* elseBlock();
YieldOp elseYield();
+
+ /// If the condition is a constant, returns 1 for the executed block and 0
+ /// for the other. Otherwise, returns `kUnknownNumRegionInvocations` for
+ /// both successors.
+ void getNumRegionInvocations(ArrayRef<Attribute> operands,
+ SmallVectorImpl<int64_t> &countPerRegion);
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index c7f1436f9fbbc..cea88d1e52a66 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -474,7 +474,7 @@ void ForOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
// Loop bounds are not known statically.
if (!lb || !ub || !step || step.getValue().getSExtValue() == 0) {
- countPerRegion[0] = -1;
+ countPerRegion[0] = kUnknownNumRegionInvocations;
return;
}
@@ -1181,6 +1181,23 @@ void IfOp::getSuccessorRegions(Optional<unsigned> index,
regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
}
+/// If the condition is a constant, returns 1 for the executed block and 0 for
+/// the other. Otherwise, returns `kUnknownNumRegionInvocations` for both
+/// successors.
+void IfOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
+ SmallVectorImpl<int64_t> &countPerRegion) {
+ if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
+ // If the condition is true, `then` is executed once and `else` zero times,
+ // and vice-versa.
+ bool cond = condAttr.getValue().isOneValue();
+ countPerRegion.assign(1, cond ? 1 : 0);
+ countPerRegion.push_back(cond ? 0 : 1);
+ } else {
+ // Non-constant condition: unknown invocations for both successors.
+ countPerRegion.assign(2, kUnknownNumRegionInvocations);
+ }
+}
+
namespace {
// Pattern to remove unused IfOp results.
struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index f37f578572b8f..91aec5054af25 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -7,6 +7,7 @@ target_link_libraries(MLIRDialectTests
MLIRDialect)
add_subdirectory(Quant)
+add_subdirectory(SCF)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
add_subdirectory(Utils)
diff --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt
new file mode 100644
index 0000000000000..81e05d353d49d
--- /dev/null
+++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_mlir_unittest(MLIRSCFTests
+ SCFOps.cpp
+ )
+target_link_libraries(MLIRSCFTests
+ PRIVATE
+ MLIRIR
+ MLIRParser
+ MLIRSCF
+ MLIRStandard
+ )
diff --git a/mlir/unittests/Dialect/SCF/SCFOps.cpp b/mlir/unittests/Dialect/SCF/SCFOps.cpp
new file mode 100644
index 0000000000000..099bd2764e7d7
--- /dev/null
+++ b/mlir/unittests/Dialect/SCF/SCFOps.cpp
@@ -0,0 +1,67 @@
+//===- SCFOps.cpp - SCF Op Unit Tests -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Parser.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+class SCFOpsTest : public testing::Test {
+public:
+ SCFOpsTest() {
+ context.getOrLoadDialect<scf::SCFDialect>();
+ context.getOrLoadDialect<StandardOpsDialect>();
+ }
+
+protected:
+ MLIRContext context;
+};
+
+TEST_F(SCFOpsTest, IfOpNumRegionInvocations) {
+ const char *const code = R"mlir(
+func @test(%cond : i1) -> () {
+ scf.if %cond {
+ scf.yield
+ } else {
+ scf.yield
+ }
+ return
+}
+)mlir";
+ Builder builder(&context);
+
+ auto module = parseSourceString(code, &context);
+ ASSERT_TRUE(module);
+ scf::IfOp op;
+ module->walk([&](scf::IfOp ifOp) { op = ifOp; });
+ ASSERT_TRUE(op);
+
+ SmallVector<int64_t> countPerRegion;
+ op.getNumRegionInvocations({Attribute()}, countPerRegion);
+ EXPECT_EQ(countPerRegion.size(), 2u);
+ EXPECT_EQ(countPerRegion[0], kUnknownNumRegionInvocations);
+ EXPECT_EQ(countPerRegion[1], kUnknownNumRegionInvocations);
+
+ countPerRegion.clear();
+ op.getNumRegionInvocations(
+ {builder.getIntegerAttr(builder.getI1Type(), true)}, countPerRegion);
+ EXPECT_EQ(countPerRegion.size(), 2u);
+ EXPECT_EQ(countPerRegion[0], 1);
+ EXPECT_EQ(countPerRegion[1], 0);
+
+ countPerRegion.clear();
+ op.getNumRegionInvocations(
+ {builder.getIntegerAttr(builder.getI1Type(), false)}, countPerRegion);
+ EXPECT_EQ(countPerRegion.size(), 2u);
+ EXPECT_EQ(countPerRegion[0], 0);
+ EXPECT_EQ(countPerRegion[1], 1);
+}
+} // end anonymous namespace
More information about the Mlir-commits
mailing list