[Mlir-commits] [mlir] [mlir][Parser] Fix use-after-free when parsing invalid reference to nested definition (PR #127778)

Matthias Springer llvmlistbot at llvm.org
Wed Feb 19 02:36:18 PST 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/127778

This commit fixes a use-after-free crash when parsing the following invalid IR:
```mlir
scf.for ... iter_args(%var = %foo) -> tensor<?xf32> {
  %foo = "test.inner"() : () -> (tensor<?xf32>)
  scf.yield %arg0 : tensor<?xf32>
}
```

The `scf.for` parser was implemented as follows:
1. Resolve operands (including `%foo`).
2. Parse the region.

During operand resolution, a forward reference (`unrealized_conversion_cast`) is added by the parser because `%foo` has not been defined yet. During region parsing, the definition of `%foo` is found and the forward reference is replaced with the actual definition. (And the forward reference is deleted.) However, the operand of the `scf.for` op is not updated because the `scf.for` op has not been created yet; all we have is an `OperationState` object.

All parsers should be written in such a way that they first parse the region and then resolve the operands. That way, no forward reference is inserted in the first place. Before parsing the region, it may be necessary to set the argument types if they are defined as part of the assembly format of the op (as is the case with `scf.for`). Note: Ops in generic format are parsed in the same way.

To make the parsing infrastructure more robust, this commit also delays the erase of forward references until the end of the lifetime of the parser. Instead of a use-after-free crash, users will then see more descriptive error messages such as:
```
error: operation's operand is unlinked
```

Note: The proper way to fix the parser is to first parse the region, then resolve the operands. The change to `Parser.cpp` is merely to help users finding the root cause of the problem.


>From 5b9a4e1ab4b08b354d9581cae4f5e2ad608d819e Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 19 Feb 2025 11:15:44 +0100
Subject: [PATCH] fix scf.for parser

---
 mlir/lib/AsmParser/Parser.cpp      | 14 ++++++++++----
 mlir/lib/Dialect/SCF/IR/SCF.cpp    | 21 +++++++++++++--------
 mlir/test/Dialect/SCF/invalid.mlir | 10 ++++++++++
 3 files changed, 33 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index b5f1d2e27c9ba..2982757a6c5ce 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -820,6 +820,12 @@ class OperationParser : public Parser {
   /// their first reference, to allow checking for use of undefined values.
   DenseMap<Value, SMLoc> forwardRefPlaceholders;
 
+  /// Operations that define the placeholders. These are kept until the end of
+  /// of the lifetime of the parser because some custom parsers may store
+  /// references to them in local state and use them after forward references
+  /// have been resolved.
+  DenseSet<Operation *> forwardRefOps;
+
   /// Deffered locations: when parsing `loc(#loc42)` we add an entry to this
   /// map. After parsing the definition `#loc42 = ...` we'll patch back users
   /// of this location.
@@ -847,11 +853,11 @@ OperationParser::OperationParser(ParserState &state, ModuleOp topLevelOp)
 }
 
 OperationParser::~OperationParser() {
-  for (auto &fwd : forwardRefPlaceholders) {
+  for (Operation *op : forwardRefOps) {
     // Drop all uses of undefined forward declared reference and destroy
     // defining operation.
-    fwd.first.dropAllUses();
-    fwd.first.getDefiningOp()->destroy();
+    op->dropAllUses();
+    op->destroy();
   }
   for (const auto &scope : forwardRef) {
     for (const auto &fwd : scope) {
@@ -1007,7 +1013,6 @@ ParseResult OperationParser::addDefinition(UnresolvedOperand useInfo,
     // the actual definition instead, delete the forward ref, and remove it
     // from our set of forward references we track.
     existing.replaceAllUsesWith(value);
-    existing.getDefiningOp()->destroy();
     forwardRefPlaceholders.erase(existing);
 
     // If a definition of the value already exists, replace it in the assembly
@@ -1194,6 +1199,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
       /*attributes=*/std::nullopt, /*properties=*/nullptr, /*successors=*/{},
       /*numRegions=*/0);
   forwardRefPlaceholders[op->getResult(0)] = loc;
+  forwardRefOps.insert(op);
   return op->getResult(0);
 }
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 448141735ba7f..1f70ad57d986b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -499,8 +499,20 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
   else if (parser.parseType(type))
     return failure();
 
-  // Resolve input operands.
+  // Set block argument types, so that they are known when parsing the region.
   regionArgs.front().type = type;
+  for (auto [iterArg, type] :
+       llvm::zip(llvm::drop_begin(regionArgs), result.types))
+    iterArg.type = type;
+
+  // Parse the body region.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+  ForOp::ensureTerminator(*body, builder, result.location);
+
+  // Resolve input operands. This should be done after parsing the region to
+  // catch invalid IR where operands were defined inside of the region.
   if (parser.resolveOperand(lb, type, result.operands) ||
       parser.resolveOperand(ub, type, result.operands) ||
       parser.resolveOperand(step, type, result.operands))
@@ -516,13 +528,6 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
     }
   }
 
-  // Parse the body region.
-  Region *body = result.addRegion();
-  if (parser.parseRegion(*body, regionArgs))
-    return failure();
-
-  ForOp::ensureTerminator(*body, builder, result.location);
-
   // Parse the optional attribute list.
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 80576be880127..76c785f3e6166 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -747,3 +747,13 @@ func.func @parallel_missing_terminator(%0 : index) {
   return
 }
 
+// -----
+
+func.func @invalid_reference(%a: index) {
+  // expected-error @below{{use of undeclared SSA value name}}
+  scf.for %x = %a to %a step %a iter_args(%var = %foo) -> tensor<?xf32> {
+    %foo = "test.inner"() : () -> (tensor<?xf32>)
+    scf.yield %foo : tensor<?xf32>
+  }
+  return
+}



More information about the Mlir-commits mailing list