[Mlir-commits] [mlir] [MLIR][LLVM] Attach kernel metadata representation to `llvm.func` (PR #101314)
Victor Perez
llvmlistbot at llvm.org
Thu Aug 1 03:46:21 PDT 2024
https://github.com/victor-eds updated https://github.com/llvm/llvm-project/pull/101314
>From 3744c1a3f8b801d7e459e25ee9d31a385fe13cdf Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 31 Jul 2024 10:49:26 +0100
Subject: [PATCH 1/6] [MLIR][LLVM] Attach kernel metadata representation to
`llvm.func`
Add optional attributes to `llvm.func` representing LLVM so-called
"kernel" metadata:
- [`vec_type_hint`](https://clang.llvm.org/docs/AttributeReference.html#vec-type-hint)
- [`work_group_size_hint`](https://clang.llvm.org/docs/AttributeReference.html#work-group-size-hint)
- [`reqd_work_group_size`](https://clang.llvm.org/docs/AttributeReference.html#reqd-work-group-size)
- [`intel_reqd_sub_group_size`](https://clang.llvm.org/docs/AttributeReference.html#intel-reqd-sub-group-size).
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
.../mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 11 +++
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 7 +-
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 97 +++++++++++++++++++
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 75 ++++++++++++++
mlir/test/Dialect/LLVMIR/func.mlir | 42 +++++++-
.../Target/LLVMIR/Import/metadata-kernel.ll | 34 +++++++
mlir/test/Target/LLVMIR/llvmir.mlir | 44 +++++++++
7 files changed, 308 insertions(+), 2 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 695a962bcab9b..0ecd2dcacffc1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1071,6 +1071,17 @@ def LLVM_UndefAttr : LLVM_Attr<"Undef", "undef">;
/// Folded into from LLVM::PoisonOp.
def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
+//===----------------------------------------------------------------------===//
+// VecTypeHintAttr
+//===----------------------------------------------------------------------===//
+
+/// Represents "vec_type_hint" values
+def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
+ let parameters = (ins "TypeAttr":$hint,
+ DefaultValuedParameter<"bool", "false">:$is_signed);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
//===----------------------------------------------------------------------===//
// ZeroAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 260d42185b57f..fde42682d807a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1456,7 +1456,12 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<UnitAttr>:$always_inline,
OptionalAttr<UnitAttr>:$no_unwind,
OptionalAttr<UnitAttr>:$will_return,
- OptionalAttr<UnitAttr>:$optimize_none
+ OptionalAttr<UnitAttr>:$optimize_none,
+ // Kernel metadata
+ OptionalAttr<LLVM_VecTypeHintAttr>:$vec_type_hint,
+ OptionalAttr<DenseI32ArrayAttr>:$work_group_size_hint,
+ OptionalAttr<DenseI32ArrayAttr>:$reqd_work_group_size,
+ OptionalAttr<I32Attr>:$intel_reqd_sub_group_size
);
let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 8b40b7b2df6c7..cb18dc5193352 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1918,6 +1918,99 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
}
+/// Extract constant integer value from metadata if this is constant. Return
+/// `std::nullopt` otherwise.
+static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
+ if (!md)
+ return {};
+
+ auto *c = dyn_cast<llvm::ConstantAsMetadata>(md);
+ if (!c)
+ return {};
+
+ auto *ci = dyn_cast<llvm::ConstantInt>(c->getValue());
+ if (!ci)
+ return {};
+
+ return ci->getValue().getSExtValue();
+}
+
+/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
+template <typename ConvertType>
+static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
+ ConvertType convertType) {
+ if (!md || md->getNumOperands() != 2)
+ return {};
+
+ auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(md->getOperand(0).get());
+ if (!hintMD)
+ return {};
+ TypeAttr hint = TypeAttr::get(convertType(hintMD->getType()));
+
+ std::optional<int32_t> optIsSigned = parseIntegerMD(md->getOperand(1).get());
+ if (!optIsSigned)
+ return {};
+ bool isSigned = *optIsSigned != 0;
+
+ return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
+}
+
+/// Convert an `MDNode` to an MLIR `DenseI32ArrayAttr` if possible.
+static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
+ llvm::MDNode *md) {
+ if (!md)
+ return {};
+ SmallVector<int32_t> vals;
+ for (const llvm::MDOperand &op : md->operands()) {
+ std::optional<int32_t> mdValue = parseIntegerMD(op.get());
+ if (!mdValue)
+ return {};
+ vals.push_back(*mdValue);
+ }
+ return builder.getDenseI32ArrayAttr(vals);
+}
+
+/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
+static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
+ if (!md || md->getNumOperands() != 1)
+ return {};
+ std::optional<int32_t> val = parseIntegerMD(md->getOperand(0));
+ if (!val)
+ return {};
+ return builder.getI32IntegerAttr(*val);
+}
+
+/// Process metadata found in kernel functions:
+/// - `vec_type_hint`
+/// - `work_group_size_hint`
+/// - `reqd_work_group_size`
+/// - `intel_reqd_sub_group_size`
+template <typename ConvertType>
+static void processKernelMetadata(llvm::Function *func, LLVMFuncOp funcOp,
+ ConvertType convertType) {
+ Builder builder(funcOp);
+
+ if (VecTypeHintAttr attr = convertVecTypeHint(
+ builder, func->getMetadata("vec_type_hint"), convertType)) {
+ funcOp.setVecTypeHintAttr(attr);
+ }
+
+ if (DenseI32ArrayAttr attr = convertDenseI32Array(
+ builder, func->getMetadata("work_group_size_hint"))) {
+ funcOp.setWorkGroupSizeHintAttr(attr);
+ }
+
+ if (DenseI32ArrayAttr attr = convertDenseI32Array(
+ builder, func->getMetadata("reqd_work_group_size"))) {
+ funcOp.setReqdWorkGroupSizeAttr(attr);
+ }
+
+ if (IntegerAttr attr = convertIntegerMD(
+ builder, func->getMetadata("intel_reqd_sub_group_size"))) {
+ funcOp.setIntelReqdSubGroupSizeAttr(attr);
+ }
+}
+
LogicalResult ModuleImport::processFunction(llvm::Function *func) {
clearRegionState();
@@ -1966,6 +2059,10 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
// Handle Function attributes.
processFunctionAttributes(func, funcOp);
+ // Handle Kernel Metadata
+ processKernelMetadata(func, funcOp,
+ [this](llvm::Type *type) { return convertType(type); });
+
// Convert non-debug metadata by using the dialect interface.
SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
func->getAllMetadata(allMetadata);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 3016d1846e00f..00ca8bdbf0c23 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1247,6 +1247,41 @@ static LogicalResult checkedAddLLVMFnAttribute(Location loc,
return success();
}
+/// Return a representation of `value` as metadata.
+static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
+ const llvm::APInt &value) {
+ llvm::Constant *constant = llvm::ConstantInt::get(context, value);
+ return llvm::ConstantAsMetadata::get(constant);
+}
+
+/// Return a representation of `value` as an MDNode.
+static llvm::MDNode *convertIntegerToMDNode(llvm::LLVMContext &context,
+ const llvm::APInt &value) {
+ return llvm::MDNode::get(context, convertIntegerToMetadata(context, value));
+}
+
+/// Return an MDNode encoding `vec_type_hint` metadata.
+static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
+ llvm::Type *type,
+ bool isSigned) {
+ llvm::Metadata *typeMD =
+ llvm::ConstantAsMetadata::get(llvm::UndefValue::get(type));
+ llvm::Metadata *isSignedMD =
+ convertIntegerToMetadata(context, llvm::APInt(32, isSigned ? 1 : 0));
+ return llvm::MDNode::get(context, {typeMD, isSignedMD});
+}
+
+/// Return an MDNode with a tuple given by the values in the input integer array
+/// attribute.
+static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
+ ArrayRef<int32_t> values) {
+ llvm::SmallVector<llvm::Metadata *> mds;
+ llvm::transform(values, std::back_inserter(mds), [&context](int32_t value) {
+ return convertIntegerToMetadata(context, llvm::APInt(32, value));
+ });
+ return llvm::MDNode::get(context, mds);
+}
+
/// Attaches the attributes listed in the given array attribute to `llvmFunc`.
/// Reports error to `loc` if any and returns immediately. Expects `attributes`
/// to be an array attribute containing either string attributes, treated as
@@ -1448,6 +1483,42 @@ static void convertFunctionAttributes(LLVMFuncOp func,
convertFunctionMemoryAttributes(func, llvmFunc);
}
+/// Converts function attributes from `func` and attaches them to `llvmFunc`.
+template <typename TypeConverter>
+static void convertFunctionKernelAttributes(LLVMFuncOp func,
+ llvm::Function *llvmFunc,
+ TypeConverter convertType) {
+ llvm::LLVMContext &llvmContext = llvmFunc->getContext();
+
+ if (auto vecTypeHint = func.getVecTypeHint()) {
+ Type type = vecTypeHint->getHint().getValue();
+ llvm::Type *llvmType = convertType(type);
+ bool isSigned = vecTypeHint->getIsSigned();
+ llvmFunc->setMetadata(
+ "vec_type_hint",
+ convertVecTypeHintToMDNode(llvmContext, llvmType, isSigned));
+ }
+
+ if (auto workGroupSizeHint = func.getWorkGroupSizeHint()) {
+ llvmFunc->setMetadata(
+ "work_group_size_hint",
+ convertIntegerArrayToMDNode(llvmContext, *workGroupSizeHint));
+ }
+
+ if (auto reqdWorkGroupSize = func.getReqdWorkGroupSize()) {
+ llvmFunc->setMetadata(
+ "reqd_work_group_size",
+ convertIntegerArrayToMDNode(llvmContext, *reqdWorkGroupSize));
+ }
+
+ if (auto intelReqdSubGroupSize = func.getIntelReqdSubGroupSize()) {
+ llvmFunc->setMetadata(
+ "intel_reqd_sub_group_size",
+ convertIntegerToMDNode(llvmContext,
+ llvm::APInt(32, *intelReqdSubGroupSize)));
+ }
+}
+
FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
DictionaryAttr paramAttrs) {
@@ -1492,6 +1563,10 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
// Convert function attributes.
convertFunctionAttributes(function, llvmFunc);
+ // Convert function kernel attributes to metadata
+ convertFunctionKernelAttributes(
+ function, llvmFunc, [this](Type type) { return convertType(type); });
+
// Convert function_entry_count attribute to metadata.
if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
llvmFunc->setEntryCount(entryCount.value());
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 0e29a548de72f..40b4e49f08a3e 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
// RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-op-generic %s | FileCheck %s --check-prefix=GENERIC
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s -mlir-print-debuginfo | mlir-opt -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
+// RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-debuginfo %s | mlir-opt -split-input-file -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-LLVM
module {
@@ -432,3 +432,43 @@ module {
// expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
}) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
}
+
+// -----
+
+// CHECK: @vec_type_hint()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32>
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: @vec_type_hint_signed()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: @vec_type_hint_signed_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: @vec_type_hint_float_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: @vec_type_hint_bfloat_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// -----
+
+// CHECK: @work_group_size_hint()
+// CHECK-SAME: work_group_size_hint = array<i32: 128, 128, 128>
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// -----
+
+// CHECK: @reqd_work_group_size_hint()
+// CHECK-SAME: reqd_work_group_size = array<i32: 128, 256, 128>
+llvm.func @reqd_work_group_size_hint() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// -----
+
+// CHECK: @intel_reqd_sub_group_size_hint()
+// CHECK-SAME: intel_reqd_sub_group_size = 32 : i32
+llvm.func @intel_reqd_sub_group_size_hint() attributes {llvm.intel_reqd_sub_group_size = 32 : i32}
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
new file mode 100644
index 0000000000000..963e40348c795
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
@@ -0,0 +1,34 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK: llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+declare !vec_type_hint !0 void @vec_type_hint()
+
+; CHECK: llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+declare !vec_type_hint !1 void @vec_type_hint_signed()
+
+; CHECK: llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+declare !vec_type_hint !2 void @vec_type_hint_signed_vec()
+
+; CHECK: llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+declare !vec_type_hint !3 void @vec_type_hint_float_vec()
+
+; CHECK: llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+declare !vec_type_hint !4 void @vec_type_hint_bfloat_vec()
+
+; CHECK: llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+declare !work_group_size_hint !5 void @work_group_size_hint()
+
+; CHECK: llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+declare !reqd_work_group_size !6 void @reqd_work_group_size()
+
+; CHECK: llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+declare !intel_reqd_sub_group_size !7 void @intel_reqd_sub_group_size()
+
+!0 = !{i32 undef, i32 0}
+!1 = !{i32 undef, i32 1}
+!2 = !{<2 x i32> undef, i32 1}
+!3 = !{<3 x float> undef, i32 0}
+!4 = !{<8 x bfloat> undef, i32 0}
+!5 = !{i32 128, i32 128, i32 128}
+!6 = !{i32 128, i32 256, i32 128}
+!7 = !{i32 32}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index db54d131299c6..8c2efa3d16f6b 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2549,3 +2549,47 @@ llvm.func @mem_effects_call() {
// CHECK-SAME: memory(read, inaccessiblemem: write)
// CHECK: #[[ATTRS_3]]
// CHECK-SAME: memory(readwrite, argmem: read)
+
+// -----
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT:]] void @vec_type_hint()
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED:]] void @vec_type_hint_signed()
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED_VEC:]] void @vec_type_hint_signed_vec()
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_FVEC:]] void @vec_type_hint_float_vec()
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_BFVEC:]] void @vec_type_hint_bfloat_vec()
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// CHECK: ![[#VEC_TYPE_HINT]] = !{i32 undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED]] = !{i32 undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED_VEC]] = !{<2 x i32> undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_FVEC]] = !{<3 x float> undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_BFVEC]] = !{<8 x bfloat> undef, i32 0}
+
+// -----
+
+// CHECK: declare !work_group_size_hint ![[#WORK_GROUP_SIZE_HINT:]] void @work_group_size_hint()
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// CHECK: ![[#WORK_GROUP_SIZE_HINT]] = !{i32 128, i32 128, i32 128}
+
+// -----
+
+// CHECK: declare !reqd_work_group_size ![[#REQD_WORK_GROUP_SIZE:]] void @reqd_work_group_size()
+llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// CHECK: ![[#REQD_WORK_GROUP_SIZE]] = !{i32 128, i32 256, i32 128}
+
+// -----
+
+// CHECK: declare !intel_reqd_sub_group_size ![[#INTEL_REQD_SUB_GROUP_SIZE:]] void @intel_reqd_sub_group_size()
+llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+
+// CHECK: ![[#INTEL_REQD_SUB_GROUP_SIZE]] = !{i32 32}
>From 93d59f4bdf5b0705995e1f3a8af01e3b21a9114d Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 31 Jul 2024 17:34:00 +0100
Subject: [PATCH 2/6] Move module import logic to interface
---
.../mlir/Target/LLVMIR/LLVMImportInterface.h | 8 +-
.../include/mlir/Target/LLVMIR/ModuleImport.h | 4 +-
.../LLVMIR/LLVMIRToLLVMTranslation.cpp | 165 +++++++++++++++++-
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 97 ----------
4 files changed, 167 insertions(+), 107 deletions(-)
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
index 74b545490b045..58e5b2c83043d 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
@@ -85,7 +85,9 @@ class LLVMImportDialectInterface
/// Hook for derived dialect interfaces to publish the supported metadata
/// kinds. As every metadata kind has a unique integer identifier, the
/// function returns the list of supported metadata identifiers.
- virtual ArrayRef<unsigned> getSupportedMetadata() const { return {}; }
+ virtual ArrayRef<unsigned> getSupportedMetadata(llvm::LLVMContext &) const {
+ return {};
+ }
};
/// Interface collection for the import of LLVM IR that dispatches to a concrete
@@ -101,7 +103,7 @@ class LLVMImportInterface
/// intrinsic and metadata kinds and builds the dispatch tables for the
/// conversion. Returns failure if multiple dialect interfaces translate the
/// same LLVM IR intrinsic.
- LogicalResult initializeImport() {
+ LogicalResult initializeImport(llvm::LLVMContext &llvmContext) {
for (const LLVMImportDialectInterface &iface : *this) {
// Verify the supported intrinsics have not been mapped before.
const auto *intrinsicIt =
@@ -139,7 +141,7 @@ class LLVMImportInterface
for (unsigned id : iface.getSupportedInstructions())
instructionToDialect[id] = &iface;
// Add a mapping for all supported metadata kinds.
- for (unsigned kind : iface.getSupportedMetadata())
+ for (unsigned kind : iface.getSupportedMetadata(llvmContext))
metadataToDialect[kind].push_back(iface.getDialect());
}
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 04d098d38155b..df3feb0a3a280 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -53,7 +53,9 @@ class ModuleImport {
/// dialect interfaces for the supported LLVM IR intrinsics and metadata kinds
/// and builds the dispatch tables. Returns failure if multiple dialect
/// interfaces translate the same LLVM IR intrinsic.
- LogicalResult initializeImportInterface() { return iface.initializeImport(); }
+ LogicalResult initializeImportInterface() {
+ return iface.initializeImport(llvmModule->getContext());
+ }
/// Converts all functions of the LLVM module to MLIR functions.
LogicalResult convertFunctions();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 06673965245c0..c97eca73aac4f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -32,6 +32,14 @@ using namespace mlir::LLVM::detail;
#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
+static constexpr StringLiteral vecTypeHintAttrName = "vec_type_hint";
+static constexpr StringLiteral workGroupSizeHintAttrName =
+ "work_group_size_hint";
+static constexpr StringLiteral reqdWorkGroupSizeAttrName =
+ "reqd_work_group_size";
+static constexpr StringLiteral intelReqdSubGroupSizeAttrName =
+ "intel_reqd_sub_group_size";
+
/// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
/// intrinsic. Returns false otherwise.
static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) {
@@ -70,11 +78,18 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
/// Returns the list of LLVM IR metadata kinds that are convertible to MLIR LLVM
/// dialect attributes.
-static ArrayRef<unsigned> getSupportedMetadataImpl() {
+static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
static const SmallVector<unsigned> convertibleMetadata = {
- llvm::LLVMContext::MD_prof, llvm::LLVMContext::MD_tbaa,
- llvm::LLVMContext::MD_access_group, llvm::LLVMContext::MD_loop,
- llvm::LLVMContext::MD_noalias, llvm::LLVMContext::MD_alias_scope};
+ llvm::LLVMContext::MD_prof,
+ llvm::LLVMContext::MD_tbaa,
+ llvm::LLVMContext::MD_access_group,
+ llvm::LLVMContext::MD_loop,
+ llvm::LLVMContext::MD_noalias,
+ llvm::LLVMContext::MD_alias_scope,
+ context.getMDKindID(vecTypeHintAttrName),
+ context.getMDKindID(workGroupSizeHintAttrName),
+ context.getMDKindID(reqdWorkGroupSizeAttrName),
+ context.getMDKindID(intelReqdSubGroupSizeAttrName)};
return convertibleMetadata;
}
@@ -226,6 +241,133 @@ static LogicalResult setNoaliasScopesAttr(const llvm::MDNode *node,
return success();
}
+/// Extract constant integer value from metadata if this is constant. Return
+/// `std::nullopt` otherwise.
+static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
+ auto *c = llvm::dyn_cast_or_null<llvm::ConstantAsMetadata>(md);
+ if (!c)
+ return {};
+
+ auto *ci = dyn_cast<llvm::ConstantInt>(c->getValue());
+ if (!ci)
+ return {};
+
+ return ci->getValue().getSExtValue();
+}
+
+/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
+template <typename ConvertType>
+static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
+ ConvertType convertType) {
+ if (!md || md->getNumOperands() != 2)
+ return {};
+
+ auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(md->getOperand(0).get());
+ if (!hintMD)
+ return {};
+ TypeAttr hint = TypeAttr::get(convertType(hintMD->getType()));
+
+ std::optional<int32_t> optIsSigned = parseIntegerMD(md->getOperand(1).get());
+ if (!optIsSigned)
+ return {};
+ bool isSigned = *optIsSigned != 0;
+
+ return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
+}
+
+/// Convert an `MDNode` to an MLIR `DenseI32ArrayAttr` if possible.
+static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
+ llvm::MDNode *md) {
+ if (!md)
+ return {};
+ SmallVector<int32_t> vals;
+ for (const llvm::MDOperand &op : md->operands()) {
+ std::optional<int32_t> mdValue = parseIntegerMD(op.get());
+ if (!mdValue)
+ return {};
+ vals.push_back(*mdValue);
+ }
+ return builder.getDenseI32ArrayAttr(vals);
+}
+
+/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
+static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
+ if (!md || md->getNumOperands() != 1)
+ return {};
+ std::optional<int32_t> val = parseIntegerMD(md->getOperand(0));
+ if (!val)
+ return {};
+ return builder.getI32IntegerAttr(*val);
+}
+
+template <typename Parser, typename Setter>
+static LogicalResult setFuncAttr(Builder &builder, llvm::MDNode *node,
+ Operation *op, Parser parse, Setter set) {
+ auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!funcOp)
+ return failure();
+
+ auto attr = parse(node);
+ if (!attr)
+ return failure();
+
+ set(funcOp, attr);
+ return success();
+}
+
+static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
+ Operation *op,
+ LLVM::ModuleImport &moduleImport) {
+ return setFuncAttr(
+ builder, node, op,
+ [&builder, &moduleImport](llvm::MDNode *node) {
+ return convertVecTypeHint(builder, node,
+ [&moduleImport](llvm::Type *type) {
+ return moduleImport.convertType(type);
+ });
+ },
+ [](LLVM::LLVMFuncOp funcOp, VecTypeHintAttr attr) {
+ funcOp.setVecTypeHintAttr(attr);
+ });
+}
+
+static LogicalResult
+setWorkGroupSizeHintAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
+ return setFuncAttr(
+ builder, node, op,
+ [&builder](llvm::MDNode *node) {
+ return convertDenseI32Array(builder, node);
+ },
+ [](LLVM::LLVMFuncOp funcOp, DenseI32ArrayAttr attr) {
+ funcOp.setWorkGroupSizeHintAttr(attr);
+ });
+}
+
+static LogicalResult
+setReqdWorkGroupSizeAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
+ return setFuncAttr(
+ builder, node, op,
+ [&builder](llvm::MDNode *node) {
+ return convertDenseI32Array(builder, node);
+ },
+ [](LLVM::LLVMFuncOp funcOp, DenseI32ArrayAttr attr) {
+ funcOp.setReqdWorkGroupSizeAttr(attr);
+ });
+}
+
+static LogicalResult setIntelReqdSubGroupSizeAttr(Builder &builder,
+ llvm::MDNode *node,
+ Operation *op) {
+ return setFuncAttr(
+ builder, node, op,
+ [&builder](llvm::MDNode *node) {
+ return convertIntegerMD(builder, node);
+ },
+ [](LLVM::LLVMFuncOp funcOp, IntegerAttr attr) {
+ funcOp.setIntelReqdSubGroupSizeAttr(attr);
+ });
+}
+
namespace {
/// Implementation of the dialect interface that converts operations belonging
@@ -261,6 +403,16 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
if (kind == llvm::LLVMContext::MD_noalias)
return setNoaliasScopesAttr(node, op, moduleImport);
+ llvm::LLVMContext &context = node->getContext();
+ if (kind == context.getMDKindID(vecTypeHintAttrName))
+ return setVecTypeHintAttr(builder, node, op, moduleImport);
+ if (kind == context.getMDKindID(workGroupSizeHintAttrName))
+ return setWorkGroupSizeHintAttr(builder, node, op);
+ if (kind == context.getMDKindID(reqdWorkGroupSizeAttrName))
+ return setReqdWorkGroupSizeAttr(builder, node, op);
+ if (kind == context.getMDKindID(intelReqdSubGroupSizeAttrName))
+ return setIntelReqdSubGroupSizeAttr(builder, node, op);
+
// A handler for a supported metadata kind is missing.
llvm_unreachable("unknown metadata type");
}
@@ -273,8 +425,9 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
/// Returns the list of LLVM IR metadata kinds that are convertible to MLIR
/// LLVM dialect attributes.
- ArrayRef<unsigned> getSupportedMetadata() const final {
- return getSupportedMetadataImpl();
+ ArrayRef<unsigned>
+ getSupportedMetadata(llvm::LLVMContext &context) const final {
+ return getSupportedMetadataImpl(context);
}
};
} // namespace
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index cb18dc5193352..8b40b7b2df6c7 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1918,99 +1918,6 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
}
-/// Extract constant integer value from metadata if this is constant. Return
-/// `std::nullopt` otherwise.
-static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
- if (!md)
- return {};
-
- auto *c = dyn_cast<llvm::ConstantAsMetadata>(md);
- if (!c)
- return {};
-
- auto *ci = dyn_cast<llvm::ConstantInt>(c->getValue());
- if (!ci)
- return {};
-
- return ci->getValue().getSExtValue();
-}
-
-/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
-template <typename ConvertType>
-static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
- ConvertType convertType) {
- if (!md || md->getNumOperands() != 2)
- return {};
-
- auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(md->getOperand(0).get());
- if (!hintMD)
- return {};
- TypeAttr hint = TypeAttr::get(convertType(hintMD->getType()));
-
- std::optional<int32_t> optIsSigned = parseIntegerMD(md->getOperand(1).get());
- if (!optIsSigned)
- return {};
- bool isSigned = *optIsSigned != 0;
-
- return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
-}
-
-/// Convert an `MDNode` to an MLIR `DenseI32ArrayAttr` if possible.
-static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
- llvm::MDNode *md) {
- if (!md)
- return {};
- SmallVector<int32_t> vals;
- for (const llvm::MDOperand &op : md->operands()) {
- std::optional<int32_t> mdValue = parseIntegerMD(op.get());
- if (!mdValue)
- return {};
- vals.push_back(*mdValue);
- }
- return builder.getDenseI32ArrayAttr(vals);
-}
-
-/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
-static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
- if (!md || md->getNumOperands() != 1)
- return {};
- std::optional<int32_t> val = parseIntegerMD(md->getOperand(0));
- if (!val)
- return {};
- return builder.getI32IntegerAttr(*val);
-}
-
-/// Process metadata found in kernel functions:
-/// - `vec_type_hint`
-/// - `work_group_size_hint`
-/// - `reqd_work_group_size`
-/// - `intel_reqd_sub_group_size`
-template <typename ConvertType>
-static void processKernelMetadata(llvm::Function *func, LLVMFuncOp funcOp,
- ConvertType convertType) {
- Builder builder(funcOp);
-
- if (VecTypeHintAttr attr = convertVecTypeHint(
- builder, func->getMetadata("vec_type_hint"), convertType)) {
- funcOp.setVecTypeHintAttr(attr);
- }
-
- if (DenseI32ArrayAttr attr = convertDenseI32Array(
- builder, func->getMetadata("work_group_size_hint"))) {
- funcOp.setWorkGroupSizeHintAttr(attr);
- }
-
- if (DenseI32ArrayAttr attr = convertDenseI32Array(
- builder, func->getMetadata("reqd_work_group_size"))) {
- funcOp.setReqdWorkGroupSizeAttr(attr);
- }
-
- if (IntegerAttr attr = convertIntegerMD(
- builder, func->getMetadata("intel_reqd_sub_group_size"))) {
- funcOp.setIntelReqdSubGroupSizeAttr(attr);
- }
-}
-
LogicalResult ModuleImport::processFunction(llvm::Function *func) {
clearRegionState();
@@ -2059,10 +1966,6 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
// Handle Function attributes.
processFunctionAttributes(func, funcOp);
- // Handle Kernel Metadata
- processKernelMetadata(func, funcOp,
- [this](llvm::Type *type) { return convertType(type); });
-
// Convert non-debug metadata by using the dialect interface.
SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
func->getAllMetadata(allMetadata);
>From d2534a9425121a58af037823da7ec992b10acb5d Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 31 Jul 2024 17:40:09 +0100
Subject: [PATCH 3/6] Add documentation
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 7 +++++++
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 1 -
2 files changed, 7 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 0ecd2dcacffc1..4858866d2f816 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1077,6 +1077,13 @@ def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
/// Represents "vec_type_hint" values
def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
+ let summary = "Explicit vectorization compiler hint";
+ let description = [{
+ A hint to the compiler that indicates most operations used in the function
+ are explictly vectorized using a particular vector type. `$hint` is the
+ vector or scalar type in particular. `$is_signed` can be used with integer
+ types to state whether the type is signed.
+ }];
let parameters = (ins "TypeAttr":$hint,
DefaultValuedParameter<"bool", "false">:$is_signed);
let assemblyFormat = "`<` struct(params) `>`";
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index fde42682d807a..c38a2584c8eec 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1457,7 +1457,6 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<UnitAttr>:$no_unwind,
OptionalAttr<UnitAttr>:$will_return,
OptionalAttr<UnitAttr>:$optimize_none,
- // Kernel metadata
OptionalAttr<LLVM_VecTypeHintAttr>:$vec_type_hint,
OptionalAttr<DenseI32ArrayAttr>:$work_group_size_hint,
OptionalAttr<DenseI32ArrayAttr>:$reqd_work_group_size,
>From abe03ae56511d95956232cf55e7c9b879e82db34 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 1 Aug 2024 08:53:31 +0100
Subject: [PATCH 4/6] Add clarification
---
mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
index 58e5b2c83043d..cc5a77ed35d2b 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
@@ -84,8 +84,10 @@ class LLVMImportDialectInterface
/// Hook for derived dialect interfaces to publish the supported metadata
/// kinds. As every metadata kind has a unique integer identifier, the
- /// function returns the list of supported metadata identifiers.
- virtual ArrayRef<unsigned> getSupportedMetadata(llvm::LLVMContext &) const {
+ /// function returns the list of supported metadata identifiers. `ctx` can be
+ /// used to obtain IDs of metadata kinds that do not have a fixed static one.
+ virtual ArrayRef<unsigned>
+ getSupportedMetadata(llvm::LLVMContext &ctx) const {
return {};
}
};
>From e342f072708a42cbeaad9a3631a093f7f99692bb Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 1 Aug 2024 08:58:33 +0100
Subject: [PATCH 5/6] Simplify code
---
.../Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp | 16 ++++++----------
1 file changed, 6 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index c97eca73aac4f..393a68c010ebf 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -256,16 +256,15 @@ static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
}
/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
-template <typename ConvertType>
static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
- ConvertType convertType) {
+ ModuleImport &moduleImport) {
if (!md || md->getNumOperands() != 2)
return {};
auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(md->getOperand(0).get());
if (!hintMD)
return {};
- TypeAttr hint = TypeAttr::get(convertType(hintMD->getType()));
+ TypeAttr hint = TypeAttr::get(moduleImport.convertType(hintMD->getType()));
std::optional<int32_t> optIsSigned = parseIntegerMD(md->getOperand(1).get());
if (!optIsSigned)
@@ -300,14 +299,14 @@ static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
return builder.getI32IntegerAttr(*val);
}
-template <typename Parser, typename Setter>
+template <typename Encoder, typename Setter>
static LogicalResult setFuncAttr(Builder &builder, llvm::MDNode *node,
- Operation *op, Parser parse, Setter set) {
+ Operation *op, Encoder encode, Setter set) {
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!funcOp)
return failure();
- auto attr = parse(node);
+ auto attr = encode(node);
if (!attr)
return failure();
@@ -321,10 +320,7 @@ static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
return setFuncAttr(
builder, node, op,
[&builder, &moduleImport](llvm::MDNode *node) {
- return convertVecTypeHint(builder, node,
- [&moduleImport](llvm::Type *type) {
- return moduleImport.convertType(type);
- });
+ return convertVecTypeHint(builder, node, moduleImport);
},
[](LLVM::LLVMFuncOp funcOp, VecTypeHintAttr attr) {
funcOp.setVecTypeHintAttr(attr);
>From 27f9dcccdad1cf48889ac8fd780edb7350277cec Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 1 Aug 2024 11:46:03 +0100
Subject: [PATCH 6/6] Apply suggestions
---
.../mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 1 -
.../LLVMIR/LLVMIRToLLVMTranslation.cpp | 120 ++++++++----------
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 23 ++--
3 files changed, 66 insertions(+), 78 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 4858866d2f816..529c458ce1254 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1075,7 +1075,6 @@ def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
// VecTypeHintAttr
//===----------------------------------------------------------------------===//
-/// Represents "vec_type_hint" values
def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
let summary = "Explicit vectorization compiler hint";
let description = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 393a68c010ebf..d76722a24c4fc 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -32,12 +32,10 @@ using namespace mlir::LLVM::detail;
#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
-static constexpr StringLiteral vecTypeHintAttrName = "vec_type_hint";
-static constexpr StringLiteral workGroupSizeHintAttrName =
- "work_group_size_hint";
-static constexpr StringLiteral reqdWorkGroupSizeAttrName =
- "reqd_work_group_size";
-static constexpr StringLiteral intelReqdSubGroupSizeAttrName =
+static constexpr StringLiteral vecTypeHintMDName = "vec_type_hint";
+static constexpr StringLiteral workGroupSizeHintMDName = "work_group_size_hint";
+static constexpr StringLiteral reqdWorkGroupSizeMDName = "reqd_work_group_size";
+static constexpr StringLiteral intelReqdSubGroupSizeMDName =
"intel_reqd_sub_group_size";
/// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
@@ -86,10 +84,10 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
llvm::LLVMContext::MD_loop,
llvm::LLVMContext::MD_noalias,
llvm::LLVMContext::MD_alias_scope,
- context.getMDKindID(vecTypeHintAttrName),
- context.getMDKindID(workGroupSizeHintAttrName),
- context.getMDKindID(reqdWorkGroupSizeAttrName),
- context.getMDKindID(intelReqdSubGroupSizeAttrName)};
+ context.getMDKindID(vecTypeHintMDName),
+ context.getMDKindID(workGroupSizeHintMDName),
+ context.getMDKindID(reqdWorkGroupSizeMDName),
+ context.getMDKindID(intelReqdSubGroupSizeMDName)};
return convertibleMetadata;
}
@@ -241,21 +239,22 @@ static LogicalResult setNoaliasScopesAttr(const llvm::MDNode *node,
return success();
}
-/// Extract constant integer value from metadata if this is constant. Return
-/// `std::nullopt` otherwise.
+/// Extracts an integer from the provided metadata `md` if possible. Returns
+/// nullopt otherwise.
static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
- auto *c = llvm::dyn_cast_or_null<llvm::ConstantAsMetadata>(md);
- if (!c)
+ auto *constant = dyn_cast_if_present<llvm::ConstantAsMetadata>(md);
+ if (!constant)
return {};
- auto *ci = dyn_cast<llvm::ConstantInt>(c->getValue());
- if (!ci)
+ auto *intConstant = dyn_cast<llvm::ConstantInt>(constant->getValue());
+ if (!intConstant)
return {};
- return ci->getValue().getSExtValue();
+ return intConstant->getValue().getSExtValue();
}
-/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
+/// Converts the provided metadata node `md` to an LLVM dialect VecTypeHintAttr
+/// if possible.
static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
ModuleImport &moduleImport) {
if (!md || md->getNumOperands() != 2)
@@ -299,69 +298,62 @@ static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
return builder.getI32IntegerAttr(*val);
}
-template <typename Encoder, typename Setter>
-static LogicalResult setFuncAttr(Builder &builder, llvm::MDNode *node,
- Operation *op, Encoder encode, Setter set) {
+static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
+ Operation *op,
+ LLVM::ModuleImport &moduleImport) {
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!funcOp)
return failure();
- auto attr = encode(node);
+ VecTypeHintAttr attr = convertVecTypeHint(builder, node, moduleImport);
if (!attr)
return failure();
- set(funcOp, attr);
+ funcOp.setVecTypeHintAttr(attr);
return success();
}
-static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
- Operation *op,
- LLVM::ModuleImport &moduleImport) {
- return setFuncAttr(
- builder, node, op,
- [&builder, &moduleImport](llvm::MDNode *node) {
- return convertVecTypeHint(builder, node, moduleImport);
- },
- [](LLVM::LLVMFuncOp funcOp, VecTypeHintAttr attr) {
- funcOp.setVecTypeHintAttr(attr);
- });
-}
-
static LogicalResult
setWorkGroupSizeHintAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
- return setFuncAttr(
- builder, node, op,
- [&builder](llvm::MDNode *node) {
- return convertDenseI32Array(builder, node);
- },
- [](LLVM::LLVMFuncOp funcOp, DenseI32ArrayAttr attr) {
- funcOp.setWorkGroupSizeHintAttr(attr);
- });
+ auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!funcOp)
+ return failure();
+
+ DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
+ if (!attr)
+ return failure();
+
+ funcOp.setWorkGroupSizeHintAttr(attr);
+ return success();
}
static LogicalResult
setReqdWorkGroupSizeAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
- return setFuncAttr(
- builder, node, op,
- [&builder](llvm::MDNode *node) {
- return convertDenseI32Array(builder, node);
- },
- [](LLVM::LLVMFuncOp funcOp, DenseI32ArrayAttr attr) {
- funcOp.setReqdWorkGroupSizeAttr(attr);
- });
+ auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!funcOp)
+ return failure();
+
+ DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
+ if (!attr)
+ return failure();
+
+ funcOp.setReqdWorkGroupSizeAttr(attr);
+ return success();
}
static LogicalResult setIntelReqdSubGroupSizeAttr(Builder &builder,
llvm::MDNode *node,
Operation *op) {
- return setFuncAttr(
- builder, node, op,
- [&builder](llvm::MDNode *node) {
- return convertIntegerMD(builder, node);
- },
- [](LLVM::LLVMFuncOp funcOp, IntegerAttr attr) {
- funcOp.setIntelReqdSubGroupSizeAttr(attr);
- });
+ auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!funcOp)
+ return failure();
+
+ IntegerAttr attr = convertIntegerMD(builder, node);
+ if (!attr)
+ return failure();
+
+ funcOp.setIntelReqdSubGroupSizeAttr(attr);
+ return success();
}
namespace {
@@ -400,13 +392,13 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
return setNoaliasScopesAttr(node, op, moduleImport);
llvm::LLVMContext &context = node->getContext();
- if (kind == context.getMDKindID(vecTypeHintAttrName))
+ if (kind == context.getMDKindID(vecTypeHintMDName))
return setVecTypeHintAttr(builder, node, op, moduleImport);
- if (kind == context.getMDKindID(workGroupSizeHintAttrName))
+ if (kind == context.getMDKindID(workGroupSizeHintMDName))
return setWorkGroupSizeHintAttr(builder, node, op);
- if (kind == context.getMDKindID(reqdWorkGroupSizeAttrName))
+ if (kind == context.getMDKindID(reqdWorkGroupSizeMDName))
return setReqdWorkGroupSizeAttr(builder, node, op);
- if (kind == context.getMDKindID(intelReqdSubGroupSizeAttrName))
+ if (kind == context.getMDKindID(intelReqdSubGroupSizeMDName))
return setIntelReqdSubGroupSizeAttr(builder, node, op);
// A handler for a supported metadata kind is missing.
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 00ca8bdbf0c23..b1c345d48a204 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1271,11 +1271,10 @@ static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
return llvm::MDNode::get(context, {typeMD, isSignedMD});
}
-/// Return an MDNode with a tuple given by the values in the input integer array
-/// attribute.
+/// Return an MDNode with a tuple given by the values in `values`.
static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
ArrayRef<int32_t> values) {
- llvm::SmallVector<llvm::Metadata *> mds;
+ SmallVector<llvm::Metadata *> mds;
llvm::transform(values, std::back_inserter(mds), [&context](int32_t value) {
return convertIntegerToMetadata(context, llvm::APInt(32, value));
});
@@ -1484,36 +1483,35 @@ static void convertFunctionAttributes(LLVMFuncOp func,
}
/// Converts function attributes from `func` and attaches them to `llvmFunc`.
-template <typename TypeConverter>
static void convertFunctionKernelAttributes(LLVMFuncOp func,
llvm::Function *llvmFunc,
- TypeConverter convertType) {
+ ModuleTranslation &translation) {
llvm::LLVMContext &llvmContext = llvmFunc->getContext();
if (auto vecTypeHint = func.getVecTypeHint()) {
Type type = vecTypeHint->getHint().getValue();
- llvm::Type *llvmType = convertType(type);
+ llvm::Type *llvmType = translation.convertType(type);
bool isSigned = vecTypeHint->getIsSigned();
llvmFunc->setMetadata(
- "vec_type_hint",
+ func.getVecTypeHintAttrName(),
convertVecTypeHintToMDNode(llvmContext, llvmType, isSigned));
}
if (auto workGroupSizeHint = func.getWorkGroupSizeHint()) {
llvmFunc->setMetadata(
- "work_group_size_hint",
+ func.getWorkGroupSizeHintAttrName(),
convertIntegerArrayToMDNode(llvmContext, *workGroupSizeHint));
}
if (auto reqdWorkGroupSize = func.getReqdWorkGroupSize()) {
llvmFunc->setMetadata(
- "reqd_work_group_size",
+ func.getReqdWorkGroupSizeAttrName(),
convertIntegerArrayToMDNode(llvmContext, *reqdWorkGroupSize));
}
if (auto intelReqdSubGroupSize = func.getIntelReqdSubGroupSize()) {
llvmFunc->setMetadata(
- "intel_reqd_sub_group_size",
+ func.getIntelReqdSubGroupSizeAttrName(),
convertIntegerToMDNode(llvmContext,
llvm::APInt(32, *intelReqdSubGroupSize)));
}
@@ -1563,9 +1561,8 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
// Convert function attributes.
convertFunctionAttributes(function, llvmFunc);
- // Convert function kernel attributes to metadata
- convertFunctionKernelAttributes(
- function, llvmFunc, [this](Type type) { return convertType(type); });
+ // Convert function kernel attributes to metadata.
+ convertFunctionKernelAttributes(function, llvmFunc, *this);
// Convert function_entry_count attribute to metadata.
if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
More information about the Mlir-commits
mailing list