[Mlir-commits] [mlir] [mlir][GPU] gpu.printf: Do not emit duplicate format strings (PR #110504)
Matthias Springer
llvmlistbot at llvm.org
Mon Sep 30 05:56:59 PDT 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/110504
Even if the same format string is used multiple times, emit just one `LLVM:GlobalOp`.
>From 4d253de027df60c92816cd2f386b8614e5aaa9d3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 30 Sep 2024 14:55:27 +0200
Subject: [PATCH] [mlir][GPU] gpu.printf: Do not emit duplicate format strings
Even if the same format string is used multiple times, emit just one `LLVM:GlobalOp`.
---
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 106 ++++++++----------
.../Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 7 ++
2 files changed, 56 insertions(+), 57 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 5b590a457f7714..06d759b5f54175 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -340,6 +340,34 @@ static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
return stringConstName;
}
+/// Create an global that contains the given format string. If a global with
+/// the same format string exists already in the module, return that global.
+static LLVM::GlobalOp getOrCreateFormatStringConstant(
+ OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
+ StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
+ llvm::SmallString<20> formatString(str);
+ formatString.push_back('\0'); // Null terminate for C
+ auto globalType =
+ LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
+ StringAttr attr = b.getStringAttr(formatString);
+
+ // Try to find existing global.
+ for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+ if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
+ globalOp.getValueAttr() == attr &&
+ globalOp.getAlignment().value_or(0) == alignment &&
+ globalOp.getAddrSpace() == addrSpace)
+ return globalOp;
+
+ // Not found: create new global.
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleOp.getBody());
+ SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
+ return b.create<LLVM::GlobalOp>(loc, globalType,
+ /*isConstant=*/true, LLVM::Linkage::Internal,
+ name, attr, alignment, addrSpace);
+}
+
template <typename T>
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
ConversionPatternRewriter &rewriter,
@@ -391,33 +419,20 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult();
- // Get a unique global name for the format.
- SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
- llvm::SmallString<20> formatString(adaptor.getFormat());
- formatString.push_back('\0'); // Null terminate for C
- size_t formatStringSize = formatString.size_in_bytes();
-
- auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
- LLVM::GlobalOp global;
- {
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(moduleOp.getBody());
- global = rewriter.create<LLVM::GlobalOp>(
- loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
- rewriter.getStringAttr(formatString));
- }
+ // Create the global op or find an existing one.
+ LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+ rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
// Get a pointer to the format string's first element and pass it to printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
- Value stringStart = rewriter.create<LLVM::GEPOp>(
- loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
- Value stringLen =
- rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
+ Value stringStart =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ Value stringLen = rewriter.create<LLVM::ConstantOp>(
+ loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
@@ -486,30 +501,19 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
LLVM::LLVMFuncOp printfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
- // Get a unique global name for the format.
- SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
- llvm::SmallString<20> formatString(adaptor.getFormat());
- formatString.push_back('\0'); // Null terminate for C
- auto globalType =
- LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
- LLVM::GlobalOp global;
- {
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(moduleOp.getBody());
- global = rewriter.create<LLVM::GlobalOp>(
- loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
- rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
- }
+ // Create the global op or find an existing one.
+ LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+ rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
+ addressSpace);
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
- Value stringStart = rewriter.create<LLVM::GEPOp>(
- loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ Value stringStart =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
// Construct arguments and function call
auto argsRange = adaptor.getArgs();
@@ -541,27 +545,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
LLVM::LLVMFuncOp vprintfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
- // Get a unique global name for the format.
- SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
- llvm::SmallString<20> formatString(adaptor.getFormat());
- formatString.push_back('\0'); // Null terminate for C
- auto globalType =
- LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
- LLVM::GlobalOp global;
- {
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(moduleOp.getBody());
- global = rewriter.create<LLVM::GlobalOp>(
- loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
- rewriter.getStringAttr(formatString), /*allignment=*/0);
- }
+ // Create the global op or find an existing one.
+ LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+ rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
- Value stringStart = rewriter.create<LLVM::GEPOp>(
- loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ Value stringStart =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
SmallVector<Type> types;
SmallVector<Value> args;
// Promote and pack the arguments into a stack allocation.
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index ad4e9ec1791a77..748dfe8c68fc7e 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -610,6 +610,13 @@ gpu.module @test_module_29 {
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
gpu.printf "Hello, world\n"
+
+ // Make sure that the same global is reused.
+ // CHECK: %[[FORMATSTR2:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
+ // CHECK: %[[FORMATSTART2:.*]] = llvm.getelementptr %[[FORMATSTR2]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+ // CHECK: llvm.call @vprintf(%[[FORMATSTART2]], %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
+ gpu.printf "Hello, world\n"
+
gpu.return
}
More information about the Mlir-commits
mailing list