[Mlir-commits] [mlir] [mlir][Parser] Fix use-after-free when parsing invalid reference to nested definition (PR #127778)
Matthias Springer
llvmlistbot at llvm.org
Wed Feb 19 23:33:01 PST 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/127778
>From 5b9a4e1ab4b08b354d9581cae4f5e2ad608d819e Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 19 Feb 2025 11:15:44 +0100
Subject: [PATCH 1/2] fix scf.for parser
---
mlir/lib/AsmParser/Parser.cpp | 14 ++++++++++----
mlir/lib/Dialect/SCF/IR/SCF.cpp | 21 +++++++++++++--------
mlir/test/Dialect/SCF/invalid.mlir | 10 ++++++++++
3 files changed, 33 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index b5f1d2e27c9ba..2982757a6c5ce 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -820,6 +820,12 @@ class OperationParser : public Parser {
/// their first reference, to allow checking for use of undefined values.
DenseMap<Value, SMLoc> forwardRefPlaceholders;
+ /// Operations that define the placeholders. These are kept until the end of
+ /// of the lifetime of the parser because some custom parsers may store
+ /// references to them in local state and use them after forward references
+ /// have been resolved.
+ DenseSet<Operation *> forwardRefOps;
+
/// Deffered locations: when parsing `loc(#loc42)` we add an entry to this
/// map. After parsing the definition `#loc42 = ...` we'll patch back users
/// of this location.
@@ -847,11 +853,11 @@ OperationParser::OperationParser(ParserState &state, ModuleOp topLevelOp)
}
OperationParser::~OperationParser() {
- for (auto &fwd : forwardRefPlaceholders) {
+ for (Operation *op : forwardRefOps) {
// Drop all uses of undefined forward declared reference and destroy
// defining operation.
- fwd.first.dropAllUses();
- fwd.first.getDefiningOp()->destroy();
+ op->dropAllUses();
+ op->destroy();
}
for (const auto &scope : forwardRef) {
for (const auto &fwd : scope) {
@@ -1007,7 +1013,6 @@ ParseResult OperationParser::addDefinition(UnresolvedOperand useInfo,
// the actual definition instead, delete the forward ref, and remove it
// from our set of forward references we track.
existing.replaceAllUsesWith(value);
- existing.getDefiningOp()->destroy();
forwardRefPlaceholders.erase(existing);
// If a definition of the value already exists, replace it in the assembly
@@ -1194,6 +1199,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
/*attributes=*/std::nullopt, /*properties=*/nullptr, /*successors=*/{},
/*numRegions=*/0);
forwardRefPlaceholders[op->getResult(0)] = loc;
+ forwardRefOps.insert(op);
return op->getResult(0);
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 448141735ba7f..1f70ad57d986b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -499,8 +499,20 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
else if (parser.parseType(type))
return failure();
- // Resolve input operands.
+ // Set block argument types, so that they are known when parsing the region.
regionArgs.front().type = type;
+ for (auto [iterArg, type] :
+ llvm::zip(llvm::drop_begin(regionArgs), result.types))
+ iterArg.type = type;
+
+ // Parse the body region.
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+ ForOp::ensureTerminator(*body, builder, result.location);
+
+ // Resolve input operands. This should be done after parsing the region to
+ // catch invalid IR where operands were defined inside of the region.
if (parser.resolveOperand(lb, type, result.operands) ||
parser.resolveOperand(ub, type, result.operands) ||
parser.resolveOperand(step, type, result.operands))
@@ -516,13 +528,6 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
}
}
- // Parse the body region.
- Region *body = result.addRegion();
- if (parser.parseRegion(*body, regionArgs))
- return failure();
-
- ForOp::ensureTerminator(*body, builder, result.location);
-
// Parse the optional attribute list.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 80576be880127..76c785f3e6166 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -747,3 +747,13 @@ func.func @parallel_missing_terminator(%0 : index) {
return
}
+// -----
+
+func.func @invalid_reference(%a: index) {
+ // expected-error @below{{use of undeclared SSA value name}}
+ scf.for %x = %a to %a step %a iter_args(%var = %foo) -> tensor<?xf32> {
+ %foo = "test.inner"() : () -> (tensor<?xf32>)
+ scf.yield %foo : tensor<?xf32>
+ }
+ return
+}
>From d6177255b8046d88035bccbbd9e67adcb96653ac Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 20 Feb 2025 08:32:39 +0100
Subject: [PATCH 2/2] address comments
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 1f70ad57d986b..1cfb866db0b51 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -502,7 +502,7 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
// Set block argument types, so that they are known when parsing the region.
regionArgs.front().type = type;
for (auto [iterArg, type] :
- llvm::zip(llvm::drop_begin(regionArgs), result.types))
+ llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
iterArg.type = type;
// Parse the body region.
@@ -518,8 +518,8 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
parser.resolveOperand(step, type, result.operands))
return failure();
if (hasIterArgs) {
- for (auto argOperandType :
- llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+ for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
+ operands, result.types)) {
Type type = std::get<2>(argOperandType);
std::get<0>(argOperandType).type = type;
if (parser.resolveOperand(std::get<1>(argOperandType), type,
More information about the Mlir-commits
mailing list