[Mlir-commits] [mlir] 7d03c8e - [mlir][Parser] Fix use-after-free when parsing invalid reference to nested definition (#127778)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 19 23:42:31 PST 2025


Author: Matthias Springer
Date: 2025-02-20T08:42:27+01:00
New Revision: 7d03c8e256a78b67a645b78e3ca93287bee0cd37

URL: https://github.com/llvm/llvm-project/commit/7d03c8e256a78b67a645b78e3ca93287bee0cd37
DIFF: https://github.com/llvm/llvm-project/commit/7d03c8e256a78b67a645b78e3ca93287bee0cd37.diff

LOG: [mlir][Parser] Fix use-after-free when parsing invalid reference to nested definition (#127778)

This commit fixes a use-after-free crash when parsing the following
invalid IR:
```mlir
scf.for ... iter_args(%var = %foo) -> tensor<?xf32> {
  %foo = "test.inner"() : () -> (tensor<?xf32>)
  scf.yield %arg0 : tensor<?xf32>
}
```

The `scf.for` parser was implemented as follows:
1. Resolve operands (including `%foo`).
2. Parse the region.

During operand resolution, a forward reference
(`unrealized_conversion_cast`) is added by the parser because `%foo` has
not been defined yet. During region parsing, the definition of `%foo` is
found and the forward reference is replaced with the actual definition.
(And the forward reference is deleted.) However, the operand of the
`scf.for` op is not updated because the `scf.for` op has not been
created yet; all we have is an `OperationState` object.

All parsers should be written in such a way that they first parse the
region and then resolve the operands. That way, no forward reference is
inserted in the first place. Before parsing the region, it may be
necessary to set the argument types if they are defined as part of the
assembly format of the op (as is the case with `scf.for`). Note: Ops in
generic format are parsed in the same way.

To make the parsing infrastructure more robust, this commit also delays
the erase of forward references until the end of the lifetime of the
parser. Instead of a use-after-free crash, users will then see more
descriptive error messages such as:
```
error: operation's operand is unlinked
```

Note: The proper way to fix the parser is to first parse the region,
then resolve the operands. The change to `Parser.cpp` is merely to help
users finding the root cause of the problem.

Added: 
    

Modified: 
    mlir/lib/AsmParser/Parser.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index b5f1d2e27c9ba..2982757a6c5ce 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -820,6 +820,12 @@ class OperationParser : public Parser {
   /// their first reference, to allow checking for use of undefined values.
   DenseMap<Value, SMLoc> forwardRefPlaceholders;
 
+  /// Operations that define the placeholders. These are kept until the end of
+  /// of the lifetime of the parser because some custom parsers may store
+  /// references to them in local state and use them after forward references
+  /// have been resolved.
+  DenseSet<Operation *> forwardRefOps;
+
   /// Deffered locations: when parsing `loc(#loc42)` we add an entry to this
   /// map. After parsing the definition `#loc42 = ...` we'll patch back users
   /// of this location.
@@ -847,11 +853,11 @@ OperationParser::OperationParser(ParserState &state, ModuleOp topLevelOp)
 }
 
 OperationParser::~OperationParser() {
-  for (auto &fwd : forwardRefPlaceholders) {
+  for (Operation *op : forwardRefOps) {
     // Drop all uses of undefined forward declared reference and destroy
     // defining operation.
-    fwd.first.dropAllUses();
-    fwd.first.getDefiningOp()->destroy();
+    op->dropAllUses();
+    op->destroy();
   }
   for (const auto &scope : forwardRef) {
     for (const auto &fwd : scope) {
@@ -1007,7 +1013,6 @@ ParseResult OperationParser::addDefinition(UnresolvedOperand useInfo,
     // the actual definition instead, delete the forward ref, and remove it
     // from our set of forward references we track.
     existing.replaceAllUsesWith(value);
-    existing.getDefiningOp()->destroy();
     forwardRefPlaceholders.erase(existing);
 
     // If a definition of the value already exists, replace it in the assembly
@@ -1194,6 +1199,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
       /*attributes=*/std::nullopt, /*properties=*/nullptr, /*successors=*/{},
       /*numRegions=*/0);
   forwardRefPlaceholders[op->getResult(0)] = loc;
+  forwardRefOps.insert(op);
   return op->getResult(0);
 }
 

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 448141735ba7f..1cfb866db0b51 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -499,15 +499,27 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
   else if (parser.parseType(type))
     return failure();
 
-  // Resolve input operands.
+  // Set block argument types, so that they are known when parsing the region.
   regionArgs.front().type = type;
+  for (auto [iterArg, type] :
+       llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
+    iterArg.type = type;
+
+  // Parse the body region.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+  ForOp::ensureTerminator(*body, builder, result.location);
+
+  // Resolve input operands. This should be done after parsing the region to
+  // catch invalid IR where operands were defined inside of the region.
   if (parser.resolveOperand(lb, type, result.operands) ||
       parser.resolveOperand(ub, type, result.operands) ||
       parser.resolveOperand(step, type, result.operands))
     return failure();
   if (hasIterArgs) {
-    for (auto argOperandType :
-         llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+    for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
+                                               operands, result.types)) {
       Type type = std::get<2>(argOperandType);
       std::get<0>(argOperandType).type = type;
       if (parser.resolveOperand(std::get<1>(argOperandType), type,
@@ -516,13 +528,6 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
     }
   }
 
-  // Parse the body region.
-  Region *body = result.addRegion();
-  if (parser.parseRegion(*body, regionArgs))
-    return failure();
-
-  ForOp::ensureTerminator(*body, builder, result.location);
-
   // Parse the optional attribute list.
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();

diff  --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 80576be880127..76c785f3e6166 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -747,3 +747,13 @@ func.func @parallel_missing_terminator(%0 : index) {
   return
 }
 
+// -----
+
+func.func @invalid_reference(%a: index) {
+  // expected-error @below{{use of undeclared SSA value name}}
+  scf.for %x = %a to %a step %a iter_args(%var = %foo) -> tensor<?xf32> {
+    %foo = "test.inner"() : () -> (tensor<?xf32>)
+    scf.yield %foo : tensor<?xf32>
+  }
+  return
+}


        


More information about the Mlir-commits mailing list