[llvm] [mlir] [mlir][nvvm]Add support for grid_constant attribute on LLVM function arguments (PR #78228)

Rishi Surendran via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 28 18:28:18 PST 2024


https://github.com/rishisurendran updated https://github.com/llvm/llvm-project/pull/78228

>From 28cdab8050cdd37eb1c8c7de1a7782efe66563ed Mon Sep 17 00:00:00 2001
From: rsurendran <rsurendran at nvidia.com>
Date: Mon, 15 Jan 2024 15:36:14 -0800
Subject: [PATCH 1/4] Add support for adding 'nvvm.grid_constant' attribute to
 LLVM function arguments

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 13 +++++
 .../Target/LLVMIR/LLVMTranslationInterface.h  | 26 +++++++++
 .../mlir/Target/LLVMIR/ModuleTranslation.h    |  3 +-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 28 +++++++++
 mlir/lib/Target/LLVMIR/AttrKindDetail.h       | 13 +++++
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp  | 57 +++++++++++++++++++
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  | 56 ++++++++++--------
 mlir/test/Dialect/LLVMIR/nvvm.mlir            | 26 +++++++++
 mlir/test/Target/LLVMIR/nvvmir.mlir           | 17 ++++++
 9 files changed, 214 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f98..1fc5ee2c32bd49 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -59,6 +59,19 @@ def NVVM_Dialect : Dialect {
     /// Get the name of the attribute used to annotate max number of
     /// registers that can be allocated per thread.
     static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
+
+    /// Get the name of the attribute used to annotate kernel arguments that
+    /// are grid constants.
+    static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
+
+    /// Verify an attribute from this dialect on the argument at 'argIndex' for
+    /// the region at 'regionIndex' on the given operation. Returns failure if
+    /// the verification failed, success otherwise. This hook may optionally be
+    /// invoked from any operation containing a region.
+    LogicalResult verifyRegionArgAttribute(Operation *,
+                                           unsigned regionIndex,
+                                           unsigned argIndex,
+                                           NamedAttribute) override;
   }];
 
   let useDefaultAttributePrinterParser = 1;
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 19991a6f89d80f..55358ebc6e86ef 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
 #define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
 
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/Support/LogicalResult.h"
@@ -25,6 +26,7 @@ class IRBuilderBase;
 namespace mlir {
 namespace LLVM {
 class ModuleTranslation;
+class LLVMFuncOp;
 } // namespace LLVM
 
 /// Base class for dialect interfaces providing translation to LLVM IR.
@@ -58,6 +60,16 @@ class LLVMTranslationDialectInterface
                  LLVM::ModuleTranslation &moduleTranslation) const {
     return success();
   }
+
+  /// Hook for derived dialect interface to translate or act on a derived
+  /// dialect attribute that appears on a function parameter. This gets called
+  /// after the function operation has been translated.
+  virtual LogicalResult
+  convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+                       NamedAttribute attr,
+                       LLVM::ModuleTranslation &moduleTranslation) const {
+    return success();
+  }
 };
 
 /// Interface collection for translation to LLVM IR, dispatches to a concrete
@@ -90,6 +102,20 @@ class LLVMTranslationInterface
     }
     return success();
   }
+
+  /// Acts on the given function operation using the interface implemented by
+  /// the dialect of one of the function parameter attributes.
+  virtual LogicalResult
+  convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+                       NamedAttribute attribute,
+                       LLVM::ModuleTranslation &moduleTranslation) const {
+    if (const LLVMTranslationDialectInterface *iface =
+            getInterfaceFor(attribute.getNameDialect())) {
+      return iface->convertParameterAttr(function, argIdx, attribute,
+                                         moduleTranslation);
+    }
+    return success();
+  }
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index d6b03aca28d24d..61cf30a123b0c7 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -327,7 +327,8 @@ class ModuleTranslation {
                            ArrayRef<llvm::Instruction *> instructions);
 
   /// Translates parameter attributes and adds them to the returned AttrBuilder.
-  llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
+  FailureOr<llvm::AttrBuilder>
+  convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
 
   /// Original and translated module.
   Operation *mlirModule;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index aa49c4dc31fbc0..dc7816318131e4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1077,6 +1077,34 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
   return success();
 }
 
+LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
+                                                    unsigned regionIndex,
+                                                    unsigned argIndex,
+                                                    NamedAttribute argAttr) {
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return success();
+
+  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
+  auto attrName = argAttr.getName();
+  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
+    if (!isKernel)
+      return op->emitError()
+             << "'" << attrName
+             << "' attribute must be present only on kernel arguments.";
+    if (!llvm::isa<UnitAttr>(argAttr.getValue()))
+      return op->emitError()
+             << "'" << attrName << "' must be a unit attribute.";
+    if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName()))
+      return op->emitError()
+             << "'" << attrName
+             << "' attribute requires the argument to also have attribute '"
+             << LLVM::LLVMDialect::getByValAttrName() << "'.";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/AttrKindDetail.h b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
index 7f81777886f56e..55a364856bd6f9 100644
--- a/mlir/lib/Target/LLVMIR/AttrKindDetail.h
+++ b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
@@ -59,6 +59,19 @@ getAttrKindToNameMapping() {
   return kindNamePairs;
 }
 
+static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+getAttrNameToKindMapping() {
+  static auto attrNameToKindMapping = []() {
+    static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+        nameKindMap;
+    for (auto kindNamePair : getAttrKindToNameMapping()) {
+      nameKindMap.insert({kindNamePair.second, kindNamePair.first});
+    }
+    return nameKindMap;
+  }();
+  return attrNameToKindMapping;
+}
+
 } // namespace detail
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 45eb8402a7344f..5e1712527d7015 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -201,6 +201,63 @@ class NVVMDialectLLVMIRTranslationInterface
     }
     return success();
   }
+
+  LogicalResult
+  convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
+                       LLVM::ModuleTranslation &moduleTranslation) const final {
+
+    llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
+    llvm::Function *llvmFunc =
+        moduleTranslation.lookupFunction(funcOp.getName());
+    auto nvvmAnnotations =
+        moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
+
+    if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
+      llvm::MDNode *gridConstantMetaData = nullptr;
+
+      // Check if a 'grid_constant' metadata node exists for the given function
+      for (int i = nvvmAnnotations->getNumOperands() - 1; i >= 0; --i) {
+        auto opnd = nvvmAnnotations->getOperand(i);
+        if (opnd->getNumOperands() == 3 &&
+            opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
+            opnd->getOperand(1) ==
+                llvm::MDString::get(llvmContext, "grid_constant")) {
+          gridConstantMetaData = opnd;
+          break;
+        }
+      }
+
+      // 'grid_constant' is a function-level meta data node with a list of
+      // integers, where each integer n denotes that the nth parameter has the
+      // grid_constant annotation (numbering from 1). This requires aggregating
+      // the indices of the individual parameters that have this attribute.
+      llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
+      if (gridConstantMetaData == nullptr) {
+        // Create a new 'grid_constant' metadata node
+        SmallVector<llvm::Metadata *> gridConstMetadata = {
+            llvm::ValueAsMetadata::getConstant(
+                llvm::ConstantInt::get(i32, argIdx + 1))};
+        llvm::Metadata *llvmMetadata[] = {
+            llvm::ValueAsMetadata::get(llvmFunc),
+            llvm::MDString::get(llvmContext, "grid_constant"),
+            llvm::MDNode::get(llvmContext, gridConstMetadata)};
+        llvm::MDNode *llvmMetadataNode =
+            llvm::MDNode::get(llvmContext, llvmMetadata);
+        nvvmAnnotations->addOperand(llvmMetadataNode);
+      } else {
+        // Append argIdx + 1 to the 'grid_constant' argument list
+        if (auto argList =
+                dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
+          auto clonedArgList = argList->clone();
+          clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
+              llvm::ConstantInt::get(i32, argIdx + 1))));
+          gridConstantMetaData->replaceOperandWith(
+              2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
+        }
+      }
+    }
+    return success();
+  }
 };
 } // namespace
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 2763a0fdd62aba..574dbfa177b9bb 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1174,28 +1174,29 @@ static void convertFunctionAttributes(LLVMFuncOp func,
   llvmFunc->setMemoryEffects(newMemEffects);
 }
 
-llvm::AttrBuilder
-ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
+FailureOr<llvm::AttrBuilder>
+ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
+                                         DictionaryAttr paramAttrs) {
   llvm::AttrBuilder attrBuilder(llvmModule->getContext());
-
-  for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
-    Attribute attr = paramAttrs.get(mlirName);
-    // Skip attributes that are not present.
-    if (!attr)
-      continue;
-
-    // NOTE: C++17 does not support capturing structured bindings.
-    llvm::Attribute::AttrKind llvmKindCap = llvmKind;
-
-    llvm::TypeSwitch<Attribute>(attr)
-        .Case<TypeAttr>([&](auto typeAttr) {
-          attrBuilder.addTypeAttr(llvmKindCap,
-                                  convertType(typeAttr.getValue()));
-        })
-        .Case<IntegerAttr>([&](auto intAttr) {
-          attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
-        })
-        .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); });
+  auto attrNameToKindMapping = getAttrNameToKindMapping();
+
+  for (auto namedAttr : paramAttrs) {
+    auto it = attrNameToKindMapping.find(namedAttr.getName());
+    if (it != attrNameToKindMapping.end()) {
+      llvm::Attribute::AttrKind llvmKind = it->second;
+
+      llvm::TypeSwitch<Attribute>(namedAttr.getValue())
+          .Case<TypeAttr>([&](auto typeAttr) {
+            attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
+          })
+          .Case<IntegerAttr>([&](auto intAttr) {
+            attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
+          })
+          .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); });
+    } else if (namedAttr.getNameDialect()) {
+      if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
+        return failure();
+    }
   }
 
   return attrBuilder;
@@ -1224,14 +1225,21 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
     // Convert result attributes.
     if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
       DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
-      llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
+      FailureOr<llvm::AttrBuilder> attrBuilder =
+          convertParameterAttrs(function, -1, resultAttrs);
+      if (failed(attrBuilder))
+        return failure();
+      llvmFunc->addRetAttrs(*attrBuilder);
     }
 
     // Convert argument attributes.
     for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
       if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
-        llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
-        llvmArg.addAttrs(attrBuilder);
+        FailureOr<llvm::AttrBuilder> attrBuilder =
+            convertParameterAttrs(function, argIdx, argAttrs);
+        if (failed(attrBuilder))
+          return failure();
+        llvmArg.addAttrs(*attrBuilder);
       }
     }
 
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index ce483ddab22a0e..0369f45ca6a015 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -472,3 +472,29 @@ gpu.module @module_1 [#nvvm.target<chip = "sm_90", features = "+ptx70", link = [
 
 gpu.module @module_2 [#nvvm.target<chip = "sm_90">, #nvvm.target<chip = "sm_80">, #nvvm.target<chip = "sm_70">] {
 }
+
+// CHECK-LABEL : nvvm.grid_constant
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute must be present only on kernel arguments}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) {
+  llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute requires the argument to also have attribute 'llvm.byval'}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f6..6dc47d08fc5c81 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -518,3 +518,20 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4
   llvm.return
 }
 
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}
+
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1, i32 3}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}

>From bccf82b181e3d448811f1e5f58b5110d0dd193a3 Mon Sep 17 00:00:00 2001
From: rsurendran <rsurendran at nvidia.com>
Date: Thu, 18 Jan 2024 20:46:06 -0800
Subject: [PATCH 2/4] Address review comments

---
 llvm/include/llvm/IR/Metadata.h                | 10 +++++-----
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td    |  4 ++--
 .../Target/LLVMIR/LLVMTranslationInterface.h   |  2 +-
 .../mlir/Target/LLVMIR/ModuleTranslation.h     |  1 +
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp     | 18 ++++++++++--------
 mlir/lib/Target/LLVMIR/AttrKindDetail.h        |  5 +++--
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp   |  7 +++----
 7 files changed, 25 insertions(+), 22 deletions(-)

diff --git a/llvm/include/llvm/IR/Metadata.h b/llvm/include/llvm/IR/Metadata.h
index 4498423c4c460d..b38cd6a2fc458e 100644
--- a/llvm/include/llvm/IR/Metadata.h
+++ b/llvm/include/llvm/IR/Metadata.h
@@ -1701,7 +1701,7 @@ class NamedMDNode : public ilist_node<NamedMDNode> {
 
   explicit NamedMDNode(const Twine &N);
 
-  template <class T1, class T2> class op_iterator_impl {
+  template <class T1> class op_iterator_impl {
     friend class NamedMDNode;
 
     const NamedMDNode *Node = nullptr;
@@ -1711,10 +1711,10 @@ class NamedMDNode : public ilist_node<NamedMDNode> {
 
   public:
     using iterator_category = std::bidirectional_iterator_tag;
-    using value_type = T2;
+    using value_type = T1;
     using difference_type = std::ptrdiff_t;
     using pointer = value_type *;
-    using reference = value_type &;
+    using reference = value_type;
 
     op_iterator_impl() = default;
 
@@ -1775,12 +1775,12 @@ class NamedMDNode : public ilist_node<NamedMDNode> {
   // ---------------------------------------------------------------------------
   // Operand Iterator interface...
   //
-  using op_iterator = op_iterator_impl<MDNode *, MDNode>;
+  using op_iterator = op_iterator_impl<MDNode *>;
 
   op_iterator op_begin() { return op_iterator(this, 0); }
   op_iterator op_end()   { return op_iterator(this, getNumOperands()); }
 
-  using const_op_iterator = op_iterator_impl<const MDNode *, MDNode>;
+  using const_op_iterator = op_iterator_impl<const MDNode *>;
 
   const_op_iterator op_begin() const { return const_op_iterator(this, 0); }
   const_op_iterator op_end()   const { return const_op_iterator(this, getNumOperands()); }
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 1fc5ee2c32bd49..159411a450308f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -68,10 +68,10 @@ def NVVM_Dialect : Dialect {
     /// the region at 'regionIndex' on the given operation. Returns failure if
     /// the verification failed, success otherwise. This hook may optionally be
     /// invoked from any operation containing a region.
-    LogicalResult verifyRegionArgAttribute(Operation *,
+    LogicalResult verifyRegionArgAttribute(Operation *op,
                                            unsigned regionIndex,
                                            unsigned argIndex,
-                                           NamedAttribute) override;
+                                           NamedAttribute argAttr) override;
   }];
 
   let useDefaultAttributePrinterParser = 1;
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 55358ebc6e86ef..8bc0cad0d701bc 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -114,7 +114,7 @@ class LLVMTranslationInterface
       return iface->convertParameterAttr(function, argIdx, attribute,
                                          moduleTranslation);
     }
-    return success();
+    return failure();
   }
 };
 
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 61cf30a123b0c7..fb4392eb223c7f 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -327,6 +327,7 @@ class ModuleTranslation {
                            ArrayRef<llvm::Instruction *> instructions);
 
   /// Translates parameter attributes and adds them to the returned AttrBuilder.
+  /// Returns failure if any of the translations failed.
   FailureOr<llvm::AttrBuilder>
   convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index dc7816318131e4..024e3bdd0c260d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1086,20 +1086,22 @@ LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
     return success();
 
   bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
-  auto attrName = argAttr.getName();
+  StringAttr attrName = argAttr.getName();
   if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
-    if (!isKernel)
+    if (!isKernel) {
       return op->emitError()
              << "'" << attrName
-             << "' attribute must be present only on kernel arguments.";
-    if (!llvm::isa<UnitAttr>(argAttr.getValue()))
-      return op->emitError()
-             << "'" << attrName << "' must be a unit attribute.";
-    if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName()))
+             << "' attribute must be present only on kernel arguments";
+    }
+    if (!isa<UnitAttr>(argAttr.getValue())) {
+      return op->emitError() << "'" << attrName << "' must be a unit attribute";
+    }
+    if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
       return op->emitError()
              << "'" << attrName
              << "' attribute requires the argument to also have attribute '"
-             << LLVM::LLVMDialect::getByValAttrName() << "'.";
+             << LLVM::LLVMDialect::getByValAttrName() << "'";
+    }
   }
 
   return success();
diff --git a/mlir/lib/Target/LLVMIR/AttrKindDetail.h b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
index 55a364856bd6f9..b01858ea814380 100644
--- a/mlir/lib/Target/LLVMIR/AttrKindDetail.h
+++ b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
@@ -59,11 +59,12 @@ getAttrKindToNameMapping() {
   return kindNamePairs;
 }
 
+/// Returns a dense map from LLVM attribute name to their kind in LLVM IR
+/// dialect.
 static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
 getAttrNameToKindMapping() {
   static auto attrNameToKindMapping = []() {
-    static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
-        nameKindMap;
+    llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind> nameKindMap;
     for (auto kindNamePair : getAttrKindToNameMapping()) {
       nameKindMap.insert({kindNamePair.second, kindNamePair.first});
     }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 5e1712527d7015..ea9fe2635461f2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -209,15 +209,14 @@ class NVVMDialectLLVMIRTranslationInterface
     llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
     llvm::Function *llvmFunc =
         moduleTranslation.lookupFunction(funcOp.getName());
-    auto nvvmAnnotations =
+    llvm::NamedMDNode *nvvmAnnotations =
         moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
 
     if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
       llvm::MDNode *gridConstantMetaData = nullptr;
 
       // Check if a 'grid_constant' metadata node exists for the given function
-      for (int i = nvvmAnnotations->getNumOperands() - 1; i >= 0; --i) {
-        auto opnd = nvvmAnnotations->getOperand(i);
+      for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
         if (opnd->getNumOperands() == 3 &&
             opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
             opnd->getOperand(1) ==
@@ -248,7 +247,7 @@ class NVVMDialectLLVMIRTranslationInterface
         // Append argIdx + 1 to the 'grid_constant' argument list
         if (auto argList =
                 dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
-          auto clonedArgList = argList->clone();
+          llvm::TempMDTuple clonedArgList = argList->clone();
           clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
               llvm::ConstantInt::get(i32, argIdx + 1))));
           gridConstantMetaData->replaceOperandWith(

>From af6966e3c6df5030aa2b01367ba8201b2287eda1 Mon Sep 17 00:00:00 2001
From: rsurendran <rsurendran at nvidia.com>
Date: Thu, 25 Jan 2024 13:06:22 -0800
Subject: [PATCH 3/4] Remove braces from single statement if

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 024e3bdd0c260d..8920a2b23f9fa0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1093,9 +1093,8 @@ LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
              << "'" << attrName
              << "' attribute must be present only on kernel arguments";
     }
-    if (!isa<UnitAttr>(argAttr.getValue())) {
+    if (!isa<UnitAttr>(argAttr.getValue()))
       return op->emitError() << "'" << attrName << "' must be a unit attribute";
-    }
     if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
       return op->emitError()
              << "'" << attrName

>From 132fa6fbe3ff529b117c36dda43d7b3af87a5874 Mon Sep 17 00:00:00 2001
From: rsurendran <rsurendran at nvidia.com>
Date: Sun, 28 Jan 2024 18:26:58 -0800
Subject: [PATCH 4/4] Emit warning instead of returning failure for unhandled
 dialect attributes

---
 mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 8bc0cad0d701bc..4a8ee06f5dcc70 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -114,7 +114,9 @@ class LLVMTranslationInterface
       return iface->convertParameterAttr(function, argIdx, attribute,
                                          moduleTranslation);
     }
-    return failure();
+    function.emitWarning("Unhandled parameter attribute '" +
+                         attribute.getName().str() + "'");
+    return success();
   }
 };
 



More information about the llvm-commits mailing list