[Mlir-commits] [llvm] [mlir] [mlir][nvvm]Add support for grid_constant attribute on LLVM function arguments (PR #78228)
Rishi Surendran
llvmlistbot at llvm.org
Thu Jan 18 20:47:38 PST 2024
https://github.com/rishisurendran updated https://github.com/llvm/llvm-project/pull/78228
>From 28cdab8050cdd37eb1c8c7de1a7782efe66563ed Mon Sep 17 00:00:00 2001
From: rsurendran <rsurendran at nvidia.com>
Date: Mon, 15 Jan 2024 15:36:14 -0800
Subject: [PATCH 1/2] Add support for adding 'nvvm.grid_constant' attribute to
LLVM function arguments
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 13 +++++
.../Target/LLVMIR/LLVMTranslationInterface.h | 26 +++++++++
.../mlir/Target/LLVMIR/ModuleTranslation.h | 3 +-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 28 +++++++++
mlir/lib/Target/LLVMIR/AttrKindDetail.h | 13 +++++
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 57 +++++++++++++++++++
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 56 ++++++++++--------
mlir/test/Dialect/LLVMIR/nvvm.mlir | 26 +++++++++
mlir/test/Target/LLVMIR/nvvmir.mlir | 17 ++++++
9 files changed, 214 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f986..1fc5ee2c32bd492 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -59,6 +59,19 @@ def NVVM_Dialect : Dialect {
/// Get the name of the attribute used to annotate max number of
/// registers that can be allocated per thread.
static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
+
+ /// Get the name of the attribute used to annotate kernel arguments that
+ /// are grid constants.
+ static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
+
+ /// Verify an attribute from this dialect on the argument at 'argIndex' for
+ /// the region at 'regionIndex' on the given operation. Returns failure if
+ /// the verification failed, success otherwise. This hook may optionally be
+ /// invoked from any operation containing a region.
+ LogicalResult verifyRegionArgAttribute(Operation *,
+ unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute) override;
}];
let useDefaultAttributePrinterParser = 1;
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 19991a6f89d80fa..55358ebc6e86efc 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -13,6 +13,7 @@
#ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
#define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/LogicalResult.h"
@@ -25,6 +26,7 @@ class IRBuilderBase;
namespace mlir {
namespace LLVM {
class ModuleTranslation;
+class LLVMFuncOp;
} // namespace LLVM
/// Base class for dialect interfaces providing translation to LLVM IR.
@@ -58,6 +60,16 @@ class LLVMTranslationDialectInterface
LLVM::ModuleTranslation &moduleTranslation) const {
return success();
}
+
+ /// Hook for derived dialect interface to translate or act on a derived
+ /// dialect attribute that appears on a function parameter. This gets called
+ /// after the function operation has been translated.
+ virtual LogicalResult
+ convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+ NamedAttribute attr,
+ LLVM::ModuleTranslation &moduleTranslation) const {
+ return success();
+ }
};
/// Interface collection for translation to LLVM IR, dispatches to a concrete
@@ -90,6 +102,20 @@ class LLVMTranslationInterface
}
return success();
}
+
+ /// Acts on the given function operation using the interface implemented by
+ /// the dialect of one of the function parameter attributes.
+ virtual LogicalResult
+ convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const {
+ if (const LLVMTranslationDialectInterface *iface =
+ getInterfaceFor(attribute.getNameDialect())) {
+ return iface->convertParameterAttr(function, argIdx, attribute,
+ moduleTranslation);
+ }
+ return success();
+ }
};
} // namespace mlir
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index d6b03aca28d24d5..61cf30a123b0c7d 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -327,7 +327,8 @@ class ModuleTranslation {
ArrayRef<llvm::Instruction *> instructions);
/// Translates parameter attributes and adds them to the returned AttrBuilder.
- llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
+ FailureOr<llvm::AttrBuilder>
+ convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
/// Original and translated module.
Operation *mlirModule;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index aa49c4dc31fbc02..dc7816318131e41 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1077,6 +1077,34 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
return success();
}
+LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
+ unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute argAttr) {
+ auto funcOp = dyn_cast<FunctionOpInterface>(op);
+ if (!funcOp)
+ return success();
+
+ bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
+ auto attrName = argAttr.getName();
+ if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
+ if (!isKernel)
+ return op->emitError()
+ << "'" << attrName
+ << "' attribute must be present only on kernel arguments.";
+ if (!llvm::isa<UnitAttr>(argAttr.getValue()))
+ return op->emitError()
+ << "'" << attrName << "' must be a unit attribute.";
+ if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName()))
+ return op->emitError()
+ << "'" << attrName
+ << "' attribute requires the argument to also have attribute '"
+ << LLVM::LLVMDialect::getByValAttrName() << "'.";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/AttrKindDetail.h b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
index 7f81777886f56eb..55a364856bd6f99 100644
--- a/mlir/lib/Target/LLVMIR/AttrKindDetail.h
+++ b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
@@ -59,6 +59,19 @@ getAttrKindToNameMapping() {
return kindNamePairs;
}
+static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+getAttrNameToKindMapping() {
+ static auto attrNameToKindMapping = []() {
+ static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+ nameKindMap;
+ for (auto kindNamePair : getAttrKindToNameMapping()) {
+ nameKindMap.insert({kindNamePair.second, kindNamePair.first});
+ }
+ return nameKindMap;
+ }();
+ return attrNameToKindMapping;
+}
+
} // namespace detail
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 45eb8402a7344f4..5e1712527d70151 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -201,6 +201,63 @@ class NVVMDialectLLVMIRTranslationInterface
}
return success();
}
+
+ LogicalResult
+ convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+
+ llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
+ llvm::Function *llvmFunc =
+ moduleTranslation.lookupFunction(funcOp.getName());
+ auto nvvmAnnotations =
+ moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
+
+ if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
+ llvm::MDNode *gridConstantMetaData = nullptr;
+
+ // Check if a 'grid_constant' metadata node exists for the given function
+ for (int i = nvvmAnnotations->getNumOperands() - 1; i >= 0; --i) {
+ auto opnd = nvvmAnnotations->getOperand(i);
+ if (opnd->getNumOperands() == 3 &&
+ opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
+ opnd->getOperand(1) ==
+ llvm::MDString::get(llvmContext, "grid_constant")) {
+ gridConstantMetaData = opnd;
+ break;
+ }
+ }
+
+ // 'grid_constant' is a function-level meta data node with a list of
+ // integers, where each integer n denotes that the nth parameter has the
+ // grid_constant annotation (numbering from 1). This requires aggregating
+ // the indices of the individual parameters that have this attribute.
+ llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
+ if (gridConstantMetaData == nullptr) {
+ // Create a new 'grid_constant' metadata node
+ SmallVector<llvm::Metadata *> gridConstMetadata = {
+ llvm::ValueAsMetadata::getConstant(
+ llvm::ConstantInt::get(i32, argIdx + 1))};
+ llvm::Metadata *llvmMetadata[] = {
+ llvm::ValueAsMetadata::get(llvmFunc),
+ llvm::MDString::get(llvmContext, "grid_constant"),
+ llvm::MDNode::get(llvmContext, gridConstMetadata)};
+ llvm::MDNode *llvmMetadataNode =
+ llvm::MDNode::get(llvmContext, llvmMetadata);
+ nvvmAnnotations->addOperand(llvmMetadataNode);
+ } else {
+ // Append argIdx + 1 to the 'grid_constant' argument list
+ if (auto argList =
+ dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
+ auto clonedArgList = argList->clone();
+ clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
+ llvm::ConstantInt::get(i32, argIdx + 1))));
+ gridConstantMetaData->replaceOperandWith(
+ 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
+ }
+ }
+ }
+ return success();
+ }
};
} // namespace
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 2763a0fdd62aba1..574dbfa177b9bb3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1174,28 +1174,29 @@ static void convertFunctionAttributes(LLVMFuncOp func,
llvmFunc->setMemoryEffects(newMemEffects);
}
-llvm::AttrBuilder
-ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
+FailureOr<llvm::AttrBuilder>
+ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
+ DictionaryAttr paramAttrs) {
llvm::AttrBuilder attrBuilder(llvmModule->getContext());
-
- for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
- Attribute attr = paramAttrs.get(mlirName);
- // Skip attributes that are not present.
- if (!attr)
- continue;
-
- // NOTE: C++17 does not support capturing structured bindings.
- llvm::Attribute::AttrKind llvmKindCap = llvmKind;
-
- llvm::TypeSwitch<Attribute>(attr)
- .Case<TypeAttr>([&](auto typeAttr) {
- attrBuilder.addTypeAttr(llvmKindCap,
- convertType(typeAttr.getValue()));
- })
- .Case<IntegerAttr>([&](auto intAttr) {
- attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
- })
- .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); });
+ auto attrNameToKindMapping = getAttrNameToKindMapping();
+
+ for (auto namedAttr : paramAttrs) {
+ auto it = attrNameToKindMapping.find(namedAttr.getName());
+ if (it != attrNameToKindMapping.end()) {
+ llvm::Attribute::AttrKind llvmKind = it->second;
+
+ llvm::TypeSwitch<Attribute>(namedAttr.getValue())
+ .Case<TypeAttr>([&](auto typeAttr) {
+ attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
+ })
+ .Case<IntegerAttr>([&](auto intAttr) {
+ attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
+ })
+ .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); });
+ } else if (namedAttr.getNameDialect()) {
+ if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
+ return failure();
+ }
}
return attrBuilder;
@@ -1224,14 +1225,21 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
// Convert result attributes.
if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
- llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(function, -1, resultAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ llvmFunc->addRetAttrs(*attrBuilder);
}
// Convert argument attributes.
for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
- llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
- llvmArg.addAttrs(attrBuilder);
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(function, argIdx, argAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ llvmArg.addAttrs(*attrBuilder);
}
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index ce483ddab22a0ee..0369f45ca6a0156 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -472,3 +472,29 @@ gpu.module @module_1 [#nvvm.target<chip = "sm_90", features = "+ptx70", link = [
gpu.module @module_2 [#nvvm.target<chip = "sm_90">, #nvvm.target<chip = "sm_80">, #nvvm.target<chip = "sm_70">] {
}
+
+// CHECK-LABEL : nvvm.grid_constant
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute must be present only on kernel arguments}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) {
+ llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute requires the argument to also have attribute 'llvm.byval'}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f68..6dc47d08fc5c812 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -518,3 +518,20 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4
llvm.return
}
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
+
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1, i32 3}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+ llvm.return
+}
>From bccf82b181e3d448811f1e5f58b5110d0dd193a3 Mon Sep 17 00:00:00 2001
From: rsurendran <rsurendran at nvidia.com>
Date: Thu, 18 Jan 2024 20:46:06 -0800
Subject: [PATCH 2/2] Address review comments
---
llvm/include/llvm/IR/Metadata.h | 10 +++++-----
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 4 ++--
.../Target/LLVMIR/LLVMTranslationInterface.h | 2 +-
.../mlir/Target/LLVMIR/ModuleTranslation.h | 1 +
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 18 ++++++++++--------
mlir/lib/Target/LLVMIR/AttrKindDetail.h | 5 +++--
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 7 +++----
7 files changed, 25 insertions(+), 22 deletions(-)
diff --git a/llvm/include/llvm/IR/Metadata.h b/llvm/include/llvm/IR/Metadata.h
index 4498423c4c460d9..b38cd6a2fc458e1 100644
--- a/llvm/include/llvm/IR/Metadata.h
+++ b/llvm/include/llvm/IR/Metadata.h
@@ -1701,7 +1701,7 @@ class NamedMDNode : public ilist_node<NamedMDNode> {
explicit NamedMDNode(const Twine &N);
- template <class T1, class T2> class op_iterator_impl {
+ template <class T1> class op_iterator_impl {
friend class NamedMDNode;
const NamedMDNode *Node = nullptr;
@@ -1711,10 +1711,10 @@ class NamedMDNode : public ilist_node<NamedMDNode> {
public:
using iterator_category = std::bidirectional_iterator_tag;
- using value_type = T2;
+ using value_type = T1;
using difference_type = std::ptrdiff_t;
using pointer = value_type *;
- using reference = value_type &;
+ using reference = value_type;
op_iterator_impl() = default;
@@ -1775,12 +1775,12 @@ class NamedMDNode : public ilist_node<NamedMDNode> {
// ---------------------------------------------------------------------------
// Operand Iterator interface...
//
- using op_iterator = op_iterator_impl<MDNode *, MDNode>;
+ using op_iterator = op_iterator_impl<MDNode *>;
op_iterator op_begin() { return op_iterator(this, 0); }
op_iterator op_end() { return op_iterator(this, getNumOperands()); }
- using const_op_iterator = op_iterator_impl<const MDNode *, MDNode>;
+ using const_op_iterator = op_iterator_impl<const MDNode *>;
const_op_iterator op_begin() const { return const_op_iterator(this, 0); }
const_op_iterator op_end() const { return const_op_iterator(this, getNumOperands()); }
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 1fc5ee2c32bd492..159411a450308f3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -68,10 +68,10 @@ def NVVM_Dialect : Dialect {
/// the region at 'regionIndex' on the given operation. Returns failure if
/// the verification failed, success otherwise. This hook may optionally be
/// invoked from any operation containing a region.
- LogicalResult verifyRegionArgAttribute(Operation *,
+ LogicalResult verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
- NamedAttribute) override;
+ NamedAttribute argAttr) override;
}];
let useDefaultAttributePrinterParser = 1;
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 55358ebc6e86efc..8bc0cad0d701bcf 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -114,7 +114,7 @@ class LLVMTranslationInterface
return iface->convertParameterAttr(function, argIdx, attribute,
moduleTranslation);
}
- return success();
+ return failure();
}
};
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 61cf30a123b0c7d..fb4392eb223c7f6 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -327,6 +327,7 @@ class ModuleTranslation {
ArrayRef<llvm::Instruction *> instructions);
/// Translates parameter attributes and adds them to the returned AttrBuilder.
+ /// Returns failure if any of the translations failed.
FailureOr<llvm::AttrBuilder>
convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index dc7816318131e41..024e3bdd0c260d7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1086,20 +1086,22 @@ LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
return success();
bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
- auto attrName = argAttr.getName();
+ StringAttr attrName = argAttr.getName();
if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
- if (!isKernel)
+ if (!isKernel) {
return op->emitError()
<< "'" << attrName
- << "' attribute must be present only on kernel arguments.";
- if (!llvm::isa<UnitAttr>(argAttr.getValue()))
- return op->emitError()
- << "'" << attrName << "' must be a unit attribute.";
- if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName()))
+ << "' attribute must be present only on kernel arguments";
+ }
+ if (!isa<UnitAttr>(argAttr.getValue())) {
+ return op->emitError() << "'" << attrName << "' must be a unit attribute";
+ }
+ if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
return op->emitError()
<< "'" << attrName
<< "' attribute requires the argument to also have attribute '"
- << LLVM::LLVMDialect::getByValAttrName() << "'.";
+ << LLVM::LLVMDialect::getByValAttrName() << "'";
+ }
}
return success();
diff --git a/mlir/lib/Target/LLVMIR/AttrKindDetail.h b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
index 55a364856bd6f99..b01858ea814380d 100644
--- a/mlir/lib/Target/LLVMIR/AttrKindDetail.h
+++ b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
@@ -59,11 +59,12 @@ getAttrKindToNameMapping() {
return kindNamePairs;
}
+/// Returns a dense map from LLVM attribute name to their kind in LLVM IR
+/// dialect.
static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
getAttrNameToKindMapping() {
static auto attrNameToKindMapping = []() {
- static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
- nameKindMap;
+ llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind> nameKindMap;
for (auto kindNamePair : getAttrKindToNameMapping()) {
nameKindMap.insert({kindNamePair.second, kindNamePair.first});
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 5e1712527d70151..ea9fe2635461f23 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -209,15 +209,14 @@ class NVVMDialectLLVMIRTranslationInterface
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
llvm::Function *llvmFunc =
moduleTranslation.lookupFunction(funcOp.getName());
- auto nvvmAnnotations =
+ llvm::NamedMDNode *nvvmAnnotations =
moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
llvm::MDNode *gridConstantMetaData = nullptr;
// Check if a 'grid_constant' metadata node exists for the given function
- for (int i = nvvmAnnotations->getNumOperands() - 1; i >= 0; --i) {
- auto opnd = nvvmAnnotations->getOperand(i);
+ for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
if (opnd->getNumOperands() == 3 &&
opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
opnd->getOperand(1) ==
@@ -248,7 +247,7 @@ class NVVMDialectLLVMIRTranslationInterface
// Append argIdx + 1 to the 'grid_constant' argument list
if (auto argList =
dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
- auto clonedArgList = argList->clone();
+ llvm::TempMDTuple clonedArgList = argList->clone();
clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
llvm::ConstantInt::get(i32, argIdx + 1))));
gridConstantMetaData->replaceOperandWith(
More information about the Mlir-commits
mailing list