[Mlir-commits] [mlir] 9d7da41 - Fix crash in scf.parallel verifier
Mehdi Amini
llvmlistbot at llvm.org
Tue Jan 17 08:22:15 PST 2023
Author: Mehdi Amini
Date: 2023-01-17T16:21:28Z
New Revision: 9d7da415d244124af93e42a7a378eb79c2fb391f
URL: https://github.com/llvm/llvm-project/commit/9d7da415d244124af93e42a7a378eb79c2fb391f
DIFF: https://github.com/llvm/llvm-project/commit/9d7da415d244124af93e42a7a378eb79c2fb391f.diff
LOG: Fix crash in scf.parallel verifier
Fixes #59989
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D141911
Added:
Modified:
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 4e6cd2e5e64bd..fc7ce764ea8f0 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -78,6 +78,23 @@ void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
builder.create<scf::YieldOp>(loc);
}
+/// Verifies that the first block of the given `region` is terminated by a
+/// TerminatorTy. Reports errors on the given operation if it is not the case.
+template <typename TerminatorTy>
+static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion,
+ StringRef errorMessage) {
+ Operation *terminatorOperation = nullptr;
+ if (!region.empty() && !region.front().empty()) {
+ terminatorOperation = ®ion.front().back();
+ if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
+ return yield;
+ }
+ auto diag = op->emitOpError(errorMessage);
+ if (terminatorOperation)
+ diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// ExecuteRegionOp
//===----------------------------------------------------------------------===//
@@ -2323,10 +2340,13 @@ LogicalResult ParallelOp::verify() {
"expects arguments for the induction variable to be of index type");
// Check that the yield has no results
- Operation *yield = body->getTerminator();
+ auto yield = verifyAndGetTerminator<scf::YieldOp>(
+ *this, getRegion(), "expects body to terminate with 'scf.yield'");
+ if (!yield)
+ return failure();
if (yield->getNumOperands() != 0)
- return yield->emitOpError() << "not allowed to have operands inside '"
- << ParallelOp::getOperationName() << "'";
+ return yield.emitOpError() << "not allowed to have operands inside '"
+ << ParallelOp::getOperationName() << "'";
// Check that the number of results is the same as the number of ReduceOps.
SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
@@ -2854,23 +2874,6 @@ static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
return success();
}
-/// Verifies that the first block of the given `region` is terminated by a
-/// YieldOp. Reports errors on the given operation if it is not the case.
-template <typename TerminatorTy>
-static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region ®ion,
- StringRef errorMessage) {
- Operation *terminatorOperation = nullptr;
- if (!region.empty() && !region.front().empty()) {
- terminatorOperation = ®ion.front().back();
- if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
- return yield;
- }
- auto diag = op.emitOpError(errorMessage);
- if (terminatorOperation)
- diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
- return nullptr;
-}
-
LogicalResult scf::WhileOp::verify() {
auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
*this, getBefore(),
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 498a3bc28041e..c1c66393a5ac8 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -672,3 +672,16 @@ func.func @switch_missing_terminator(%arg0: index, %arg1: i32) {
return
}) {cases = array<i64: 1>} : (index) -> ()
}
+
+// -----
+
+func.func @parallel_missing_terminator(%0 : index) {
+ // expected-error @below {{'scf.parallel' op expects body to terminate with 'scf.yield'}}
+ "scf.parallel"(%0, %0, %0) ({
+ ^bb0(%arg1: index):
+ // expected-note @below {{terminator here}}
+ %2 = "arith.constant"() {value = 1.000000e+00 : f32} : () -> f32
+ }) {operand_segment_sizes = array<i32: 1, 1, 1, 0>} : (index, index, index) -> ()
+ return
+}
+
More information about the Mlir-commits
mailing list