[Mlir-commits] [mlir] 346c4a8 - [mlir][NVVM] Disallow results on kernel functions (#96399)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jun 23 00:50:05 PDT 2024


Author: Matthias Springer
Date: 2024-06-23T09:50:01+02:00
New Revision: 346c4a88afedcef3da40f68c83f0a5b3e0ac61ea

URL: https://github.com/llvm/llvm-project/commit/346c4a88afedcef3da40f68c83f0a5b3e0ac61ea
DIFF: https://github.com/llvm/llvm-project/commit/346c4a88afedcef3da40f68c83f0a5b3e0ac61ea.diff

LOG: [mlir][NVVM] Disallow results on kernel functions (#96399)

Functions that have the `nvvm.kernel` attribute should have 0 results.

Added: 
    

Modified: 
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 94197e473ce01..3d6a911f36541 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -214,7 +214,8 @@ void MmaOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
 
   // Print the types of the operands and result.
-  p << " : " << "(";
+  p << " : "
+    << "(";
   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
                                              frags[1].regs[0].getType(),
                                              frags[2].regs[0].getType()},
@@ -955,7 +956,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   ss << "},";
   // Need to map read/write registers correctly.
   regCnt = (regCnt * 2);
-  ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
+  ss << " $" << (regCnt) << ","
+     << " $" << (regCnt + 1) << ","
+     << " p";
   if (getTypeD() != WGMMATypes::s32) {
     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
   }
@@ -1053,10 +1056,14 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
   StringAttr attrName = attr.getName();
   // Kernel function attribute should be attached to functions.
   if (attrName == NVVMDialect::getKernelFuncAttrName()) {
-    if (!isa<LLVM::LLVMFuncOp>(op)) {
+    auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+    if (!funcOp) {
       return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
                              << "' attribute attached to unexpected op";
     }
+    if (!funcOp.getResultTypes().empty()) {
+      return op->emitError() << "kernel function cannot have results";
+    }
   }
   // If maxntid and reqntid exist, it must be an array with max 3 dim
   if (attrName == NVVMDialect::getMaxntidAttrName() ||

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index a8ae4d97888c9..26ba80cba6ed5 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -574,3 +574,10 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant})
 llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
   llvm.return
 }
+
+// -----
+
+// expected-error @below{{kernel function cannot have results}}
+llvm.func @kernel_with_result(%i: i32) -> i32 attributes {nvvm.kernel}  {
+  llvm.return %i : i32
+}


        


More information about the Mlir-commits mailing list