[Mlir-commits] [mlir] 4c4876c - [mlir] Use target-specific GPU kernel attributes in lowering pipelines

Alex Zinenko llvmlistbot at llvm.org
Fri Feb 12 05:09:32 PST 2021


Author: Alex Zinenko
Date: 2021-02-12T14:09:24+01:00
New Revision: 4c4876c314577e253a198ca3868b26fd35ec8a6e

URL: https://github.com/llvm/llvm-project/commit/4c4876c314577e253a198ca3868b26fd35ec8a6e
DIFF: https://github.com/llvm/llvm-project/commit/4c4876c314577e253a198ca3868b26fd35ec8a6e.diff

LOG: [mlir] Use target-specific GPU kernel attributes in lowering pipelines

Until now, the GPU translation to NVVM or ROCDL intrinsics relied on the
presence of the generic `gpu.kernel` attribute to attach additional LLVM IR
metadata to the relevant functions. This would be problematic if each dialect
were to handle the conversion of its own options, which is the intended
direction for the translation infrastructure. Introduce `nvvm.kernel` and
`rocdl.kernel` in addition to `gpu.kernel` and base translation on these new
attributes instead.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D96591

Added: 
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/lib/Conversion/GPUCommon/CMakeLists.txt
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
    mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
    mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
    mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
    mlir/test/Target/nvvmir.mlir
    mlir/test/Target/rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index b94cb5ce86b6..de7fd01fe1ca 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -24,6 +24,12 @@ def NVVM_Dialect : Dialect {
   let name = "nvvm";
   let cppNamespace = "::mlir::NVVM";
   let dependentDialects = ["LLVM::LLVMDialect"];
+
+  let extraClassDeclaration = [{
+    /// Get the name of the attribute used to annotate external kernel
+    /// functions.
+    static StringRef getKernelFuncAttrName() { return "nvvm.kernel"; }
+  }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index e073126450d6..d5eec3cebb4a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -24,6 +24,12 @@ def ROCDL_Dialect : Dialect {
   let name = "rocdl";
   let cppNamespace = "::mlir::ROCDL";
   let dependentDialects = ["LLVM::LLVMDialect"];
+
+  let extraClassDeclaration = [{
+    /// Get the name of the attribute used to annotate external kernel
+    /// functions.
+    static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
+  }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
index f3c6eca87d8a..825bed600aba 100644
--- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
@@ -17,6 +17,7 @@ endif()
 add_mlir_conversion_library(MLIRGPUToGPURuntimeTransforms
   ConvertLaunchFuncToRuntimeCalls.cpp
   ConvertKernelFuncToBlob.cpp
+  GPUOpsLowering.cpp
 
   DEPENDS
   MLIRConversionPassIncGen

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
new file mode 100644
index 000000000000..c0fde0e2eecb
--- /dev/null
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -0,0 +1,148 @@
+//===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "GPUOpsLowering.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+
+LogicalResult
+GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
+                                   ArrayRef<Value> operands,
+                                   ConversionPatternRewriter &rewriter) const {
+  assert(operands.empty() && "func op is not expected to have operands");
+  Location loc = gpuFuncOp.getLoc();
+
+  SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
+  workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
+  for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
+    Value attribution = en.value();
+
+    auto type = attribution.getType().dyn_cast<MemRefType>();
+    assert(type && type.hasStaticShape() && "unexpected type in attribution");
+
+    uint64_t numElements = type.getNumElements();
+
+    auto elementType =
+        typeConverter->convertType(type.getElementType()).template cast<Type>();
+    auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
+    std::string name = std::string(
+        llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
+    auto globalOp = rewriter.create<LLVM::GlobalOp>(
+        gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
+        LLVM::Linkage::Internal, name, /*value=*/Attribute(),
+        gpu::GPUDialect::getWorkgroupAddressSpace());
+    workgroupBuffers.push_back(globalOp);
+  }
+
+  // Rewrite the original GPU function to an LLVM function.
+  auto funcType = typeConverter->convertType(gpuFuncOp.getType())
+                      .template cast<LLVM::LLVMPointerType>()
+                      .getElementType();
+
+  // Remap proper input types.
+  TypeConverter::SignatureConversion signatureConversion(
+      gpuFuncOp.front().getNumArguments());
+  getTypeConverter()->convertFunctionSignature(
+      gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion);
+
+  // Create the new function operation. Only copy those attributes that are
+  // not specific to function modeling.
+  SmallVector<NamedAttribute, 4> attributes;
+  for (const auto &attr : gpuFuncOp.getAttrs()) {
+    if (attr.first == SymbolTable::getSymbolAttrName() ||
+        attr.first == impl::getTypeAttrName() ||
+        attr.first == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
+      continue;
+    attributes.push_back(attr);
+  }
+  // Add a dialect specific kernel attribute in addition to GPU kernel
+  // attribute. The former is necessary for further translation while the
+  // latter is expected by gpu.launch_func.
+  if (gpuFuncOp.isKernel())
+    attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
+  auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
+      gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
+      LLVM::Linkage::External, attributes);
+
+  {
+    // Insert operations that correspond to converted workgroup and private
+    // memory attributions to the body of the function. This must operate on
+    // the original function, before the body region is inlined in the new
+    // function to maintain the relation between block arguments and the
+    // parent operation that assigns their semantics.
+    OpBuilder::InsertionGuard guard(rewriter);
+
+    // Rewrite workgroup memory attributions to addresses of global buffers.
+    rewriter.setInsertionPointToStart(&gpuFuncOp.front());
+    unsigned numProperArguments = gpuFuncOp.getNumArguments();
+    auto i32Type = IntegerType::get(rewriter.getContext(), 32);
+
+    Value zero = nullptr;
+    if (!workgroupBuffers.empty())
+      zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
+                                               rewriter.getI32IntegerAttr(0));
+    for (auto en : llvm::enumerate(workgroupBuffers)) {
+      LLVM::GlobalOp global = en.value();
+      Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
+      auto elementType =
+          global.getType().cast<LLVM::LLVMArrayType>().getElementType();
+      Value memory = rewriter.create<LLVM::GEPOp>(
+          loc, LLVM::LLVMPointerType::get(elementType, global.addr_space()),
+          address, ArrayRef<Value>{zero, zero});
+
+      // Build a memref descriptor pointing to the buffer to plug with the
+      // existing memref infrastructure. This may use more registers than
+      // otherwise necessary given that memref sizes are fixed, but we can try
+      // and canonicalize that away later.
+      Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
+      auto type = attribution.getType().cast<MemRefType>();
+      auto descr = MemRefDescriptor::fromStaticShape(
+          rewriter, loc, *getTypeConverter(), type, memory);
+      signatureConversion.remapInput(numProperArguments + en.index(), descr);
+    }
+
+    // Rewrite private memory attributions to alloca'ed buffers.
+    unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
+    auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
+    for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
+      Value attribution = en.value();
+      auto type = attribution.getType().cast<MemRefType>();
+      assert(type && type.hasStaticShape() && "unexpected type in attribution");
+
+      // Explicitly drop memory space when lowering private memory
+      // attributions since NVVM models it as `alloca`s in the default
+      // memory space and does not support `alloca`s with addrspace(5).
+      auto ptrType = LLVM::LLVMPointerType::get(
+          typeConverter->convertType(type.getElementType())
+              .template cast<Type>(),
+          allocaAddrSpace);
+      Value numElements = rewriter.create<LLVM::ConstantOp>(
+          gpuFuncOp.getLoc(), int64Ty,
+          rewriter.getI64IntegerAttr(type.getNumElements()));
+      Value allocated = rewriter.create<LLVM::AllocaOp>(
+          gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
+      auto descr = MemRefDescriptor::fromStaticShape(
+          rewriter, loc, *getTypeConverter(), type, allocated);
+      signatureConversion.remapInput(
+          numProperArguments + numWorkgroupAttributions + en.index(), descr);
+    }
+  }
+
+  // Move the region to the new function, update the entry block signature.
+  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
+                              llvmFuncOp.end());
+  if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
+                                         &signatureConversion)))
+    return failure();
+
+  rewriter.eraseOp(gpuFuncOp);
+  return success();
+}

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 95215d0f4a6e..b3a5d4078b88 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -11,145 +11,26 @@
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Builders.h"
-#include "llvm/Support/FormatVariadic.h"
 
 namespace mlir {
 
-template <unsigned AllocaAddrSpace>
 struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
-  using ConvertOpToLLVMPattern<gpu::GPUFuncOp>::ConvertOpToLLVMPattern;
+  GPUFuncOpLowering(LLVMTypeConverter &converter, unsigned allocaAddrSpace,
+                    Identifier kernelAttributeName)
+      : ConvertOpToLLVMPattern<gpu::GPUFuncOp>(converter),
+        allocaAddrSpace(allocaAddrSpace),
+        kernelAttributeName(kernelAttributeName) {}
 
   LogicalResult
   matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    assert(operands.empty() && "func op is not expected to have operands");
-    Location loc = gpuFuncOp.getLoc();
-
-    SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
-    workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
-    for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
-      Value attribution = en.value();
-
-      auto type = attribution.getType().dyn_cast<MemRefType>();
-      assert(type && type.hasStaticShape() && "unexpected type in attribution");
-
-      uint64_t numElements = type.getNumElements();
-
-      auto elementType = typeConverter->convertType(type.getElementType())
-                             .template cast<Type>();
-      auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
-      std::string name = std::string(
-          llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
-      auto globalOp = rewriter.create<LLVM::GlobalOp>(
-          gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
-          LLVM::Linkage::Internal, name, /*value=*/Attribute(),
-          gpu::GPUDialect::getWorkgroupAddressSpace());
-      workgroupBuffers.push_back(globalOp);
-    }
-
-    // Rewrite the original GPU function to an LLVM function.
-    auto funcType = typeConverter->convertType(gpuFuncOp.getType())
-                        .template cast<LLVM::LLVMPointerType>()
-                        .getElementType();
-
-    // Remap proper input types.
-    TypeConverter::SignatureConversion signatureConversion(
-        gpuFuncOp.front().getNumArguments());
-    getTypeConverter()->convertFunctionSignature(
-        gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion);
-
-    // Create the new function operation. Only copy those attributes that are
-    // not specific to function modeling.
-    SmallVector<NamedAttribute, 4> attributes;
-    for (const auto &attr : gpuFuncOp.getAttrs()) {
-      if (attr.first == SymbolTable::getSymbolAttrName() ||
-          attr.first == impl::getTypeAttrName() ||
-          attr.first == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
-        continue;
-      attributes.push_back(attr);
-    }
-    auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
-        gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
-        LLVM::Linkage::External, attributes);
+                  ConversionPatternRewriter &rewriter) const override;
 
-    {
-      // Insert operations that correspond to converted workgroup and private
-      // memory attributions to the body of the function. This must operate on
-      // the original function, before the body region is inlined in the new
-      // function to maintain the relation between block arguments and the
-      // parent operation that assigns their semantics.
-      OpBuilder::InsertionGuard guard(rewriter);
+private:
+  /// The address spcae to use for `alloca`s in private memory.
+  unsigned allocaAddrSpace;
 
-      // Rewrite workgroup memory attributions to addresses of global buffers.
-      rewriter.setInsertionPointToStart(&gpuFuncOp.front());
-      unsigned numProperArguments = gpuFuncOp.getNumArguments();
-      auto i32Type = IntegerType::get(rewriter.getContext(), 32);
-
-      Value zero = nullptr;
-      if (!workgroupBuffers.empty())
-        zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
-                                                 rewriter.getI32IntegerAttr(0));
-      for (auto en : llvm::enumerate(workgroupBuffers)) {
-        LLVM::GlobalOp global = en.value();
-        Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
-        auto elementType =
-            global.getType().cast<LLVM::LLVMArrayType>().getElementType();
-        Value memory = rewriter.create<LLVM::GEPOp>(
-            loc, LLVM::LLVMPointerType::get(elementType, global.addr_space()),
-            address, ArrayRef<Value>{zero, zero});
-
-        // Build a memref descriptor pointing to the buffer to plug with the
-        // existing memref infrastructure. This may use more registers than
-        // otherwise necessary given that memref sizes are fixed, but we can try
-        // and canonicalize that away later.
-        Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
-        auto type = attribution.getType().cast<MemRefType>();
-        auto descr = MemRefDescriptor::fromStaticShape(
-            rewriter, loc, *getTypeConverter(), type, memory);
-        signatureConversion.remapInput(numProperArguments + en.index(), descr);
-      }
-
-      // Rewrite private memory attributions to alloca'ed buffers.
-      unsigned numWorkgroupAttributions =
-          gpuFuncOp.getNumWorkgroupAttributions();
-      auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
-      for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
-        Value attribution = en.value();
-        auto type = attribution.getType().cast<MemRefType>();
-        assert(type && type.hasStaticShape() &&
-               "unexpected type in attribution");
-
-        // Explicitly drop memory space when lowering private memory
-        // attributions since NVVM models it as `alloca`s in the default
-        // memory space and does not support `alloca`s with addrspace(5).
-        auto ptrType = LLVM::LLVMPointerType::get(
-            typeConverter->convertType(type.getElementType())
-                .template cast<Type>(),
-            AllocaAddrSpace);
-        Value numElements = rewriter.create<LLVM::ConstantOp>(
-            gpuFuncOp.getLoc(), int64Ty,
-            rewriter.getI64IntegerAttr(type.getNumElements()));
-        Value allocated = rewriter.create<LLVM::AllocaOp>(
-            gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
-        auto descr = MemRefDescriptor::fromStaticShape(
-            rewriter, loc, *getTypeConverter(), type, allocated);
-        signatureConversion.remapInput(
-            numProperArguments + numWorkgroupAttributions + en.index(), descr);
-      }
-    }
-
-    // Move the region to the new function, update the entry block signature.
-    rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
-                                llvmFuncOp.end());
-    if (failed(rewriter.convertRegionTypes(
-            &llvmFuncOp.getBody(), *typeConverter, &signatureConversion)))
-      return failure();
-
-    rewriter.eraseOp(gpuFuncOp);
-    return success();
-  }
+  /// The attribute name to use instead of `gpu.kernel`.
+  Identifier kernelAttributeName;
 };
 
 struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {

diff  --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
index a0e8d6d21dd0..50f9e6a295e0 100644
--- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRGPUToNVVMTransforms
 
   LINK_LIBS PUBLIC
   MLIRGPU
+  MLIRGPUToGPURuntimeTransforms
   MLIRLLVMIR
   MLIRNVVMIR
   MLIRPass

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 8424937d7cf0..05b77dba8d11 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -167,11 +167,16 @@ void mlir::populateGpuToNVVMConversionPatterns(
                                           NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
               GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
                                           NVVM::GridDimYOp, NVVM::GridDimZOp>,
-              GPUShuffleOpLowering, GPUReturnOpLowering,
-              // Explicitly drop memory space when lowering private memory
-              // attributions since NVVM models it as `alloca`s in the default
-              // memory space and does not support `alloca`s with addrspace(5).
-              GPUFuncOpLowering<0>>(converter);
+              GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
+
+  // Explicitly drop memory space when lowering private memory
+  // attributions since NVVM models it as `alloca`s in the default
+  // memory space and does not support `alloca`s with addrspace(5).
+  patterns.insert<GPUFuncOpLowering>(
+      converter, /*allocaAddrSpace=*/0,
+      Identifier::get(NVVM::NVVMDialect::getKernelFuncAttrName(),
+                      &converter.getContext()));
+
   patterns.insert<OpToFuncCallLowering<AbsFOp>>(converter, "__nv_fabsf",
                                                 "__nv_fabs");
   patterns.insert<OpToFuncCallLowering<math::AtanOp>>(converter, "__nv_atanf",

diff  --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
index 0871b27aadda..ad270645648e 100644
--- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
 
   LINK_LIBS PUBLIC
   MLIRGPU
+  MLIRGPUToGPURuntimeTransforms
   MLIRLLVMIR
   MLIRROCDLIR
   MLIRPass

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 19953e5def53..c29cfea2274d 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -103,7 +103,11 @@ void mlir::populateGpuToROCDLConversionPatterns(
                                   ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>,
       GPUIndexIntrinsicOpLowering<gpu::GridDimOp, ROCDL::GridDimXOp,
                                   ROCDL::GridDimYOp, ROCDL::GridDimZOp>,
-      GPUFuncOpLowering<5>, GPUReturnOpLowering>(converter);
+      GPUReturnOpLowering>(converter);
+  patterns.insert<GPUFuncOpLowering>(
+      converter, /*allocaAddrSpace=*/5,
+      Identifier::get(ROCDL::ROCDLDialect::getKernelFuncAttrName(),
+                      &converter.getContext()));
   patterns.insert<OpToFuncCallLowering<AbsFOp>>(converter, "__ocml_fabs_f32",
                                                 "__ocml_fabs_f64");
   patterns.insert<OpToFuncCallLowering<math::AtanOp>>(

diff  --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
index 5d152e14c45b..668d9d95c150 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
@@ -77,7 +77,8 @@ mlir::translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext,
   // function as a kernel.
   for (auto func :
        ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
-    if (!gpu::GPUDialect::isKernel(func))
+    if (!func->getAttrOfType<UnitAttr>(
+            NVVM::NVVMDialect::getKernelFuncAttrName()))
       continue;
 
     auto *llvmFunc = llvmModule->getFunction(func.getName());

diff  --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
index c091c72c7702..c415787b2acb 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
@@ -87,7 +87,7 @@ mlir::translateModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext,
   for (auto func :
        ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
     if (!func->getAttrOfType<UnitAttr>(
-            gpu::GPUDialect::getKernelFuncAttrName()))
+            ROCDL::ROCDLDialect::getKernelFuncAttrName()))
       continue;
 
     auto *llvmFunc = llvmModule->getFunction(func.getName());

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 9c7ef468fbb7..1ade02e67a5b 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -423,3 +423,15 @@ gpu.module @test_module {
     std.return %result32, %result64 : f32, f64
   }
 }
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: @kernel_func
+  // CHECK: attributes
+  // CHECK: gpu.kernel
+  // CHECK: nvvm.kernel
+  gpu.func @kernel_func() kernel {
+    gpu.return
+  }
+}

diff  --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index bbbc2613cf6a..39ced31a7239 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -340,3 +340,15 @@ gpu.module @test_module {
     std.return %result32, %result64 : f32, f64
   }
 }
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: @kernel_func
+  // CHECK: attributes
+  // CHECK: gpu.kernel
+  // CHECK: rocdl.kernel
+  gpu.func @kernel_func() kernel {
+    gpu.return
+  }
+}

diff  --git a/mlir/test/Target/nvvmir.mlir b/mlir/test/Target/nvvmir.mlir
index 08aaa07b12a2..3065188925f8 100644
--- a/mlir/test/Target/nvvmir.mlir
+++ b/mlir/test/Target/nvvmir.mlir
@@ -75,7 +75,7 @@ llvm.func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
 
 // This function has the "kernel" attribute attached and should appear in the
 // NVVM annotations after conversion.
-llvm.func @kernel_func() attributes {gpu.kernel} {
+llvm.func @kernel_func() attributes {nvvm.kernel} {
   llvm.return
 }
 

diff  --git a/mlir/test/Target/rocdl.mlir b/mlir/test/Target/rocdl.mlir
index 1f4b8b03c81b..27884c738539 100644
--- a/mlir/test/Target/rocdl.mlir
+++ b/mlir/test/Target/rocdl.mlir
@@ -29,7 +29,7 @@ llvm.func @rocdl_special_regs() -> i32 {
   llvm.return %1 : i32
 }
 
-llvm.func @kernel_func() attributes {gpu.kernel} {
+llvm.func @kernel_func() attributes {rocdl.kernel} {
   // CHECK-LABEL: amdgpu_kernel void @kernel_func
   llvm.return
 }


        


More information about the Mlir-commits mailing list