[Mlir-commits] [mlir] [MLIR][GPU] Support grid constant, byval, byref on gpu.func (PR #172037)

Asher Mancinelli llvmlistbot at llvm.org
Fri Dec 19 08:18:54 PST 2025


https://github.com/ashermancinelli updated https://github.com/llvm/llvm-project/pull/172037

>From d3c583de9b08f0d161aa4bcdf58b4ccb8092bac0 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Fri, 12 Dec 2025 08:09:05 -0800
Subject: [PATCH 1/3] Support grid constant, byval, byref on gpu.func

---
 mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp |  4 ++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp       |  3 ++-
 mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir  | 16 ++++++++++++++++
 3 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index eb662a1b056de..d0a03098044c7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -369,7 +370,10 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
     }
 
     if (lowersToPointer) {
+      copyPointerAttribute(mlir::NVVM::NVVMDialect::getGridConstantAttrName());
       copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
+      copyPointerAttribute(LLVM::LLVMDialect::getByValAttrName());
+      copyPointerAttribute(LLVM::LLVMDialect::getByRefAttrName());
       copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
       copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
       copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 331d7a244310f..6ec0f54674378 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -5448,7 +5448,8 @@ LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
   if (!funcOp)
     return success();
 
-  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
+  const bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName()) ||
+                        op->hasAttr(gpu::GPUDialect::getKernelFuncAttrName());
   StringAttr attrName = argAttr.getName();
   if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
     if (!isKernel) {
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index f1cc1eb983267..28613ee33cac6 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1149,3 +1149,19 @@ gpu.module @test_module_56 {
     func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
   }
 }
+
+// Check that nvvm.grid_constant is a valid argument attribute on gpu.kernel.
+gpu.module @test_module_57 {
+  // CHECK:       gpu.module @test_module_57
+  // CHECK-LABEL:   llvm.func @test_kernel(
+  // CHECK-SAME:      %[[VAL_0:.*]]: !llvm.ptr {llvm.byval = i64, nvvm.grid_constant}
+  // CHECK-SAME:      %[[VAL_1:.*]]: !llvm.ptr {llvm.byref = i64}
+  // CHECK:           llvm.return
+  // CHECK:         }
+  // CHECK:       }
+  gpu.func @test_kernel(
+      %arg0: !llvm.ptr {nvvm.grid_constant, llvm.byval = i64},
+      %arg1: !llvm.ptr {llvm.byref = i64}) kernel {
+    gpu.return
+  }
+}

>From e0b8653aa74bf8bba65156e6fa27649d92fd44d4 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Fri, 12 Dec 2025 08:24:38 -0800
Subject: [PATCH 2/3] Prefer early-exit

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 33 ++++++++++---------
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     |  2 +-
 2 files changed, 18 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 6ec0f54674378..7f97ca843ce5f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -5448,24 +5448,25 @@ LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
   if (!funcOp)
     return success();
 
+  StringAttr attrName = argAttr.getName();
+  if (attrName != NVVM::NVVMDialect::getGridConstantAttrName())
+    return success();
+
   const bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName()) ||
                         op->hasAttr(gpu::GPUDialect::getKernelFuncAttrName());
-  StringAttr attrName = argAttr.getName();
-  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
-    if (!isKernel) {
-      return op->emitError()
-             << "'" << attrName
-             << "' attribute must be present only on kernel arguments";
-    }
-    if (!isa<UnitAttr>(argAttr.getValue()))
-      return op->emitError() << "'" << attrName << "' must be a unit attribute";
-    if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
-      return op->emitError()
-             << "'" << attrName
-             << "' attribute requires the argument to also have attribute '"
-             << LLVM::LLVMDialect::getByValAttrName() << "'";
-    }
-  }
+  if (!isKernel)
+    return op->emitError()
+           << "'" << attrName
+           << "' attribute must be present only on kernel arguments";
+
+  if (!isa<UnitAttr>(argAttr.getValue()))
+    return op->emitError() << "'" << attrName << "' must be a unit attribute";
+
+  if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName()))
+    return op->emitError()
+           << "'" << attrName
+           << "' attribute requires the argument to also have attribute '"
+           << LLVM::LLVMDialect::getByValAttrName() << "'";
 
   return success();
 }
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 28613ee33cac6..68e6025651a8a 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1150,7 +1150,7 @@ gpu.module @test_module_56 {
   }
 }
 
-// Check that nvvm.grid_constant is a valid argument attribute on gpu.kernel.
+// Check that nvvm.grid_constant, llvm.byref, and llvm.byval are valid argument attributes on gpu.kernel.
 gpu.module @test_module_57 {
   // CHECK:       gpu.module @test_module_57
   // CHECK-LABEL:   llvm.func @test_kernel(

>From 7e3612380ba9d654279ed68fff0c72881573b252 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Fri, 19 Dec 2025 08:18:25 -0800
Subject: [PATCH 3/3] Propagate all arg attrs when the type is unchanged

---
 mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index d0a03098044c7..fc0a1e802a1c7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -11,7 +11,6 @@
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -360,6 +359,14 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
     if (argAttr.empty())
       continue;
 
+    const bool isArgTypeUnchanged =
+        remapping->size == 1 &&
+        llvmFuncOp.getArgument(remapping->inputNo).getType() == argTy;
+    if (isArgTypeUnchanged) {
+      llvmFuncOp.setArgAttrs(remapping->inputNo, argAttr);
+      continue;
+    }
+
     copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
     copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
     copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
@@ -370,10 +377,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
     }
 
     if (lowersToPointer) {
-      copyPointerAttribute(mlir::NVVM::NVVMDialect::getGridConstantAttrName());
       copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
-      copyPointerAttribute(LLVM::LLVMDialect::getByValAttrName());
-      copyPointerAttribute(LLVM::LLVMDialect::getByRefAttrName());
       copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
       copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
       copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());



More information about the Mlir-commits mailing list