[Mlir-commits] [mlir] [mlir][RegionBranchOpInterface] explicitly check for existance of block terminator (PR #76831)

Maksim Levental llvmlistbot at llvm.org
Thu Jan 4 09:43:23 PST 2024


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/76831

>From d7d455702af64077bf0f9caff034915af50f5c37 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 3 Jan 2024 10:26:48 -0600
Subject: [PATCH 1/4] [mlir][RegionBranchOpInterface] explicitly check for
 existance of block terminator

---
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 4ed024ddae247b..8768ef9060c6ba 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -177,9 +177,10 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
 
     SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
     for (Block &block : region)
-      if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
-              block.getTerminator()))
-        regionReturnOps.push_back(terminator);
+      if (block.mightHaveTerminator())
+        if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
+                block.getTerminator()))
+          regionReturnOps.push_back(terminator);
 
     // If there is no return-like terminator, the op itself should verify
     // type consistency.

>From 0df86060ea51a60efa92121663f7e66be1f7818a Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 3 Jan 2024 17:21:31 -0600
Subject: [PATCH 2/4] add test

---
 .../IR/test-region-branch-op-verifier.mlir    | 15 +++++-
 mlir/test/lib/Dialect/Test/TestDialect.cpp    | 50 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 12 +++++
 3 files changed, 76 insertions(+), 1 deletion(-)

diff --git a/mlir/test/IR/test-region-branch-op-verifier.mlir b/mlir/test/IR/test-region-branch-op-verifier.mlir
index f5fb7fc2b25cb9..b94f6beb9796f2 100644
--- a/mlir/test/IR/test-region-branch-op-verifier.mlir
+++ b/mlir/test/IR/test-region-branch-op-verifier.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s
+// RUN: mlir-opt %s -split-input-file
 
 func.func @test_ops_verify(%arg: i32) -> f32 {
   %0 = "test.constant"() { value = 5.3 : f32 } : () -> f32
@@ -8,3 +8,16 @@ func.func @test_ops_verify(%arg: i32) -> f32 {
   }
   return %1 : f32
 }
+
+// -----
+
+func.func @test_no_terminator(%arg: index) {
+  test.switch_with_no_break %arg
+  case 0 {
+  ^bb:
+  }
+  case 1 {
+  ^bb:
+  }
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a1b30705f16a98..cb5ee6014b6113 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -53,6 +53,7 @@ using namespace test;
 Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
   return StringAttr::get(ctx, content);
 }
+
 LogicalResult
 MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
                           function_ref<InFlightDiagnostic()> emitError) {
@@ -64,6 +65,7 @@ MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
   prop.content = strAttr.getValue();
   return success();
 }
+
 llvm::hash_code MyPropStruct::hash() const {
   return hash_value(StringRef(content));
 }
@@ -127,6 +129,12 @@ static void customPrintProperties(OpAsmPrinter &p,
                                   const VersionedProperties &prop);
 static ParseResult customParseProperties(OpAsmParser &parser,
                                          VersionedProperties &prop);
+static ParseResult
+parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
+                 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions);
+
+static void printSwitchCases(OpAsmPrinter &p, Operation *op,
+                             DenseI64ArrayAttr cases, RegionRange caseRegions);
 
 void test::registerTestDialect(DialectRegistry &registry) {
   registry.insert<TestDialect>();
@@ -230,6 +238,7 @@ void TestDialect::initialize() {
   // unregistered op.
   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
 }
+
 TestDialect::~TestDialect() {
   delete static_cast<TestOpEffectInterfaceFallback *>(
       fallbackEffectOpInterfaces);
@@ -1013,6 +1022,13 @@ LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
   return getNextIterArgMutable();
 }
 
+//===----------------------------------------------------------------------===//
+// SwitchWithNoBreakOp
+//===----------------------------------------------------------------------===//
+
+void TestNoTerminatorOp::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}
+
 //===----------------------------------------------------------------------===//
 // SingleNoTerminatorCustomAsmOp
 //===----------------------------------------------------------------------===//
@@ -1160,6 +1176,7 @@ setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
   prop.value = valueAttr.getValue().getSExtValue();
   return success();
 }
+
 static DictionaryAttr
 getPropertiesAsAttribute(MLIRContext *ctx,
                          const PropertiesWithCustomPrint &prop) {
@@ -1169,14 +1186,17 @@ getPropertiesAsAttribute(MLIRContext *ctx,
   attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
   return b.getDictionaryAttr(attrs);
 }
+
 static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop) {
   return llvm::hash_combine(prop.value, StringRef(*prop.label));
 }
+
 static void customPrintProperties(OpAsmPrinter &p,
                                   const PropertiesWithCustomPrint &prop) {
   p.printKeywordOrString(*prop.label);
   p << " is " << prop.value;
 }
+
 static ParseResult customParseProperties(OpAsmParser &parser,
                                          PropertiesWithCustomPrint &prop) {
   std::string label;
@@ -1186,6 +1206,31 @@ static ParseResult customParseProperties(OpAsmParser &parser,
   prop.label = std::make_shared<std::string>(std::move(label));
   return success();
 }
+
+static ParseResult
+parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
+                 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
+  SmallVector<int64_t> caseValues;
+  while (succeeded(p.parseOptionalKeyword("case"))) {
+    int64_t value;
+    Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
+    if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
+      return failure();
+    caseValues.push_back(value);
+  }
+  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
+  return success();
+}
+
+static void printSwitchCases(OpAsmPrinter &p, Operation *op,
+                             DenseI64ArrayAttr cases, RegionRange caseRegions) {
+  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
+    p.printNewline();
+    p << "case " << value << ' ';
+    p.printRegion(*region, /*printEntryBlockArgs=*/false);
+  }
+}
+
 static LogicalResult
 setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
                            function_ref<InFlightDiagnostic()> emitError) {
@@ -1209,6 +1254,7 @@ setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
   prop.value2 = value2Attr.getValue().getSExtValue();
   return success();
 }
+
 static DictionaryAttr
 getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) {
   SmallVector<NamedAttribute> attrs;
@@ -1217,13 +1263,16 @@ getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) {
   attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
   return b.getDictionaryAttr(attrs);
 }
+
 static llvm::hash_code computeHash(const VersionedProperties &prop) {
   return llvm::hash_combine(prop.value1, prop.value2);
 }
+
 static void customPrintProperties(OpAsmPrinter &p,
                                   const VersionedProperties &prop) {
   p << prop.value1 << " | " << prop.value2;
 }
+
 static ParseResult customParseProperties(OpAsmParser &parser,
                                          VersionedProperties &prop) {
   if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
@@ -1393,6 +1442,7 @@ ::mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
   prop.value2 = value2;
   return success();
 }
+
 void TestOpWithVersionedProperties::writeToMlirBytecode(
     ::mlir::DialectBytecodeWriter &writer,
     const test::VersionedProperties &prop) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 48b41d8698762d..61e181999ff2fb 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2213,6 +2213,18 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term",
   }];
 }
 
+def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
+    NoTerminator,
+    DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]>
+  ]> {
+  let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);
+  let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
+
+  let assemblyFormat = [{
+    $arg attr-dict custom<SwitchCases>($cases, $caseRegions)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Test TableGen generated build() methods
 //===----------------------------------------------------------------------===//

>From 29a55509f11aa242a7eb14c3196289e3eb1d392c Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 4 Jan 2024 11:18:55 -0600
Subject: [PATCH 3/4] comments

---
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 8768ef9060c6ba..a563ec5cb8db58 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -177,9 +177,9 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
 
     SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
     for (Block &block : region)
-      if (block.mightHaveTerminator())
-        if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
-                block.getTerminator()))
+      if (!block.empty())
+        if (auto terminator =
+                dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
           regionReturnOps.push_back(terminator);
 
     // If there is no return-like terminator, the op itself should verify



More information about the Mlir-commits mailing list