[Mlir-commits] [mlir] [mlir][Parser] Fix use-after-free when parsing invalid reference to nested definition (PR #127778)

Matthias Springer llvmlistbot at llvm.org
Wed Feb 19 23:33:01 PST 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/127778

>From 5b9a4e1ab4b08b354d9581cae4f5e2ad608d819e Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 19 Feb 2025 11:15:44 +0100
Subject: [PATCH 1/2] fix scf.for parser

---
 mlir/lib/AsmParser/Parser.cpp      | 14 ++++++++++----
 mlir/lib/Dialect/SCF/IR/SCF.cpp    | 21 +++++++++++++--------
 mlir/test/Dialect/SCF/invalid.mlir | 10 ++++++++++
 3 files changed, 33 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index b5f1d2e27c9ba..2982757a6c5ce 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -820,6 +820,12 @@ class OperationParser : public Parser {
   /// their first reference, to allow checking for use of undefined values.
   DenseMap<Value, SMLoc> forwardRefPlaceholders;
 
+  /// Operations that define the placeholders. These are kept until the end of
+  /// of the lifetime of the parser because some custom parsers may store
+  /// references to them in local state and use them after forward references
+  /// have been resolved.
+  DenseSet<Operation *> forwardRefOps;
+
   /// Deffered locations: when parsing `loc(#loc42)` we add an entry to this
   /// map. After parsing the definition `#loc42 = ...` we'll patch back users
   /// of this location.
@@ -847,11 +853,11 @@ OperationParser::OperationParser(ParserState &state, ModuleOp topLevelOp)
 }
 
 OperationParser::~OperationParser() {
-  for (auto &fwd : forwardRefPlaceholders) {
+  for (Operation *op : forwardRefOps) {
     // Drop all uses of undefined forward declared reference and destroy
     // defining operation.
-    fwd.first.dropAllUses();
-    fwd.first.getDefiningOp()->destroy();
+    op->dropAllUses();
+    op->destroy();
   }
   for (const auto &scope : forwardRef) {
     for (const auto &fwd : scope) {
@@ -1007,7 +1013,6 @@ ParseResult OperationParser::addDefinition(UnresolvedOperand useInfo,
     // the actual definition instead, delete the forward ref, and remove it
     // from our set of forward references we track.
     existing.replaceAllUsesWith(value);
-    existing.getDefiningOp()->destroy();
     forwardRefPlaceholders.erase(existing);
 
     // If a definition of the value already exists, replace it in the assembly
@@ -1194,6 +1199,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
       /*attributes=*/std::nullopt, /*properties=*/nullptr, /*successors=*/{},
       /*numRegions=*/0);
   forwardRefPlaceholders[op->getResult(0)] = loc;
+  forwardRefOps.insert(op);
   return op->getResult(0);
 }
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 448141735ba7f..1f70ad57d986b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -499,8 +499,20 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
   else if (parser.parseType(type))
     return failure();
 
-  // Resolve input operands.
+  // Set block argument types, so that they are known when parsing the region.
   regionArgs.front().type = type;
+  for (auto [iterArg, type] :
+       llvm::zip(llvm::drop_begin(regionArgs), result.types))
+    iterArg.type = type;
+
+  // Parse the body region.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+  ForOp::ensureTerminator(*body, builder, result.location);
+
+  // Resolve input operands. This should be done after parsing the region to
+  // catch invalid IR where operands were defined inside of the region.
   if (parser.resolveOperand(lb, type, result.operands) ||
       parser.resolveOperand(ub, type, result.operands) ||
       parser.resolveOperand(step, type, result.operands))
@@ -516,13 +528,6 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
     }
   }
 
-  // Parse the body region.
-  Region *body = result.addRegion();
-  if (parser.parseRegion(*body, regionArgs))
-    return failure();
-
-  ForOp::ensureTerminator(*body, builder, result.location);
-
   // Parse the optional attribute list.
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 80576be880127..76c785f3e6166 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -747,3 +747,13 @@ func.func @parallel_missing_terminator(%0 : index) {
   return
 }
 
+// -----
+
+func.func @invalid_reference(%a: index) {
+  // expected-error @below{{use of undeclared SSA value name}}
+  scf.for %x = %a to %a step %a iter_args(%var = %foo) -> tensor<?xf32> {
+    %foo = "test.inner"() : () -> (tensor<?xf32>)
+    scf.yield %foo : tensor<?xf32>
+  }
+  return
+}

>From d6177255b8046d88035bccbbd9e67adcb96653ac Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 20 Feb 2025 08:32:39 +0100
Subject: [PATCH 2/2] address comments

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 1f70ad57d986b..1cfb866db0b51 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -502,7 +502,7 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
   // Set block argument types, so that they are known when parsing the region.
   regionArgs.front().type = type;
   for (auto [iterArg, type] :
-       llvm::zip(llvm::drop_begin(regionArgs), result.types))
+       llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
     iterArg.type = type;
 
   // Parse the body region.
@@ -518,8 +518,8 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.resolveOperand(step, type, result.operands))
     return failure();
   if (hasIterArgs) {
-    for (auto argOperandType :
-         llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+    for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
+                                               operands, result.types)) {
       Type type = std::get<2>(argOperandType);
       std::get<0>(argOperandType).type = type;
       if (parser.resolveOperand(std::get<1>(argOperandType), type,



More information about the Mlir-commits mailing list