[Mlir-commits] [mlir] c710381 - [mlir][scf] Add getNumRegionInvocations to IfOp

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 15 06:56:24 PST 2021


Author: Mogball
Date: 2021-12-15T14:56:20Z
New Revision: c7103810bde9e300f0a272f0dc55eb324f5415f2

URL: https://github.com/llvm/llvm-project/commit/c7103810bde9e300f0a272f0dc55eb324f5415f2
DIFF: https://github.com/llvm/llvm-project/commit/c7103810bde9e300f0a272f0dc55eb324f5415f2.diff

LOG: [mlir][scf] Add getNumRegionInvocations to IfOp

Implements the RegionBranchOpInterface method getNumRegionInvocations to `scf::IfOp` so that, when the condition is constant, the number of region executions can be analyzed by `NumberOfExecutions`.

Reviewed By: jpienaar, ftynse

Differential Revision: https://reviews.llvm.org/D115087

Added: 
    mlir/unittests/Dialect/SCF/CMakeLists.txt
    mlir/unittests/Dialect/SCF/SCFOps.cpp

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/unittests/Dialect/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 89293bd4cd9bf..29e36cb977e8a 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -403,6 +403,12 @@ def IfOp : SCF_Op<"if",
     YieldOp thenYield();
     Block* elseBlock();
     YieldOp elseYield();
+
+    /// If the condition is a constant, returns 1 for the executed block and 0
+    /// for the other. Otherwise, returns `kUnknownNumRegionInvocations` for
+    /// both successors.
+    void getNumRegionInvocations(ArrayRef<Attribute> operands,
+                                 SmallVectorImpl<int64_t> &countPerRegion);
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index c7f1436f9fbbc..cea88d1e52a66 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -474,7 +474,7 @@ void ForOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
 
   // Loop bounds are not known statically.
   if (!lb || !ub || !step || step.getValue().getSExtValue() == 0) {
-    countPerRegion[0] = -1;
+    countPerRegion[0] = kUnknownNumRegionInvocations;
     return;
   }
 
@@ -1181,6 +1181,23 @@ void IfOp::getSuccessorRegions(Optional<unsigned> index,
   regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
 }
 
+/// If the condition is a constant, returns 1 for the executed block and 0 for
+/// the other. Otherwise, returns `kUnknownNumRegionInvocations` for both
+/// successors.
+void IfOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
+                                   SmallVectorImpl<int64_t> &countPerRegion) {
+  if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
+    // If the condition is true, `then` is executed once and `else` zero times,
+    // and vice-versa.
+    bool cond = condAttr.getValue().isOneValue();
+    countPerRegion.assign(1, cond ? 1 : 0);
+    countPerRegion.push_back(cond ? 0 : 1);
+  } else {
+    // Non-constant condition: unknown invocations for both successors.
+    countPerRegion.assign(2, kUnknownNumRegionInvocations);
+  }
+}
+
 namespace {
 // Pattern to remove unused IfOp results.
 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {

diff  --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index f37f578572b8f..91aec5054af25 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -7,6 +7,7 @@ target_link_libraries(MLIRDialectTests
   MLIRDialect)
 
 add_subdirectory(Quant)
+add_subdirectory(SCF)
 add_subdirectory(SparseTensor)
 add_subdirectory(SPIRV)
 add_subdirectory(Utils)

diff  --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt
new file mode 100644
index 0000000000000..81e05d353d49d
--- /dev/null
+++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_mlir_unittest(MLIRSCFTests
+  SCFOps.cpp
+  )
+target_link_libraries(MLIRSCFTests
+  PRIVATE
+  MLIRIR
+  MLIRParser
+  MLIRSCF
+  MLIRStandard
+  )

diff  --git a/mlir/unittests/Dialect/SCF/SCFOps.cpp b/mlir/unittests/Dialect/SCF/SCFOps.cpp
new file mode 100644
index 0000000000000..099bd2764e7d7
--- /dev/null
+++ b/mlir/unittests/Dialect/SCF/SCFOps.cpp
@@ -0,0 +1,67 @@
+//===- SCFOps.cpp - SCF Op Unit Tests -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Parser.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+class SCFOpsTest : public testing::Test {
+public:
+  SCFOpsTest() {
+    context.getOrLoadDialect<scf::SCFDialect>();
+    context.getOrLoadDialect<StandardOpsDialect>();
+  }
+
+protected:
+  MLIRContext context;
+};
+
+TEST_F(SCFOpsTest, IfOpNumRegionInvocations) {
+  const char *const code = R"mlir(
+func @test(%cond : i1) -> () {
+  scf.if %cond {
+    scf.yield
+  } else {
+    scf.yield
+  }
+  return
+}
+)mlir";
+  Builder builder(&context);
+
+  auto module = parseSourceString(code, &context);
+  ASSERT_TRUE(module);
+  scf::IfOp op;
+  module->walk([&](scf::IfOp ifOp) { op = ifOp; });
+  ASSERT_TRUE(op);
+
+  SmallVector<int64_t> countPerRegion;
+  op.getNumRegionInvocations({Attribute()}, countPerRegion);
+  EXPECT_EQ(countPerRegion.size(), 2u);
+  EXPECT_EQ(countPerRegion[0], kUnknownNumRegionInvocations);
+  EXPECT_EQ(countPerRegion[1], kUnknownNumRegionInvocations);
+
+  countPerRegion.clear();
+  op.getNumRegionInvocations(
+      {builder.getIntegerAttr(builder.getI1Type(), true)}, countPerRegion);
+  EXPECT_EQ(countPerRegion.size(), 2u);
+  EXPECT_EQ(countPerRegion[0], 1);
+  EXPECT_EQ(countPerRegion[1], 0);
+
+  countPerRegion.clear();
+  op.getNumRegionInvocations(
+      {builder.getIntegerAttr(builder.getI1Type(), false)}, countPerRegion);
+  EXPECT_EQ(countPerRegion.size(), 2u);
+  EXPECT_EQ(countPerRegion[0], 0);
+  EXPECT_EQ(countPerRegion[1], 1);
+}
+} // end anonymous namespace


        


More information about the Mlir-commits mailing list