[Mlir-commits] [mlir] e1da629 - [MLIR][GPU] Define gpu.printf op and its lowerings

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Dec 9 07:54:36 PST 2021


Author: Krzysztof Drewniak
Date: 2021-12-09T15:54:31Z
New Revision: e1da62910e140cf45eafec64193c813e79796f05

URL: https://github.com/llvm/llvm-project/commit/e1da62910e140cf45eafec64193c813e79796f05
DIFF: https://github.com/llvm/llvm-project/commit/e1da62910e140cf45eafec64193c813e79796f05.diff

LOG: [MLIR][GPU] Define gpu.printf op and its lowerings

- Define a gpu.printf op, which can be lowered to any GPU printf() support (which is present in CUDA, HIP, and OpenCL). This op only supports constant format strings and scalar arguments
- Define the lowering of gpu.pirntf to a call to printf() (which is what is required for AMD GPUs when using OpenCL) as well as to the hostcall interface present in the AMD Open Compute device library, which is the interface present when kernels are running under HIP.
- Add a "runtime" enum that allows specifying which of the possible runtimes a ROCDL kernel will be executed under or that the runtime is unknown. This enum controls how gpu.printf is lowered

This change does not enable lowering for Nvidia GPUs, but such a lowering should be possible in principle.

And:
[MLIR][AMDGPU] Always set amdgpu-implicitarg-num-bytes=56 on kernels

This is something that Clang always sets on both OpenCL and HIP kernels, and failing to include it causes mysterious crashes with printf() support.

In addition, revert the max-flat-work-group-size to (1, 256) to avoid triggering bugs in the AMDGPU backend.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D110448

Added: 
    mlir/include/mlir/Conversion/GPUToROCDL/Runtimes.h
    mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
    mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir
    mlir/test/Integration/GPU/ROCM/printf.mlir

Modified: 
    mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/lib/Conversion/PassDetail.h
    mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
    mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
    mlir/test/Dialect/GPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
index 570a97633e7fd..83d8a08bb5a0b 100644
--- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
+++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
@@ -8,6 +8,7 @@
 #ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
 #define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
 
+#include "mlir/Conversion/GPUToROCDL/Runtimes.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include <memory>
 
@@ -25,8 +26,11 @@ class GPUModuleOp;
 } // namespace gpu
 
 /// Collect a set of patterns to convert from the GPU dialect to ROCDL.
+/// If `runtime` is Unknown, gpu.printf will not be lowered
+/// The resulting pattern set should be run over a gpu.module op
 void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
-                                          RewritePatternSet &patterns);
+                                          RewritePatternSet &patterns,
+                                          gpu::amd::Runtime runtime);
 
 /// Configure target to convert from the GPU dialect to ROCDL.
 void configureGpuToROCDLConversionLegality(ConversionTarget &target);
@@ -36,7 +40,8 @@ void configureGpuToROCDLConversionLegality(ConversionTarget &target);
 /// is configurable.
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
 createLowerGpuOpsToROCDLOpsPass(
-    unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout);
+    unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
+    gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown);
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Conversion/GPUToROCDL/Runtimes.h b/mlir/include/mlir/Conversion/GPUToROCDL/Runtimes.h
new file mode 100644
index 0000000000000..da47d4cf97e7e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GPUToROCDL/Runtimes.h
@@ -0,0 +1,24 @@
+//===- Runtimes.h - Possible runtimes for AMD GPUs ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_GPUTOROCDL_RUNTIMES_H
+#define MLIR_CONVERSION_GPUTOROCDL_RUNTIMES_H
+
+namespace mlir {
+namespace gpu {
+namespace amd {
+/// Potential runtimes for AMD GPU kernels
+enum Runtime {
+  Unknown = 0,
+  HIP = 1,
+  OpenCL = 2,
+};
+} // end namespace amd
+} // end namespace gpu
+} // end namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTOROCDL_RUNTIMES_H

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 31ef83579b2f9..791d4643b7d2c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -203,7 +203,15 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
   let options = [
     Option<"indexBitwidth", "index-bitwidth", "unsigned",
            /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
-           "Bitwidth of the index type, 0 to use size of machine word">
+           "Bitwidth of the index type, 0 to use size of machine word">,
+    Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
+          "::mlir::gpu::amd::Runtime::Unknown",
+          "Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",
+          [{::llvm::cl::values(
+            clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown", "Unknown (default)"),
+            clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
+            clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL", "OpenCL")
+          )}]>
   ];
 }
 

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 8bb3ebbef0396..11f2863ff2a18 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -547,6 +547,22 @@ def GPU_LaunchOp : GPU_Op<"launch">,
   let hasCanonicalizer = 1;
 }
 
+def GPU_PrintfOp : GPU_Op<"printf", [MemoryEffects<[MemWrite]>]>,
+  Arguments<(ins StrAttr:$format,
+                Variadic<AnyTypeOf<[AnyInteger, Index, AnyFloat]>>:$args)> {
+  let summary = "Device-side printf, as in CUDA or OpenCL, for debugging";
+  let description = [{
+    `gpu.printf` takes a literal format string `format` and an arbitrary number of
+    scalar arguments that should be printed.
+
+    The format string is a C-style printf string, subject to any restrictions
+    imposed by one's target platform.
+  }];
+  let assemblyFormat = [{
+    $format attr-dict ($args^ `:` type($args))?
+  }];
+}
+
 def GPU_ReturnOp : GPU_Op<"return", [HasParent<"GPUFuncOp">, NoSideEffect,
                                      Terminator]>,
     Arguments<(ins Variadic<AnyType>:$operands)>, Results<(outs)> {

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index ba986aa87dbbc..c72d2110a8892 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "GPUOpsLowering.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Builders.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -144,3 +145,200 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   rewriter.eraseOp(gpuFuncOp);
   return success();
 }
+
+static const char formatStringPrefix[] = "printfFormat_";
+
+template <typename T>
+static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
+                                            ConversionPatternRewriter &rewriter,
+                                            StringRef name,
+                                            LLVM::LLVMFunctionType type) {
+  LLVM::LLVMFuncOp ret;
+  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
+                                            LLVM::Linkage::External);
+  }
+  return ret;
+}
+
+LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
+    gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = gpuPrintfOp->getLoc();
+
+  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
+  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
+  mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
+  mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
+  mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
+  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
+  // This ensures that global constants and declarations are placed within
+  // the device code, not the host code
+  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+
+  auto ocklBegin =
+      getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
+                          LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
+  LLVM::LLVMFuncOp ocklAppendArgs;
+  if (!adaptor.args().empty()) {
+    ocklAppendArgs = getOrDefineFunction(
+        moduleOp, loc, rewriter, "__ockl_printf_append_args",
+        LLVM::LLVMFunctionType::get(
+            llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
+                      llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
+  }
+  auto ocklAppendStringN = getOrDefineFunction(
+      moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
+      LLVM::LLVMFunctionType::get(
+          llvmI64,
+          {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
+
+  /// Start the printf hostcall
+  Value zeroI64 = rewriter.create<LLVM::ConstantOp>(
+      loc, llvmI64, rewriter.getI64IntegerAttr(0));
+  auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
+  Value printfDesc = printfBeginCall.getResult(0);
+
+  // Create a global constant for the format string
+  unsigned stringNumber = 0;
+  SmallString<16> stringConstName;
+  do {
+    stringConstName.clear();
+    (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
+  } while (moduleOp.lookupSymbol(stringConstName));
+
+  llvm::SmallString<20> formatString(adaptor.format().getValue());
+  formatString.push_back('\0'); // Null terminate for C
+  size_t formatStringSize = formatString.size_in_bytes();
+
+  auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
+  LLVM::GlobalOp global;
+  {
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    global = rewriter.create<LLVM::GlobalOp>(
+        loc, globalType,
+        /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
+        rewriter.getStringAttr(formatString));
+  }
+
+  // Get a pointer to the format string's first element and pass it to printf()
+  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
+  Value zero = rewriter.create<LLVM::ConstantOp>(
+      loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
+  Value stringStart = rewriter.create<LLVM::GEPOp>(
+      loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
+  Value stringLen = rewriter.create<LLVM::ConstantOp>(
+      loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize));
+
+  Value oneI32 = rewriter.create<LLVM::ConstantOp>(
+      loc, llvmI32, rewriter.getI32IntegerAttr(1));
+  Value zeroI32 = rewriter.create<LLVM::ConstantOp>(
+      loc, llvmI32, rewriter.getI32IntegerAttr(0));
+
+  mlir::ValueRange appendFormatArgs = {printfDesc, stringStart, stringLen,
+                                       adaptor.args().empty() ? oneI32
+                                                              : zeroI32};
+  auto appendFormatCall =
+      rewriter.create<LLVM::CallOp>(loc, ocklAppendStringN, appendFormatArgs);
+  printfDesc = appendFormatCall.getResult(0);
+
+  // __ockl_printf_append_args takes 7 values per append call
+  constexpr size_t argsPerAppend = 7;
+  size_t nArgs = adaptor.args().size();
+  for (size_t group = 0; group < nArgs; group += argsPerAppend) {
+    size_t bound = std::min(group + argsPerAppend, nArgs);
+    size_t numArgsThisCall = bound - group;
+
+    SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
+    arguments.push_back(printfDesc);
+    arguments.push_back(rewriter.create<LLVM::ConstantOp>(
+        loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall)));
+    for (size_t i = group; i < bound; ++i) {
+      Value arg = adaptor.args()[i];
+      if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
+        if (!floatType.isF64())
+          arg = rewriter.create<LLVM::FPExtOp>(
+              loc, typeConverter->convertType(rewriter.getF64Type()), arg);
+        arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
+      }
+      if (arg.getType().getIntOrFloatBitWidth() != 64)
+        arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
+
+      arguments.push_back(arg);
+    }
+    // Pad out to 7 arguments since the hostcall always needs 7
+    for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
+      arguments.push_back(zeroI64);
+    }
+
+    auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
+    arguments.push_back(isLast);
+    auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
+    printfDesc = call.getResult(0);
+  }
+  rewriter.eraseOp(gpuPrintfOp);
+  return success();
+}
+
+LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
+    gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = gpuPrintfOp->getLoc();
+
+  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
+  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace);
+  mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
+
+  // Note: this is the GPUModule op, not the ModuleOp that surrounds it
+  // This ensures that global constants and declarations are placed within
+  // the device code, not the host code
+  auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
+
+  auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr},
+                                                /*isVarArg=*/true);
+  LLVM::LLVMFuncOp printfDecl =
+      getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
+
+  // Create a global constant for the format string
+  unsigned stringNumber = 0;
+  SmallString<16> stringConstName;
+  do {
+    stringConstName.clear();
+    (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
+  } while (moduleOp.lookupSymbol(stringConstName));
+
+  llvm::SmallString<20> formatString(adaptor.format().getValue());
+  formatString.push_back('\0'); // Null terminate for C
+  auto globalType =
+      LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
+  LLVM::GlobalOp global;
+  {
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    global = rewriter.create<LLVM::GlobalOp>(
+        loc, globalType,
+        /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
+        rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
+  }
+
+  // Get a pointer to the format string's first element
+  Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
+  Value zero = rewriter.create<LLVM::ConstantOp>(
+      loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
+  Value stringStart = rewriter.create<LLVM::GEPOp>(
+      loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
+
+  // Construct arguments and function call
+  auto argsRange = adaptor.args();
+  SmallVector<Value, 4> printfArgs;
+  printfArgs.reserve(argsRange.size() + 1);
+  printfArgs.push_back(stringStart);
+  printfArgs.append(argsRange.begin(), argsRange.end());
+
+  rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
+  rewriter.eraseOp(gpuPrintfOp);
+  return success();
+}

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 4abafb9686922..0800a00ec0d87 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -33,6 +33,40 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
   StringAttr kernelAttributeName;
 };
 
+/// The lowering of gpu.printf to a call to HIP hostcalls
+///
+/// Simplifies llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp, as we don't have
+/// to deal with %s (even if there were first-class strings in MLIR, they're not
+/// legal input to gpu.printf) or non-constant format strings
+struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> {
+  using ConvertOpToLLVMPattern<gpu::PrintfOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+/// The lowering of gpu.printf to a call to an external printf() function
+///
+/// This pass will add a declaration of printf() to the GPUModule if needed
+/// and seperate out the format strings into global constants. For some
+/// runtimes, such as OpenCL on AMD, this is sufficient setup, as the compiler
+/// will lower printf calls to appropriate device-side code
+struct GPUPrintfOpToLLVMCallLowering
+    : public ConvertOpToLLVMPattern<gpu::PrintfOp> {
+  GPUPrintfOpToLLVMCallLowering(LLVMTypeConverter &converter,
+                                int addressSpace = 0)
+      : ConvertOpToLLVMPattern<gpu::PrintfOp>(converter),
+        addressSpace(addressSpace) {}
+
+  LogicalResult
+  matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+
+private:
+  int addressSpace;
+};
+
 struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
   using ConvertOpToLLVMPattern<gpu::ReturnOp>::ConvertOpToLLVMPattern;
 

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 1a570dc5985b6..d345434a379e3 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -51,8 +51,9 @@ namespace {
 struct LowerGpuOpsToROCDLOpsPass
     : public ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
   LowerGpuOpsToROCDLOpsPass() = default;
-  LowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) {
+  LowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth, gpu::amd::Runtime runtime) {
     this->indexBitwidth = indexBitwidth;
+    this->runtime = runtime;
   }
 
   void runOnOperation() override {
@@ -79,7 +80,7 @@ struct LowerGpuOpsToROCDLOpsPass
     populateVectorToROCDLConversionPatterns(converter, llvmPatterns);
     populateStdToLLVMConversionPatterns(converter, llvmPatterns);
     populateMemRefToLLVMConversionPatterns(converter, llvmPatterns);
-    populateGpuToROCDLConversionPatterns(converter, llvmPatterns);
+    populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
     LLVMConversionTarget target(getContext());
     configureGpuToROCDLConversionLegality(target);
     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
@@ -102,8 +103,11 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
 }
 
-void mlir::populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
-                                                RewritePatternSet &patterns) {
+void mlir::populateGpuToROCDLConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    mlir::gpu::amd::Runtime runtime) {
+  using mlir::gpu::amd::Runtime;
+
   populateWithGenerated(patterns);
   patterns
       .add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
@@ -119,6 +123,13 @@ void mlir::populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
       converter, /*allocaAddrSpace=*/5,
       StringAttr::get(&converter.getContext(),
                       ROCDL::ROCDLDialect::getKernelFuncAttrName()));
+  if (Runtime::HIP == runtime) {
+    patterns.add<GPUPrintfOpToHIPLowering>(converter);
+  } else if (Runtime::OpenCL == runtime) {
+    // Use address space = 4 to match the OpenCL definition of printf()
+    patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/4);
+  }
+
   patterns.add<OpToFuncCallLowering<math::AbsOp>>(converter, "__ocml_fabs_f32",
                                                   "__ocml_fabs_f64");
   patterns.add<OpToFuncCallLowering<math::AtanOp>>(converter, "__ocml_atan_f32",
@@ -158,6 +169,7 @@ void mlir::populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-mlir::createLowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) {
-  return std::make_unique<LowerGpuOpsToROCDLOpsPass>(indexBitwidth);
+mlir::createLowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth,
+                                      gpu::amd::Runtime runtime) {
+  return std::make_unique<LowerGpuOpsToROCDLOpsPass>(indexBitwidth, runtime);
 }

diff  --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index d9f47138aae53..628d6995feeb3 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -11,6 +11,8 @@
 
 #include "mlir/Pass/Pass.h"
 
+#include "mlir/Conversion/GPUToROCDL/Runtimes.h"
+
 namespace mlir {
 class AffineDialect;
 class StandardOpsDialect;

diff  --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
index 6f488b0362a6e..8ffa83f893fb6 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp
@@ -306,6 +306,12 @@ SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
       return nullptr;
     }
   }
+
+  // Set amdgpu_hostcall if host calls have been linked, as needed by newer LLVM
+  // FIXME: Is there a way to set this during printf() lowering that makes sense
+  if (ret->getFunction("__ockl_hostcall_internal"))
+    if (!ret->getModuleFlag("amdgpu_hostcall"))
+      ret->addModuleFlag(llvm::Module::Override, "amdgpu_hostcall", 1);
   return ret;
 }
 

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index a0c61b8de98c8..95d9939f2fa1a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -71,11 +71,14 @@ class ROCDLDialectLLVMIRTranslationInterface
 
       // For GPU kernels,
       // 1. Insert AMDGPU_KERNEL calling convention.
-      // 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute.
+      // 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute.
+      // 3. Insert amdgpu-implicitarg-num-bytes=56 (which must be set on OpenCL
+      // and HIP kernels per Clang)
       llvm::Function *llvmFunc =
           moduleTranslation.lookupFunction(func.getName());
       llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
-      llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
+      llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1, 256");
+      llvmFunc->addFnAttr("amdgpu-implicitarg-num-bytes", "56");
     }
     return success();
   }

diff  --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
new file mode 100644
index 0000000000000..b2efa495fede4
--- /dev/null
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt %s -convert-gpu-to-rocdl=runtime=HIP -split-input-file | FileCheck %s
+
+gpu.module @test_module {
+  // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
+  // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
+  // CHECK-DAG: llvm.func @__ockl_printf_append_args(i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64
+  // CHECK-DAG: llvm.func @__ockl_printf_append_string_n(i64, !llvm.ptr<i8>, i64, i32) -> i64
+  // CHECK-DAG: llvm.func @__ockl_printf_begin(i64) -> i64
+
+  // CHECK-LABEL: func @test_const_printf
+  gpu.func @test_const_printf() {
+    // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+    // CHECK-NEXT: %[[DESC0:.*]] = llvm.call @__ockl_printf_begin(%0) : (i64) -> i64
+    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr<array<14 x i8>>
+    // CHECK-NEXT: %[[CST1:.*]] = llvm.mlir.constant(0 : i64) : i64
+    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][%[[CST1]], %[[CST1]]] : (!llvm.ptr<array<14 x i8>>, i64, i64) -> !llvm.ptr<i8>
+    // CHECK-NEXT: %[[FORMATLEN:.*]] = llvm.mlir.constant(14 : i64) : i64
+    // CHECK-NEXT: %[[ISLAST:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT: %[[ISNTLAST:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT: %{{.*}} = llvm.call @__ockl_printf_append_string_n(%[[DESC0]], %[[FORMATSTART]], %[[FORMATLEN]], %[[ISLAST]]) : (i64, !llvm.ptr<i8>, i64, i32) -> i64
+    gpu.printf "Hello, world\n"
+    gpu.return
+  }
+
+
+  // CHECK-LABEL: func @test_printf
+  // CHECK: (%[[ARG0:.*]]: i32)
+  gpu.func @test_printf(%arg0: i32) {
+    // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+    // CHECK-NEXT: %[[DESC0:.*]] = llvm.call @__ockl_printf_begin(%0) : (i64) -> i64
+    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr<array<11 x i8>>
+    // CHECK-NEXT: %[[CST1:.*]] = llvm.mlir.constant(0 : i64) : i64
+    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][%[[CST1]], %[[CST1]]] : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
+    // CHECK-NEXT: %[[FORMATLEN:.*]] = llvm.mlir.constant(11 : i64) : i64
+    // CHECK-NEXT: %[[ISLAST:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT: %[[ISNTLAST:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT: %[[DESC1:.*]] = llvm.call @__ockl_printf_append_string_n(%[[DESC0]], %[[FORMATSTART]], %[[FORMATLEN]], %[[ISNTLAST]]) : (i64, !llvm.ptr<i8>, i64, i32) -> i64
+    // CHECK-NEXT: %[[NARGS1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT: %[[ARG0_64:.*]] = llvm.zext %[[ARG0]] : i32 to i64
+    // CHECK-NEXT: %{{.*}} = llvm.call @__ockl_printf_append_args(%[[DESC1]], %[[NARGS1]], %[[ARG0_64]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[ISLAST]]) : (i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64
+    gpu.printf "Hello: %d\n" %arg0 : i32
+    gpu.return
+  }
+}

diff  --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir
new file mode 100644
index 0000000000000..8e5af9dff5a31
--- /dev/null
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -convert-gpu-to-rocdl=runtime=OpenCL | FileCheck %s
+
+gpu.module @test_module {
+  // CHECK: llvm.mlir.global internal constant @[[$PRINT_GLOBAL:[A-Za-z0-9_]+]]("Hello: %d\0A\00")  {addr_space = 4 : i32}
+  // CHECK: llvm.func @printf(!llvm.ptr<i8, 4>, ...) -> i32
+  // CHECK-LABEL: func @test_printf
+  // CHECK: (%[[ARG0:.*]]: i32)
+  gpu.func @test_printf(%arg0: i32) {
+    // CHECK: %[[IMM0:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL]] : !llvm.ptr<array<11 x i8>, 4>
+    // CHECK-NEXT: %[[IMM1:.*]] = llvm.mlir.constant(0 : i64) : i64
+    // CHECK-NEXT: %[[IMM2:.*]] = llvm.getelementptr %[[IMM0]][%[[IMM1]], %[[IMM1]]] : (!llvm.ptr<array<11 x i8>, 4>, i64, i64) -> !llvm.ptr<i8, 4>
+    // CHECK-NEXT: %{{.*}} = llvm.call @printf(%[[IMM2]], %[[ARG0]]) : (!llvm.ptr<i8, 4>, i32) -> i32
+    gpu.printf "Hello: %d\n" %arg0 : i32
+    gpu.return
+  }
+}

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 1c5ab143dcc44..fe4015ecb2142 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -112,6 +112,14 @@ module attributes {gpu.container_module} {
       gpu.return
     }
 
+    // CHECK-LABEL gpu.func @printf_test
+    // CHECK: (%[[ARG0:.*]]: i32)
+    // CHECK: gpu.printf "Value: %d" %[[ARG0]] : i32
+    gpu.func @printf_test(%arg0 : i32) {
+      gpu.printf "Value: %d" %arg0 : i32
+      gpu.return
+    }
+
     // CHECK-LABEL: gpu.func @no_attribution
     // CHECK: {
     gpu.func @no_attribution(%arg0: f32) {

diff  --git a/mlir/test/Integration/GPU/ROCM/printf.mlir b/mlir/test/Integration/GPU/ROCM/printf.mlir
new file mode 100644
index 0000000000000..e476d6d299930
--- /dev/null
+++ b/mlir/test/Integration/GPU/ROCM/printf.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s \
+// RUN:   -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-rocdl{index-bitwidth=32 runtime=HIP},gpu-to-hsaco{chip=%chip})' \
+// RUN:   -gpu-to-llvm \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext \
+// RUN:   --shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+// CHECK: Hello from 0
+// CHECK: Hello from 1
+module attributes {gpu.container_module} {
+    gpu.module @kernels {
+        gpu.func @hello() kernel {
+            %0 = "gpu.thread_id"() {dimension="x"} : () -> (index)
+            gpu.printf "Hello from %d\n" %0 : index
+            gpu.return
+        }
+    }
+
+    func @main() {
+        %c2 = arith.constant 2 : index
+        %c1 = arith.constant 1 : index
+        gpu.launch_func @kernels::@hello
+            blocks in (%c1, %c1, %c1)
+            threads in (%c2, %c1, %c1)
+        return
+    }
+}


        


More information about the Mlir-commits mailing list