[flang-commits] [flang] [flang][CodeGen] Fix use-after-free in BoxedProcedurePass (PR #84376)

Krzysztof Parzyszek via flang-commits flang-commits at lists.llvm.org
Fri Mar 8 05:39:31 PST 2024


https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/84376

>From 82c0d1ee3dd93c631234efa5f767217f786a1275 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 7 Mar 2024 14:20:25 -0600
Subject: [PATCH 1/3] [flang][CodeGen] Replace correct op in BoxedProcedurePass

---
 flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index 846a78931acba7..4cf39716a73755 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -342,7 +342,7 @@ class BoxedProcedurePass
             auto toTy = typeConverter.convertType(ty);
             auto toOnTy = typeConverter.convertType(onTy);
             rewriter.replaceOpWithNewOp<LenParamIndexOp>(
-                mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
+                index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
           }
         } else if (op->getDialect() == firDialect) {
           rewriter.startOpModification(op);

>From 58cda6db7030e178fbd861312dcee372e1558611 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 7 Mar 2024 14:21:45 -0600
Subject: [PATCH 2/3] [flang][CodeGen] Fix use-after-free in BoxedProcedurePass

Replacing an element of an operation range while traversing the range
can make the range invalid. Store the operations in a separate list,
and traverse the list instead.
Additionally, avoid inspecting an operation that has been replaced.

This was detected by address sanitizer.
---
 .../lib/Optimizer/CodeGen/BoxedProcedure.cpp  | 19 ++++++++++++++++---
 1 file changed, 16 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index 4cf39716a73755..2e34b0a1b492b1 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -208,7 +208,12 @@ class BoxedProcedurePass
       mlir::IRRewriter rewriter(context);
       BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
       mlir::Dialect *firDialect = context->getLoadedDialect("fir");
-      getModule().walk([&](mlir::Operation *op) {
+      llvm::SmallVector<mlir::Operation *> operations;
+
+      getModule().walk([&](mlir::Operation *op) { operations.push_back(op); });
+
+      for (mlir::Operation *op : operations) {
+        bool opIsValid = true;
         typeConverter.setLocation(op->getLoc());
         if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
           mlir::Type ty = addr.getVal().getType();
@@ -220,6 +225,7 @@ class BoxedProcedurePass
             rewriter.setInsertionPoint(addr);
             rewriter.replaceOpWithNewOp<ConvertOp>(
                 addr, typeConverter.convertType(addr.getType()), addr.getVal());
+            opIsValid = false;
           } else if (typeConverter.needsConversion(resTy)) {
             rewriter.startOpModification(op);
             op->getResult(0).setType(typeConverter.convertType(resTy));
@@ -271,10 +277,12 @@ class BoxedProcedurePass
                 llvm::ArrayRef<mlir::Value>{tramp});
             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
                                                    adjustCall.getResult(0));
+            opIsValid = false;
           } else {
             // Just forward the function as a pointer.
             rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
                                                    embox.getFunc());
+            opIsValid = false;
           }
         } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
           auto ty = global.getType();
@@ -297,6 +305,7 @@ class BoxedProcedurePass
             rewriter.replaceOpWithNewOp<AllocaOp>(
                 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
                 mem.getShape());
+            opIsValid = false;
           }
         } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
           auto ty = mem.getType();
@@ -310,6 +319,7 @@ class BoxedProcedurePass
             rewriter.replaceOpWithNewOp<AllocMemOp>(
                 mem, toTy, uniqName, bindcName, mem.getTypeparams(),
                 mem.getShape());
+            opIsValid = false;
           }
         } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
           auto ty = coor.getType();
@@ -321,6 +331,7 @@ class BoxedProcedurePass
             auto toBaseTy = typeConverter.convertType(baseTy);
             rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
                                                       coor.getCoor(), toBaseTy);
+            opIsValid = false;
           }
         } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
           auto ty = index.getType();
@@ -332,6 +343,7 @@ class BoxedProcedurePass
             auto toOnTy = typeConverter.convertType(onTy);
             rewriter.replaceOpWithNewOp<FieldIndexOp>(
                 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
+            opIsValid = false;
           }
         } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
           auto ty = index.getType();
@@ -343,6 +355,7 @@ class BoxedProcedurePass
             auto toOnTy = typeConverter.convertType(onTy);
             rewriter.replaceOpWithNewOp<LenParamIndexOp>(
                 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
+            opIsValid = false;
           }
         } else if (op->getDialect() == firDialect) {
           rewriter.startOpModification(op);
@@ -354,7 +367,7 @@ class BoxedProcedurePass
           rewriter.finalizeOpModification(op);
         }
         // Ensure block arguments are updated if needed.
-        if (op->getNumRegions() != 0) {
+        if (opIsValid && op->getNumRegions() != 0) {
           rewriter.startOpModification(op);
           for (mlir::Region &region : op->getRegions())
             for (mlir::Block &block : region.getBlocks())
@@ -366,7 +379,7 @@ class BoxedProcedurePass
                 }
           rewriter.finalizeOpModification(op);
         }
-      });
+      }
     }
   }
 

>From 05be5a25ef093dd958021aa70863ea962ec672d5 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Fri, 8 Mar 2024 07:36:32 -0600
Subject: [PATCH 3/3] Caching ops is unnecessary, since it's a post-order walk

---
 flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp | 8 ++------
 1 file changed, 2 insertions(+), 6 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index 2e34b0a1b492b1..746c275f37eaca 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -208,11 +208,7 @@ class BoxedProcedurePass
       mlir::IRRewriter rewriter(context);
       BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
       mlir::Dialect *firDialect = context->getLoadedDialect("fir");
-      llvm::SmallVector<mlir::Operation *> operations;
-
-      getModule().walk([&](mlir::Operation *op) { operations.push_back(op); });
-
-      for (mlir::Operation *op : operations) {
+      getModule().walk([&](mlir::Operation *op) {
         bool opIsValid = true;
         typeConverter.setLocation(op->getLoc());
         if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
@@ -379,7 +375,7 @@ class BoxedProcedurePass
                 }
           rewriter.finalizeOpModification(op);
         }
-      }
+      });
     }
   }
 



More information about the flang-commits mailing list