[Mlir-commits] [mlir] [MLIR][LLVM] Fix memory explosion when converting global variable bodies in ModuleTranslation (PR #82708)

Xiang Li llvmlistbot at llvm.org
Sat Feb 24 09:32:32 PST 2024


https://github.com/python3kgae updated https://github.com/llvm/llvm-project/pull/82708

>From cc7b35a2fe31dd84a0e6ddfd07bc6a6f9c62196e Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Thu, 22 Feb 2024 18:38:21 -0500
Subject: [PATCH 1/3] [MLIR][LLVM] Fix memory explosion when converting global
 variable bodies in ModuleTranslation

There is memory explosion when converting the body or initializer region of a large global variable, e.g. a constant array.

For example, when translating a constant array of 100000 strings:

llvm.mlir.global internal constant @cats_strings() {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<100000 x ptr<i8>> {
    %0 = llvm.mlir.undef : !llvm.array<100000 x ptr<i8>>
    %1 = llvm.mlir.addressof @om_1 : !llvm.ptr<array<1 x i8>>
    %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8>
    %3 = llvm.insertvalue %2, %0[0] : !llvm.array<100000 x ptr<i8>>
    %4 = llvm.mlir.addressof @om_2 : !llvm.ptr<array<1 x i8>>
    %5 = llvm.getelementptr %4[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8>
    %6 = llvm.insertvalue %5, %3[1] : !llvm.array<100000 x ptr<i8>>
    %7 = llvm.mlir.addressof @om_3 : !llvm.ptr<array<1 x i8>>
    %8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8>
    %9 = llvm.insertvalue %8, %6[2] : !llvm.array<100000 x ptr<i8>>
    %10 = llvm.mlir.addressof @om_4 : !llvm.ptr<array<1 x i8>>
    %11 = llvm.getelementptr %10[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8>
    %12 = llvm.insertvalue %11, %9[3] : !llvm.array<100000 x ptr<i8>>

    ... (ignore the remaining part)
}

where @om_1, @om_2, ... are string global constants.

Each time an operation is converted to LLVM, a new constant is created.
When it comes to llvm.insertvalue, a new constant array of 100000 elements is created and the old constant array (input) is not destroyed.
This causes memory explosion. We observed that, on a system with 128 GB memory, the translation of 100000 elements got killed due to using up all the memory.
On a system with 64 GB, 65536 elements was enough to cause the translation killed.

There is a previous patch (https://reviews.llvm.org/D148487) which fix this issue but was reverted for
https://github.com/llvm/llvm-project/issues/62802

The old patch checks generated constants and destroyed them if there is no use.
But the check of use for the constant is too early, which cause the constant be removed before use.


This new patch added a map was added a map to save expected use count for a constant.
Then decrease when reach each use.
And only erase the constant when the use count reach to zero

With new patch, the repro in https://github.com/llvm/llvm-project/issues/62802 finished correctly.
---
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  | 64 +++++++++++++++++
 .../LLVMIR/erase-dangling-constants.mlir      | 72 +++++++++++++++++++
 2 files changed, 136 insertions(+)
 create mode 100644 mlir/test/Target/LLVMIR/erase-dangling-constants.mlir

diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ee8fffd959c883..64c37b1d5fa961 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -51,11 +51,15 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 #include <optional>
 
+#define DEBUG_TYPE "llvm-dialect-to-llvm-ir"
+
 using namespace mlir;
 using namespace mlir::LLVM;
 using namespace mlir::LLVM::detail;
@@ -1042,17 +1046,77 @@ LogicalResult ModuleTranslation::convertGlobals() {
   for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
     if (Block *initializer = op.getInitializerBlock()) {
       llvm::IRBuilder<> builder(llvmModule->getContext());
+
+      int numConstantsHit = 0;
+      int numConstantsErased = 0;
+      DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap;
+
       for (auto &op : initializer->without_terminator()) {
         if (failed(convertOperation(op, builder)) ||
             !isa<llvm::Constant>(lookupValue(op.getResult(0))))
           return emitError(op.getLoc(), "unemittable constant value");
+
+        // When emitting an LLVM constant, a new constant is created and the old
+        // constant may become dangling and take space. We should remove the
+        // dangling constants to avoid memory explosion especially for constant
+        // arrays whose number of elements is large.
+        // Because multiple operations may refer to the same constant, we need
+        // to count the number of uses of each constant array and remove it only
+        // when the count becomes zero.
+        if (op.getNumResults() == 1) {
+          Value result = op.getResult(0);
+          auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(result));
+          if (!cst)
+            continue;
+          numConstantsHit++;
+          auto iter = constantAggregateUseMap.find(cst);
+          int numUsers = std::distance(result.use_begin(), result.use_end());
+          if (iter == constantAggregateUseMap.end())
+            constantAggregateUseMap.try_emplace(cst, numUsers);
+          else
+            iter->second += numUsers;
+        }
+        for (Value v : op.getOperands()) {
+          auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v));
+          if (!cst)
+            continue;
+          auto iter = constantAggregateUseMap.find(cst);
+          assert(iter != constantAggregateUseMap.end() && "constant not found");
+          iter->second--;
+          if (iter->second == 0) {
+            cst->removeDeadConstantUsers();
+            if (cst->user_empty()) {
+              cst->destroyConstant();
+              numConstantsErased++;
+            }
+            constantAggregateUseMap.erase(iter);
+          }
+        }
       }
+
       ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
       llvm::Constant *cst =
           cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
       auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
       if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
         global->setInitializer(cst);
+
+      // Try to remove the dangling constants again after all operations are
+      // converted.
+      for (auto it : constantAggregateUseMap) {
+        auto cst = it.first;
+        cst->removeDeadConstantUsers();
+        if (cst->user_empty()) {
+          cst->destroyConstant();
+          numConstantsErased++;
+        }
+      }
+
+      LLVM_DEBUG(llvm::dbgs()
+                     << "Convert initializer for " << op.getName() << "\n";
+                 llvm::dbgs() << numConstantsHit << " new constants hit\n";
+                 llvm::dbgs()
+                 << numConstantsErased << " dangling constants erased\n";);
     }
   }
 
diff --git a/mlir/test/Target/LLVMIR/erase-dangling-constants.mlir b/mlir/test/Target/LLVMIR/erase-dangling-constants.mlir
new file mode 100644
index 00000000000000..b3b5d540ae88fc
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/erase-dangling-constants.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-translate -mlir-to-llvmir %s -debug-only=llvm-dialect-to-llvm-ir 2>&1 | FileCheck %s
+
+// CHECK: Convert initializer for dup_const
+// CHECK: 6 new constants hit
+// CHECK: 3 dangling constants erased
+// CHECK: Convert initializer for unique_const
+// CHECK: 6 new constants hit
+// CHECK: 5 dangling constants erased
+
+
+// CHECK:@dup_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02] }
+
+llvm.mlir.global @dup_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
+    %c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
+    %c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64
+
+    %empty0 = llvm.mlir.undef : !llvm.array<2 x f64>
+    %a00 = llvm.insertvalue %c0, %empty0[0] : !llvm.array<2 x f64>
+
+    %empty1 = llvm.mlir.undef : !llvm.array<2 x f64>
+    %a10 = llvm.insertvalue %c0, %empty1[0] : !llvm.array<2 x f64>
+
+    %empty2 = llvm.mlir.undef : !llvm.array<2 x f64>
+    %a20 = llvm.insertvalue %c0, %empty2[0] : !llvm.array<2 x f64>
+
+// NOTE: a00, a10, a20 are all same ConstantAggregate which not used at this point.
+//       should not delete it before all of the uses of the ConstantAggregate finished.
+
+    %a01 = llvm.insertvalue %c1, %a00[1] : !llvm.array<2 x f64>
+    %a11 = llvm.insertvalue %c1, %a10[1] : !llvm.array<2 x f64>
+    %a21 = llvm.insertvalue %c1, %a20[1] : !llvm.array<2 x f64>
+    %empty_r = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+    %r0 = llvm.insertvalue %a01, %empty_r[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+    %r1 = llvm.insertvalue %a11, %r0[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+    %r2 = llvm.insertvalue %a21, %r1[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+
+    llvm.return %r2 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+  }
+
+// CHECK:@unique_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.312250e-02, double 5.219230e-02], [2 x double] [double 3.412250e-02, double 5.419230e-02] }
+
+llvm.mlir.global @unique_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
+    %c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
+    %c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64
+
+    %c2 = llvm.mlir.constant(3.312250e-02 : f64) : f64
+    %c3 = llvm.mlir.constant(5.219230e-02 : f64) : f64
+
+    %c4 = llvm.mlir.constant(3.412250e-02 : f64) : f64
+    %c5 = llvm.mlir.constant(5.419230e-02 : f64) : f64
+
+    %2 = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+
+    %3 = llvm.mlir.undef : !llvm.array<2 x f64>
+
+    %4 = llvm.insertvalue %c0, %3[0] : !llvm.array<2 x f64>
+    %5 = llvm.insertvalue %c1, %4[1] : !llvm.array<2 x f64>
+
+    %6 = llvm.insertvalue %5, %2[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+
+    %7 = llvm.insertvalue %c2, %3[0] : !llvm.array<2 x f64>
+    %8 = llvm.insertvalue %c3, %7[1] : !llvm.array<2 x f64>
+
+    %9 = llvm.insertvalue %8, %6[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+
+    %10 = llvm.insertvalue %c4, %3[0] : !llvm.array<2 x f64>
+    %11 = llvm.insertvalue %c5, %10[1] : !llvm.array<2 x f64>
+
+    %12 = llvm.insertvalue %11, %9[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+
+    llvm.return %12 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+}

>From 7705e97055ce52b10cf53132a1c5824992bddc65 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Sat, 24 Feb 2024 12:27:57 -0500
Subject: [PATCH 2/3] Save lookup result for op.getResult(0). Use try_emplace
 for map.

---
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 28 ++++++++++----------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 64c37b1d5fa961..dd6482a173cb61 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1052,8 +1052,10 @@ LogicalResult ModuleTranslation::convertGlobals() {
       DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap;
 
       for (auto &op : initializer->without_terminator()) {
-        if (failed(convertOperation(op, builder)) ||
-            !isa<llvm::Constant>(lookupValue(op.getResult(0))))
+        if (failed(convertOperation(op, builder)))
+          return emitError(op.getLoc(), "fail to convert global initializer");
+        auto *cst = dyn_cast<llvm::Constant>(lookupValue(op.getResult(0)));
+        if (!cst)
           return emitError(op.getLoc(), "unemittable constant value");
 
         // When emitting an LLVM constant, a new constant is created and the old
@@ -1063,18 +1065,16 @@ LogicalResult ModuleTranslation::convertGlobals() {
         // Because multiple operations may refer to the same constant, we need
         // to count the number of uses of each constant array and remove it only
         // when the count becomes zero.
-        if (op.getNumResults() == 1) {
-          Value result = op.getResult(0);
-          auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(result));
-          if (!cst)
-            continue;
-          numConstantsHit++;
-          auto iter = constantAggregateUseMap.find(cst);
-          int numUsers = std::distance(result.use_begin(), result.use_end());
-          if (iter == constantAggregateUseMap.end())
-            constantAggregateUseMap.try_emplace(cst, numUsers);
-          else
-            iter->second += numUsers;
+        if (auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) {
+           numConstantsHit++;
+           Value result = op.getResult(0);
+           int numUsers = std::distance(result.use_begin(), result.use_end());
+           auto [iterator, inserted] =
+               constantAggregateUseMap.try_emplace(agg, numUsers);
+           if (!inserted) {
+             // Key already exists, update the value
+             iterator->second += numUsers;
+           }
         }
         for (Value v : op.getOperands()) {
           auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v));

>From 9dde90227a452072b3287c844c18b676656ebf8d Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Sat, 24 Feb 2024 12:32:15 -0500
Subject: [PATCH 3/3] Fix format.

---
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index dd6482a173cb61..32b16a4dd9c8a2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1066,15 +1066,15 @@ LogicalResult ModuleTranslation::convertGlobals() {
         // to count the number of uses of each constant array and remove it only
         // when the count becomes zero.
         if (auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) {
-           numConstantsHit++;
-           Value result = op.getResult(0);
-           int numUsers = std::distance(result.use_begin(), result.use_end());
-           auto [iterator, inserted] =
-               constantAggregateUseMap.try_emplace(agg, numUsers);
-           if (!inserted) {
-             // Key already exists, update the value
-             iterator->second += numUsers;
-           }
+          numConstantsHit++;
+          Value result = op.getResult(0);
+          int numUsers = std::distance(result.use_begin(), result.use_end());
+          auto [iterator, inserted] =
+              constantAggregateUseMap.try_emplace(agg, numUsers);
+          if (!inserted) {
+            // Key already exists, update the value
+            iterator->second += numUsers;
+          }
         }
         for (Value v : op.getOperands()) {
           auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v));



More information about the Mlir-commits mailing list