[llvm-branch-commits] [flang] [flang][CodeGen] Fix use-after-free in BoxedProcedurePass (PR #84376)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Mar 7 12:27:58 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-codegen
Author: Krzysztof Parzyszek (kparzysz)
<details>
<summary>Changes</summary>
Replacing an element of an operation range while traversing the range can make the range invalid. Store the operations in a separate list, and traverse the list instead.
Additionally, avoid inspecting an operation that has been replaced.
This was detected by address sanitizer.
---
Full diff: https://github.com/llvm/llvm-project/pull/84376.diff
1 Files Affected:
- (modified) flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (+16-3)
``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index 4cf39716a73755..2e34b0a1b492b1 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -208,7 +208,12 @@ class BoxedProcedurePass
mlir::IRRewriter rewriter(context);
BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
mlir::Dialect *firDialect = context->getLoadedDialect("fir");
- getModule().walk([&](mlir::Operation *op) {
+ llvm::SmallVector<mlir::Operation *> operations;
+
+ getModule().walk([&](mlir::Operation *op) { operations.push_back(op); });
+
+ for (mlir::Operation *op : operations) {
+ bool opIsValid = true;
typeConverter.setLocation(op->getLoc());
if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
mlir::Type ty = addr.getVal().getType();
@@ -220,6 +225,7 @@ class BoxedProcedurePass
rewriter.setInsertionPoint(addr);
rewriter.replaceOpWithNewOp<ConvertOp>(
addr, typeConverter.convertType(addr.getType()), addr.getVal());
+ opIsValid = false;
} else if (typeConverter.needsConversion(resTy)) {
rewriter.startOpModification(op);
op->getResult(0).setType(typeConverter.convertType(resTy));
@@ -271,10 +277,12 @@ class BoxedProcedurePass
llvm::ArrayRef<mlir::Value>{tramp});
rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
adjustCall.getResult(0));
+ opIsValid = false;
} else {
// Just forward the function as a pointer.
rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
embox.getFunc());
+ opIsValid = false;
}
} else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
auto ty = global.getType();
@@ -297,6 +305,7 @@ class BoxedProcedurePass
rewriter.replaceOpWithNewOp<AllocaOp>(
mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
mem.getShape());
+ opIsValid = false;
}
} else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
auto ty = mem.getType();
@@ -310,6 +319,7 @@ class BoxedProcedurePass
rewriter.replaceOpWithNewOp<AllocMemOp>(
mem, toTy, uniqName, bindcName, mem.getTypeparams(),
mem.getShape());
+ opIsValid = false;
}
} else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
auto ty = coor.getType();
@@ -321,6 +331,7 @@ class BoxedProcedurePass
auto toBaseTy = typeConverter.convertType(baseTy);
rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
coor.getCoor(), toBaseTy);
+ opIsValid = false;
}
} else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
auto ty = index.getType();
@@ -332,6 +343,7 @@ class BoxedProcedurePass
auto toOnTy = typeConverter.convertType(onTy);
rewriter.replaceOpWithNewOp<FieldIndexOp>(
index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
+ opIsValid = false;
}
} else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
auto ty = index.getType();
@@ -343,6 +355,7 @@ class BoxedProcedurePass
auto toOnTy = typeConverter.convertType(onTy);
rewriter.replaceOpWithNewOp<LenParamIndexOp>(
index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
+ opIsValid = false;
}
} else if (op->getDialect() == firDialect) {
rewriter.startOpModification(op);
@@ -354,7 +367,7 @@ class BoxedProcedurePass
rewriter.finalizeOpModification(op);
}
// Ensure block arguments are updated if needed.
- if (op->getNumRegions() != 0) {
+ if (opIsValid && op->getNumRegions() != 0) {
rewriter.startOpModification(op);
for (mlir::Region ®ion : op->getRegions())
for (mlir::Block &block : region.getBlocks())
@@ -366,7 +379,7 @@ class BoxedProcedurePass
}
rewriter.finalizeOpModification(op);
}
- });
+ }
}
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/84376
More information about the llvm-branch-commits
mailing list