[llvm] Fix use after free error in NVVMReflect (PR #81471)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 12 03:55:07 PST 2024


https://github.com/PeterZhizhin created https://github.com/llvm/llvm-project/pull/81471

I have a Triton kernel, which triggered a heap-use-after-free error in LLVM.

The problem was that the same instruction may be added to the `ToSimplify` array multiple times. If this duplicate instruction is trivially dead, it gets deleted on the first pass. Then, on the second pass, the freed instruction is passed.

To fix this, I'm adding the instructions to the `ToRemove` array and filter it out for duplicates to avoid possible double frees.

>From 8ba8c2970c0fa7d95b0ee3cbe2499c2983f929e6 Mon Sep 17 00:00:00 2001
From: Peter Zhizhin <pzhizhin at google.com>
Date: Mon, 12 Feb 2024 12:50:10 +0100
Subject: [PATCH] Fix use after free error in LLVM

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 18 +++++++++++++-----
 1 file changed, 13 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 64fedf32e9a269..29c95e4226bf40 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -39,6 +39,7 @@
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
+#include <algorithm>
 #include <sstream>
 #include <string>
 #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
@@ -185,9 +186,6 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
     ToRemove.push_back(Call);
   }
 
-  for (Instruction *I : ToRemove)
-    I->eraseFromParent();
-
   // The code guarded by __nvvm_reflect may be invalid for the target machine.
   // Traverse the use-def chain, continually simplifying constant expressions
   // until we find a terminator that we can then remove.
@@ -200,13 +198,23 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
           ToSimplify.push_back(I);
 
       I->replaceAllUsesWith(C);
-      if (isInstructionTriviallyDead(I))
-        I->eraseFromParent();
+      if (isInstructionTriviallyDead(I)) {
+        ToRemove.push_back(I);
+      }
     } else if (I->isTerminator()) {
       ConstantFoldTerminator(I->getParent());
     }
   }
 
+  // Removing via isInstructionTriviallyDead may add duplicates to the ToRemove
+  // array. Filter out the duplicates before starting to erase from parent.
+  std::sort(ToRemove.begin(), ToRemove.end());
+  auto NewLastIter = std::unique(ToRemove.begin(), ToRemove.end());
+  ToRemove.erase(NewLastIter, ToRemove.end());
+
+  for (Instruction *I : ToRemove)
+    I->eraseFromParent();
+
   return ToRemove.size() > 0;
 }
 



More information about the llvm-commits mailing list