[Mlir-commits] [mlir] [mlir][RegionBranchOpInterface] explicitly check for existance of block terminator (PR #76831)
Maksim Levental
llvmlistbot at llvm.org
Wed Jan 3 15:21:44 PST 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/76831
>From f7e5abb37a4e8eaf5ca5e7880a84584c419ec696 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 3 Jan 2024 10:26:48 -0600
Subject: [PATCH 1/2] [mlir][RegionBranchOpInterface] explicitly check for
existance of block terminator
---
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 4ed024ddae247b..8768ef9060c6ba 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -177,9 +177,10 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
for (Block &block : region)
- if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
- block.getTerminator()))
- regionReturnOps.push_back(terminator);
+ if (block.mightHaveTerminator())
+ if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
+ block.getTerminator()))
+ regionReturnOps.push_back(terminator);
// If there is no return-like terminator, the op itself should verify
// type consistency.
>From 22d286203e46455570e29490b9a02f44623f7b4c Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 3 Jan 2024 17:21:31 -0600
Subject: [PATCH 2/2] add test
---
.../IR/test-region-branch-op-verifier.mlir | 15 +++++-
mlir/test/lib/Dialect/Test/TestDialect.cpp | 50 +++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 12 +++++
3 files changed, 76 insertions(+), 1 deletion(-)
diff --git a/mlir/test/IR/test-region-branch-op-verifier.mlir b/mlir/test/IR/test-region-branch-op-verifier.mlir
index f5fb7fc2b25cb9..b94f6beb9796f2 100644
--- a/mlir/test/IR/test-region-branch-op-verifier.mlir
+++ b/mlir/test/IR/test-region-branch-op-verifier.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s
+// RUN: mlir-opt %s -split-input-file
func.func @test_ops_verify(%arg: i32) -> f32 {
%0 = "test.constant"() { value = 5.3 : f32 } : () -> f32
@@ -8,3 +8,16 @@ func.func @test_ops_verify(%arg: i32) -> f32 {
}
return %1 : f32
}
+
+// -----
+
+func.func @test_no_terminator(%arg: index) {
+ test.switch_with_no_break %arg
+ case 0 {
+ ^bb:
+ }
+ case 1 {
+ ^bb:
+ }
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a1b30705f16a98..cb5ee6014b6113 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -53,6 +53,7 @@ using namespace test;
Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
return StringAttr::get(ctx, content);
}
+
LogicalResult
MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
function_ref<InFlightDiagnostic()> emitError) {
@@ -64,6 +65,7 @@ MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
prop.content = strAttr.getValue();
return success();
}
+
llvm::hash_code MyPropStruct::hash() const {
return hash_value(StringRef(content));
}
@@ -127,6 +129,12 @@ static void customPrintProperties(OpAsmPrinter &p,
const VersionedProperties &prop);
static ParseResult customParseProperties(OpAsmParser &parser,
VersionedProperties &prop);
+static ParseResult
+parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
+ SmallVectorImpl<std::unique_ptr<Region>> &caseRegions);
+
+static void printSwitchCases(OpAsmPrinter &p, Operation *op,
+ DenseI64ArrayAttr cases, RegionRange caseRegions);
void test::registerTestDialect(DialectRegistry ®istry) {
registry.insert<TestDialect>();
@@ -230,6 +238,7 @@ void TestDialect::initialize() {
// unregistered op.
fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
}
+
TestDialect::~TestDialect() {
delete static_cast<TestOpEffectInterfaceFallback *>(
fallbackEffectOpInterfaces);
@@ -1013,6 +1022,13 @@ LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
return getNextIterArgMutable();
}
+//===----------------------------------------------------------------------===//
+// SwitchWithNoBreakOp
+//===----------------------------------------------------------------------===//
+
+void TestNoTerminatorOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {}
+
//===----------------------------------------------------------------------===//
// SingleNoTerminatorCustomAsmOp
//===----------------------------------------------------------------------===//
@@ -1160,6 +1176,7 @@ setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
prop.value = valueAttr.getValue().getSExtValue();
return success();
}
+
static DictionaryAttr
getPropertiesAsAttribute(MLIRContext *ctx,
const PropertiesWithCustomPrint &prop) {
@@ -1169,14 +1186,17 @@ getPropertiesAsAttribute(MLIRContext *ctx,
attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
return b.getDictionaryAttr(attrs);
}
+
static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop) {
return llvm::hash_combine(prop.value, StringRef(*prop.label));
}
+
static void customPrintProperties(OpAsmPrinter &p,
const PropertiesWithCustomPrint &prop) {
p.printKeywordOrString(*prop.label);
p << " is " << prop.value;
}
+
static ParseResult customParseProperties(OpAsmParser &parser,
PropertiesWithCustomPrint &prop) {
std::string label;
@@ -1186,6 +1206,31 @@ static ParseResult customParseProperties(OpAsmParser &parser,
prop.label = std::make_shared<std::string>(std::move(label));
return success();
}
+
+static ParseResult
+parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
+ SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
+ SmallVector<int64_t> caseValues;
+ while (succeeded(p.parseOptionalKeyword("case"))) {
+ int64_t value;
+ Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
+ if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
+ return failure();
+ caseValues.push_back(value);
+ }
+ cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
+ return success();
+}
+
+static void printSwitchCases(OpAsmPrinter &p, Operation *op,
+ DenseI64ArrayAttr cases, RegionRange caseRegions) {
+ for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
+ p.printNewline();
+ p << "case " << value << ' ';
+ p.printRegion(*region, /*printEntryBlockArgs=*/false);
+ }
+}
+
static LogicalResult
setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
function_ref<InFlightDiagnostic()> emitError) {
@@ -1209,6 +1254,7 @@ setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
prop.value2 = value2Attr.getValue().getSExtValue();
return success();
}
+
static DictionaryAttr
getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) {
SmallVector<NamedAttribute> attrs;
@@ -1217,13 +1263,16 @@ getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) {
attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
return b.getDictionaryAttr(attrs);
}
+
static llvm::hash_code computeHash(const VersionedProperties &prop) {
return llvm::hash_combine(prop.value1, prop.value2);
}
+
static void customPrintProperties(OpAsmPrinter &p,
const VersionedProperties &prop) {
p << prop.value1 << " | " << prop.value2;
}
+
static ParseResult customParseProperties(OpAsmParser &parser,
VersionedProperties &prop) {
if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
@@ -1393,6 +1442,7 @@ ::mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
prop.value2 = value2;
return success();
}
+
void TestOpWithVersionedProperties::writeToMlirBytecode(
::mlir::DialectBytecodeWriter &writer,
const test::VersionedProperties &prop) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 48b41d8698762d..61e181999ff2fb 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2213,6 +2213,18 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term",
}];
}
+def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
+ NoTerminator,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]>
+ ]> {
+ let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);
+ let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
+
+ let assemblyFormat = [{
+ $arg attr-dict custom<SwitchCases>($cases, $caseRegions)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Test TableGen generated build() methods
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list