[Mlir-commits] [mlir] [mlir][GPU] gpu.printf: Do not emit duplicate format strings (PR #110504)

Matthias Springer llvmlistbot at llvm.org
Mon Sep 30 05:56:59 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/110504

Even if the same format string is used multiple times, emit just one `LLVM:GlobalOp`.

>From 4d253de027df60c92816cd2f386b8614e5aaa9d3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 30 Sep 2024 14:55:27 +0200
Subject: [PATCH] [mlir][GPU] gpu.printf: Do not emit duplicate format strings

Even if the same format string is used multiple times, emit just one `LLVM:GlobalOp`.
---
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 106 ++++++++----------
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     |   7 ++
 2 files changed, 56 insertions(+), 57 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 5b590a457f7714..06d759b5f54175 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -340,6 +340,34 @@ static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
   return stringConstName;
 }
 
+/// Create an global that contains the given format string. If a global with
+/// the same format string exists already in the module, return that global.
+static LLVM::GlobalOp getOrCreateFormatStringConstant(
+    OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
+    StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
+  llvm::SmallString<20> formatString(str);
+  formatString.push_back('\0'); // Null terminate for C
+  auto globalType =
+      LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
+  StringAttr attr = b.getStringAttr(formatString);
+
+  // Try to find existing global.
+  for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+    if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
+        globalOp.getValueAttr() == attr &&
+        globalOp.getAlignment().value_or(0) == alignment &&
+        globalOp.getAddrSpace() == addrSpace)
+      return globalOp;
+
+  // Not found: create new global.
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointToStart(moduleOp.getBody());
+  SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
+  return b.create<LLVM::GlobalOp>(loc, globalType,
+                                  /*isConstant=*/true, LLVM::Linkage::Internal,
+                                  name, attr, alignment, addrSpace);
+}
+
 template <typename T>
 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
                                             ConversionPatternRewriter &rewriter,
@@ -391,33 +419,20 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
   auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
   Value printfDesc = printfBeginCall.getResult();
 
-  // Get a unique global name for the format.
-  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
-  llvm::SmallString<20> formatString(adaptor.getFormat());
-  formatString.push_back('\0'); // Null terminate for C
-  size_t formatStringSize = formatString.size_in_bytes();
-
-  auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
-  LLVM::GlobalOp global;
-  {
-    ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    global = rewriter.create<LLVM::GlobalOp>(
-        loc, globalType,
-        /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
-        rewriter.getStringAttr(formatString));
-  }
+  // Create the global op or find an existing one.
+  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
 
   // Get a pointer to the format string's first element and pass it to printf()
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
       loc,
       LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
       global.getSymNameAttr());
-  Value stringStart = rewriter.create<LLVM::GEPOp>(
-      loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
-  Value stringLen =
-      rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
+  Value stringStart =
+      rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+                                   globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+  Value stringLen = rewriter.create<LLVM::ConstantOp>(
+      loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
 
   Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
   Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
@@ -486,30 +501,19 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
   LLVM::LLVMFuncOp printfDecl =
       getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
 
-  // Get a unique global name for the format.
-  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
-  llvm::SmallString<20> formatString(adaptor.getFormat());
-  formatString.push_back('\0'); // Null terminate for C
-  auto globalType =
-      LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
-  LLVM::GlobalOp global;
-  {
-    ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    global = rewriter.create<LLVM::GlobalOp>(
-        loc, globalType,
-        /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
-        rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
-  }
+  // Create the global op or find an existing one.
+  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
+      addressSpace);
 
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
       loc,
       LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
       global.getSymNameAttr());
-  Value stringStart = rewriter.create<LLVM::GEPOp>(
-      loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+  Value stringStart =
+      rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+                                   globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
 
   // Construct arguments and function call
   auto argsRange = adaptor.getArgs();
@@ -541,27 +545,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   LLVM::LLVMFuncOp vprintfDecl =
       getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
 
-  // Get a unique global name for the format.
-  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
-  llvm::SmallString<20> formatString(adaptor.getFormat());
-  formatString.push_back('\0'); // Null terminate for C
-  auto globalType =
-      LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
-  LLVM::GlobalOp global;
-  {
-    ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    global = rewriter.create<LLVM::GlobalOp>(
-        loc, globalType,
-        /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
-        rewriter.getStringAttr(formatString), /*allignment=*/0);
-  }
+  // Create the global op or find an existing one.
+  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
 
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
-  Value stringStart = rewriter.create<LLVM::GEPOp>(
-      loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+  Value stringStart =
+      rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+                                   globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
   SmallVector<Type> types;
   SmallVector<Value> args;
   // Promote and pack the arguments into a stack allocation.
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index ad4e9ec1791a77..748dfe8c68fc7e 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -610,6 +610,13 @@ gpu.module @test_module_29 {
     // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
     // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
     gpu.printf "Hello, world\n"
+
+    // Make sure that the same global is reused.
+    // CHECK: %[[FORMATSTR2:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
+    // CHECK: %[[FORMATSTART2:.*]] = llvm.getelementptr %[[FORMATSTR2]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+    // CHECK: llvm.call @vprintf(%[[FORMATSTART2]], %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
+    gpu.printf "Hello, world\n"
+
     gpu.return
   }
 



More information about the Mlir-commits mailing list