[Mlir-commits] [mlir] [MLIR][LLVM] Attach kernel metadata representation to `llvm.func` (PR #101314)

Victor Perez llvmlistbot at llvm.org
Thu Aug 1 03:46:21 PDT 2024


https://github.com/victor-eds updated https://github.com/llvm/llvm-project/pull/101314

>From 3744c1a3f8b801d7e459e25ee9d31a385fe13cdf Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 31 Jul 2024 10:49:26 +0100
Subject: [PATCH 1/6] [MLIR][LLVM] Attach kernel metadata representation to
 `llvm.func`

Add optional attributes to `llvm.func` representing LLVM so-called
"kernel" metadata:

- [`vec_type_hint`](https://clang.llvm.org/docs/AttributeReference.html#vec-type-hint)
- [`work_group_size_hint`](https://clang.llvm.org/docs/AttributeReference.html#work-group-size-hint)
- [`reqd_work_group_size`](https://clang.llvm.org/docs/AttributeReference.html#reqd-work-group-size)
- [`intel_reqd_sub_group_size`](https://clang.llvm.org/docs/AttributeReference.html#intel-reqd-sub-group-size).

Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 11 +++
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  7 +-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 97 +++++++++++++++++++
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  | 75 ++++++++++++++
 mlir/test/Dialect/LLVMIR/func.mlir            | 42 +++++++-
 .../Target/LLVMIR/Import/metadata-kernel.ll   | 34 +++++++
 mlir/test/Target/LLVMIR/llvmir.mlir           | 44 +++++++++
 7 files changed, 308 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/Import/metadata-kernel.ll

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 695a962bcab9b..0ecd2dcacffc1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1071,6 +1071,17 @@ def LLVM_UndefAttr : LLVM_Attr<"Undef", "undef">;
 /// Folded into from LLVM::PoisonOp.
 def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
 
+//===----------------------------------------------------------------------===//
+// VecTypeHintAttr
+//===----------------------------------------------------------------------===//
+
+/// Represents "vec_type_hint" values
+def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
+  let parameters = (ins "TypeAttr":$hint,
+                        DefaultValuedParameter<"bool", "false">:$is_signed);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // ZeroAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 260d42185b57f..fde42682d807a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1456,7 +1456,12 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OptionalAttr<UnitAttr>:$always_inline,
     OptionalAttr<UnitAttr>:$no_unwind,
     OptionalAttr<UnitAttr>:$will_return,
-    OptionalAttr<UnitAttr>:$optimize_none
+    OptionalAttr<UnitAttr>:$optimize_none,
+    // Kernel metadata
+    OptionalAttr<LLVM_VecTypeHintAttr>:$vec_type_hint,
+    OptionalAttr<DenseI32ArrayAttr>:$work_group_size_hint,
+    OptionalAttr<DenseI32ArrayAttr>:$reqd_work_group_size,
+    OptionalAttr<I32Attr>:$intel_reqd_sub_group_size
   );
 
   let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 8b40b7b2df6c7..cb18dc5193352 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1918,6 +1918,99 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
       builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
 }
 
+/// Extract constant integer value from metadata if this is constant. Return
+/// `std::nullopt` otherwise.
+static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
+  if (!md)
+    return {};
+
+  auto *c = dyn_cast<llvm::ConstantAsMetadata>(md);
+  if (!c)
+    return {};
+
+  auto *ci = dyn_cast<llvm::ConstantInt>(c->getValue());
+  if (!ci)
+    return {};
+
+  return ci->getValue().getSExtValue();
+}
+
+/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
+template <typename ConvertType>
+static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
+                                          ConvertType convertType) {
+  if (!md || md->getNumOperands() != 2)
+    return {};
+
+  auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(md->getOperand(0).get());
+  if (!hintMD)
+    return {};
+  TypeAttr hint = TypeAttr::get(convertType(hintMD->getType()));
+
+  std::optional<int32_t> optIsSigned = parseIntegerMD(md->getOperand(1).get());
+  if (!optIsSigned)
+    return {};
+  bool isSigned = *optIsSigned != 0;
+
+  return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
+}
+
+/// Convert an `MDNode` to an MLIR `DenseI32ArrayAttr` if possible.
+static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
+                                              llvm::MDNode *md) {
+  if (!md)
+    return {};
+  SmallVector<int32_t> vals;
+  for (const llvm::MDOperand &op : md->operands()) {
+    std::optional<int32_t> mdValue = parseIntegerMD(op.get());
+    if (!mdValue)
+      return {};
+    vals.push_back(*mdValue);
+  }
+  return builder.getDenseI32ArrayAttr(vals);
+}
+
+/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
+static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
+  if (!md || md->getNumOperands() != 1)
+    return {};
+  std::optional<int32_t> val = parseIntegerMD(md->getOperand(0));
+  if (!val)
+    return {};
+  return builder.getI32IntegerAttr(*val);
+}
+
+/// Process metadata found in kernel functions:
+/// - `vec_type_hint`
+/// - `work_group_size_hint`
+/// - `reqd_work_group_size`
+/// - `intel_reqd_sub_group_size`
+template <typename ConvertType>
+static void processKernelMetadata(llvm::Function *func, LLVMFuncOp funcOp,
+                                  ConvertType convertType) {
+  Builder builder(funcOp);
+
+  if (VecTypeHintAttr attr = convertVecTypeHint(
+          builder, func->getMetadata("vec_type_hint"), convertType)) {
+    funcOp.setVecTypeHintAttr(attr);
+  }
+
+  if (DenseI32ArrayAttr attr = convertDenseI32Array(
+          builder, func->getMetadata("work_group_size_hint"))) {
+    funcOp.setWorkGroupSizeHintAttr(attr);
+  }
+
+  if (DenseI32ArrayAttr attr = convertDenseI32Array(
+          builder, func->getMetadata("reqd_work_group_size"))) {
+    funcOp.setReqdWorkGroupSizeAttr(attr);
+  }
+
+  if (IntegerAttr attr = convertIntegerMD(
+          builder, func->getMetadata("intel_reqd_sub_group_size"))) {
+    funcOp.setIntelReqdSubGroupSizeAttr(attr);
+  }
+}
+
 LogicalResult ModuleImport::processFunction(llvm::Function *func) {
   clearRegionState();
 
@@ -1966,6 +2059,10 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
   // Handle Function attributes.
   processFunctionAttributes(func, funcOp);
 
+  // Handle Kernel Metadata
+  processKernelMetadata(func, funcOp,
+                        [this](llvm::Type *type) { return convertType(type); });
+
   // Convert non-debug metadata by using the dialect interface.
   SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
   func->getAllMetadata(allMetadata);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 3016d1846e00f..00ca8bdbf0c23 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1247,6 +1247,41 @@ static LogicalResult checkedAddLLVMFnAttribute(Location loc,
   return success();
 }
 
+/// Return a representation of `value` as metadata.
+static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
+                                                const llvm::APInt &value) {
+  llvm::Constant *constant = llvm::ConstantInt::get(context, value);
+  return llvm::ConstantAsMetadata::get(constant);
+}
+
+/// Return a representation of `value` as an MDNode.
+static llvm::MDNode *convertIntegerToMDNode(llvm::LLVMContext &context,
+                                            const llvm::APInt &value) {
+  return llvm::MDNode::get(context, convertIntegerToMetadata(context, value));
+}
+
+/// Return an MDNode encoding `vec_type_hint` metadata.
+static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
+                                                llvm::Type *type,
+                                                bool isSigned) {
+  llvm::Metadata *typeMD =
+      llvm::ConstantAsMetadata::get(llvm::UndefValue::get(type));
+  llvm::Metadata *isSignedMD =
+      convertIntegerToMetadata(context, llvm::APInt(32, isSigned ? 1 : 0));
+  return llvm::MDNode::get(context, {typeMD, isSignedMD});
+}
+
+/// Return an MDNode with a tuple given by the values in the input integer array
+/// attribute.
+static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
+                                                 ArrayRef<int32_t> values) {
+  llvm::SmallVector<llvm::Metadata *> mds;
+  llvm::transform(values, std::back_inserter(mds), [&context](int32_t value) {
+    return convertIntegerToMetadata(context, llvm::APInt(32, value));
+  });
+  return llvm::MDNode::get(context, mds);
+}
+
 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
 /// to be an array attribute containing either string attributes, treated as
@@ -1448,6 +1483,42 @@ static void convertFunctionAttributes(LLVMFuncOp func,
   convertFunctionMemoryAttributes(func, llvmFunc);
 }
 
+/// Converts function attributes from `func` and attaches them to `llvmFunc`.
+template <typename TypeConverter>
+static void convertFunctionKernelAttributes(LLVMFuncOp func,
+                                            llvm::Function *llvmFunc,
+                                            TypeConverter convertType) {
+  llvm::LLVMContext &llvmContext = llvmFunc->getContext();
+
+  if (auto vecTypeHint = func.getVecTypeHint()) {
+    Type type = vecTypeHint->getHint().getValue();
+    llvm::Type *llvmType = convertType(type);
+    bool isSigned = vecTypeHint->getIsSigned();
+    llvmFunc->setMetadata(
+        "vec_type_hint",
+        convertVecTypeHintToMDNode(llvmContext, llvmType, isSigned));
+  }
+
+  if (auto workGroupSizeHint = func.getWorkGroupSizeHint()) {
+    llvmFunc->setMetadata(
+        "work_group_size_hint",
+        convertIntegerArrayToMDNode(llvmContext, *workGroupSizeHint));
+  }
+
+  if (auto reqdWorkGroupSize = func.getReqdWorkGroupSize()) {
+    llvmFunc->setMetadata(
+        "reqd_work_group_size",
+        convertIntegerArrayToMDNode(llvmContext, *reqdWorkGroupSize));
+  }
+
+  if (auto intelReqdSubGroupSize = func.getIntelReqdSubGroupSize()) {
+    llvmFunc->setMetadata(
+        "intel_reqd_sub_group_size",
+        convertIntegerToMDNode(llvmContext,
+                               llvm::APInt(32, *intelReqdSubGroupSize)));
+  }
+}
+
 FailureOr<llvm::AttrBuilder>
 ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
                                          DictionaryAttr paramAttrs) {
@@ -1492,6 +1563,10 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
     // Convert function attributes.
     convertFunctionAttributes(function, llvmFunc);
 
+    // Convert function kernel attributes to metadata
+    convertFunctionKernelAttributes(
+        function, llvmFunc, [this](Type type) { return convertType(type); });
+
     // Convert function_entry_count attribute to metadata.
     if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
       llvmFunc->setEntryCount(entryCount.value());
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 0e29a548de72f..40b4e49f08a3e 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 // RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-op-generic %s | FileCheck %s --check-prefix=GENERIC
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s -mlir-print-debuginfo | mlir-opt -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
+// RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-debuginfo %s | mlir-opt -split-input-file -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
 // RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-LLVM
 
 module {
@@ -432,3 +432,43 @@ module {
   // expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
   }) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
 }
+
+// -----
+
+// CHECK: @vec_type_hint()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32>
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: @vec_type_hint_signed()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: @vec_type_hint_signed_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: @vec_type_hint_float_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: @vec_type_hint_bfloat_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// -----
+
+// CHECK: @work_group_size_hint()
+// CHECK-SAME: work_group_size_hint = array<i32: 128, 128, 128>
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// -----
+
+// CHECK: @reqd_work_group_size_hint()
+// CHECK-SAME: reqd_work_group_size = array<i32: 128, 256, 128>
+llvm.func @reqd_work_group_size_hint() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// -----
+
+// CHECK: @intel_reqd_sub_group_size_hint()
+// CHECK-SAME: intel_reqd_sub_group_size = 32 : i32
+llvm.func @intel_reqd_sub_group_size_hint() attributes {llvm.intel_reqd_sub_group_size = 32 : i32}
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
new file mode 100644
index 0000000000000..963e40348c795
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
@@ -0,0 +1,34 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK:   llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+declare !vec_type_hint !0 void @vec_type_hint()
+
+; CHECK:   llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+declare !vec_type_hint !1 void @vec_type_hint_signed()
+
+; CHECK:   llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+declare !vec_type_hint !2 void @vec_type_hint_signed_vec()
+
+; CHECK:   llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+declare !vec_type_hint !3 void @vec_type_hint_float_vec()
+
+; CHECK:   llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+declare !vec_type_hint !4 void @vec_type_hint_bfloat_vec()
+
+; CHECK:   llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+declare !work_group_size_hint !5 void @work_group_size_hint()
+
+; CHECK:   llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+declare !reqd_work_group_size !6 void @reqd_work_group_size()
+
+; CHECK:   llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+declare !intel_reqd_sub_group_size !7 void @intel_reqd_sub_group_size()
+
+!0 = !{i32 undef, i32 0}
+!1 = !{i32 undef, i32 1}
+!2 = !{<2 x i32> undef, i32 1}
+!3 = !{<3 x float> undef, i32 0}
+!4 = !{<8 x bfloat> undef, i32 0}
+!5 = !{i32 128, i32 128, i32 128}
+!6 = !{i32 128, i32 256, i32 128}
+!7 = !{i32 32}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index db54d131299c6..8c2efa3d16f6b 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2549,3 +2549,47 @@ llvm.func @mem_effects_call() {
 // CHECK-SAME: memory(read, inaccessiblemem: write)
 // CHECK: #[[ATTRS_3]]
 // CHECK-SAME: memory(readwrite, argmem: read)
+
+// -----
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT:]] void @vec_type_hint()
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED:]] void @vec_type_hint_signed()
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED_VEC:]] void @vec_type_hint_signed_vec()
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_FVEC:]] void @vec_type_hint_float_vec()
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_BFVEC:]] void @vec_type_hint_bfloat_vec()
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// CHECK: ![[#VEC_TYPE_HINT]] = !{i32 undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED]] = !{i32 undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED_VEC]] = !{<2 x i32> undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_FVEC]] = !{<3 x float> undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_BFVEC]] = !{<8 x bfloat> undef, i32 0}
+
+// -----
+
+// CHECK: declare !work_group_size_hint ![[#WORK_GROUP_SIZE_HINT:]] void @work_group_size_hint()
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// CHECK: ![[#WORK_GROUP_SIZE_HINT]] = !{i32 128, i32 128, i32 128}
+
+// -----
+
+// CHECK: declare !reqd_work_group_size ![[#REQD_WORK_GROUP_SIZE:]] void @reqd_work_group_size()
+llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// CHECK: ![[#REQD_WORK_GROUP_SIZE]] = !{i32 128, i32 256, i32 128}
+
+// -----
+
+// CHECK: declare !intel_reqd_sub_group_size ![[#INTEL_REQD_SUB_GROUP_SIZE:]] void @intel_reqd_sub_group_size()
+llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+
+// CHECK: ![[#INTEL_REQD_SUB_GROUP_SIZE]] = !{i32 32}

>From 93d59f4bdf5b0705995e1f3a8af01e3b21a9114d Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 31 Jul 2024 17:34:00 +0100
Subject: [PATCH 2/6] Move module import logic to interface

---
 .../mlir/Target/LLVMIR/LLVMImportInterface.h  |   8 +-
 .../include/mlir/Target/LLVMIR/ModuleImport.h |   4 +-
 .../LLVMIR/LLVMIRToLLVMTranslation.cpp        | 165 +++++++++++++++++-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  97 ----------
 4 files changed, 167 insertions(+), 107 deletions(-)

diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
index 74b545490b045..58e5b2c83043d 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
@@ -85,7 +85,9 @@ class LLVMImportDialectInterface
   /// Hook for derived dialect interfaces to publish the supported metadata
   /// kinds. As every metadata kind has a unique integer identifier, the
   /// function returns the list of supported metadata identifiers.
-  virtual ArrayRef<unsigned> getSupportedMetadata() const { return {}; }
+  virtual ArrayRef<unsigned> getSupportedMetadata(llvm::LLVMContext &) const {
+    return {};
+  }
 };
 
 /// Interface collection for the import of LLVM IR that dispatches to a concrete
@@ -101,7 +103,7 @@ class LLVMImportInterface
   /// intrinsic and metadata kinds and builds the dispatch tables for the
   /// conversion. Returns failure if multiple dialect interfaces translate the
   /// same LLVM IR intrinsic.
-  LogicalResult initializeImport() {
+  LogicalResult initializeImport(llvm::LLVMContext &llvmContext) {
     for (const LLVMImportDialectInterface &iface : *this) {
       // Verify the supported intrinsics have not been mapped before.
       const auto *intrinsicIt =
@@ -139,7 +141,7 @@ class LLVMImportInterface
       for (unsigned id : iface.getSupportedInstructions())
         instructionToDialect[id] = &iface;
       // Add a mapping for all supported metadata kinds.
-      for (unsigned kind : iface.getSupportedMetadata())
+      for (unsigned kind : iface.getSupportedMetadata(llvmContext))
         metadataToDialect[kind].push_back(iface.getDialect());
     }
 
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 04d098d38155b..df3feb0a3a280 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -53,7 +53,9 @@ class ModuleImport {
   /// dialect interfaces for the supported LLVM IR intrinsics and metadata kinds
   /// and builds the dispatch tables. Returns failure if multiple dialect
   /// interfaces translate the same LLVM IR intrinsic.
-  LogicalResult initializeImportInterface() { return iface.initializeImport(); }
+  LogicalResult initializeImportInterface() {
+    return iface.initializeImport(llvmModule->getContext());
+  }
 
   /// Converts all functions of the LLVM module to MLIR functions.
   LogicalResult convertFunctions();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 06673965245c0..c97eca73aac4f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -32,6 +32,14 @@ using namespace mlir::LLVM::detail;
 
 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
 
+static constexpr StringLiteral vecTypeHintAttrName = "vec_type_hint";
+static constexpr StringLiteral workGroupSizeHintAttrName =
+    "work_group_size_hint";
+static constexpr StringLiteral reqdWorkGroupSizeAttrName =
+    "reqd_work_group_size";
+static constexpr StringLiteral intelReqdSubGroupSizeAttrName =
+    "intel_reqd_sub_group_size";
+
 /// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
 /// intrinsic. Returns false otherwise.
 static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) {
@@ -70,11 +78,18 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
 
 /// Returns the list of LLVM IR metadata kinds that are convertible to MLIR LLVM
 /// dialect attributes.
-static ArrayRef<unsigned> getSupportedMetadataImpl() {
+static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
   static const SmallVector<unsigned> convertibleMetadata = {
-      llvm::LLVMContext::MD_prof,         llvm::LLVMContext::MD_tbaa,
-      llvm::LLVMContext::MD_access_group, llvm::LLVMContext::MD_loop,
-      llvm::LLVMContext::MD_noalias,      llvm::LLVMContext::MD_alias_scope};
+      llvm::LLVMContext::MD_prof,
+      llvm::LLVMContext::MD_tbaa,
+      llvm::LLVMContext::MD_access_group,
+      llvm::LLVMContext::MD_loop,
+      llvm::LLVMContext::MD_noalias,
+      llvm::LLVMContext::MD_alias_scope,
+      context.getMDKindID(vecTypeHintAttrName),
+      context.getMDKindID(workGroupSizeHintAttrName),
+      context.getMDKindID(reqdWorkGroupSizeAttrName),
+      context.getMDKindID(intelReqdSubGroupSizeAttrName)};
   return convertibleMetadata;
 }
 
@@ -226,6 +241,133 @@ static LogicalResult setNoaliasScopesAttr(const llvm::MDNode *node,
   return success();
 }
 
+/// Extract constant integer value from metadata if this is constant. Return
+/// `std::nullopt` otherwise.
+static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
+  auto *c = llvm::dyn_cast_or_null<llvm::ConstantAsMetadata>(md);
+  if (!c)
+    return {};
+
+  auto *ci = dyn_cast<llvm::ConstantInt>(c->getValue());
+  if (!ci)
+    return {};
+
+  return ci->getValue().getSExtValue();
+}
+
+/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
+template <typename ConvertType>
+static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
+                                          ConvertType convertType) {
+  if (!md || md->getNumOperands() != 2)
+    return {};
+
+  auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(md->getOperand(0).get());
+  if (!hintMD)
+    return {};
+  TypeAttr hint = TypeAttr::get(convertType(hintMD->getType()));
+
+  std::optional<int32_t> optIsSigned = parseIntegerMD(md->getOperand(1).get());
+  if (!optIsSigned)
+    return {};
+  bool isSigned = *optIsSigned != 0;
+
+  return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
+}
+
+/// Convert an `MDNode` to an MLIR `DenseI32ArrayAttr` if possible.
+static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
+                                              llvm::MDNode *md) {
+  if (!md)
+    return {};
+  SmallVector<int32_t> vals;
+  for (const llvm::MDOperand &op : md->operands()) {
+    std::optional<int32_t> mdValue = parseIntegerMD(op.get());
+    if (!mdValue)
+      return {};
+    vals.push_back(*mdValue);
+  }
+  return builder.getDenseI32ArrayAttr(vals);
+}
+
+/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
+static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
+  if (!md || md->getNumOperands() != 1)
+    return {};
+  std::optional<int32_t> val = parseIntegerMD(md->getOperand(0));
+  if (!val)
+    return {};
+  return builder.getI32IntegerAttr(*val);
+}
+
+template <typename Parser, typename Setter>
+static LogicalResult setFuncAttr(Builder &builder, llvm::MDNode *node,
+                                 Operation *op, Parser parse, Setter set) {
+  auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+  if (!funcOp)
+    return failure();
+
+  auto attr = parse(node);
+  if (!attr)
+    return failure();
+
+  set(funcOp, attr);
+  return success();
+}
+
+static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
+                                        Operation *op,
+                                        LLVM::ModuleImport &moduleImport) {
+  return setFuncAttr(
+      builder, node, op,
+      [&builder, &moduleImport](llvm::MDNode *node) {
+        return convertVecTypeHint(builder, node,
+                                  [&moduleImport](llvm::Type *type) {
+                                    return moduleImport.convertType(type);
+                                  });
+      },
+      [](LLVM::LLVMFuncOp funcOp, VecTypeHintAttr attr) {
+        funcOp.setVecTypeHintAttr(attr);
+      });
+}
+
+static LogicalResult
+setWorkGroupSizeHintAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
+  return setFuncAttr(
+      builder, node, op,
+      [&builder](llvm::MDNode *node) {
+        return convertDenseI32Array(builder, node);
+      },
+      [](LLVM::LLVMFuncOp funcOp, DenseI32ArrayAttr attr) {
+        funcOp.setWorkGroupSizeHintAttr(attr);
+      });
+}
+
+static LogicalResult
+setReqdWorkGroupSizeAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
+  return setFuncAttr(
+      builder, node, op,
+      [&builder](llvm::MDNode *node) {
+        return convertDenseI32Array(builder, node);
+      },
+      [](LLVM::LLVMFuncOp funcOp, DenseI32ArrayAttr attr) {
+        funcOp.setReqdWorkGroupSizeAttr(attr);
+      });
+}
+
+static LogicalResult setIntelReqdSubGroupSizeAttr(Builder &builder,
+                                                  llvm::MDNode *node,
+                                                  Operation *op) {
+  return setFuncAttr(
+      builder, node, op,
+      [&builder](llvm::MDNode *node) {
+        return convertIntegerMD(builder, node);
+      },
+      [](LLVM::LLVMFuncOp funcOp, IntegerAttr attr) {
+        funcOp.setIntelReqdSubGroupSizeAttr(attr);
+      });
+}
+
 namespace {
 
 /// Implementation of the dialect interface that converts operations belonging
@@ -261,6 +403,16 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
     if (kind == llvm::LLVMContext::MD_noalias)
       return setNoaliasScopesAttr(node, op, moduleImport);
 
+    llvm::LLVMContext &context = node->getContext();
+    if (kind == context.getMDKindID(vecTypeHintAttrName))
+      return setVecTypeHintAttr(builder, node, op, moduleImport);
+    if (kind == context.getMDKindID(workGroupSizeHintAttrName))
+      return setWorkGroupSizeHintAttr(builder, node, op);
+    if (kind == context.getMDKindID(reqdWorkGroupSizeAttrName))
+      return setReqdWorkGroupSizeAttr(builder, node, op);
+    if (kind == context.getMDKindID(intelReqdSubGroupSizeAttrName))
+      return setIntelReqdSubGroupSizeAttr(builder, node, op);
+
     // A handler for a supported metadata kind is missing.
     llvm_unreachable("unknown metadata type");
   }
@@ -273,8 +425,9 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
 
   /// Returns the list of LLVM IR metadata kinds that are convertible to MLIR
   /// LLVM dialect attributes.
-  ArrayRef<unsigned> getSupportedMetadata() const final {
-    return getSupportedMetadataImpl();
+  ArrayRef<unsigned>
+  getSupportedMetadata(llvm::LLVMContext &context) const final {
+    return getSupportedMetadataImpl(context);
   }
 };
 } // namespace
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index cb18dc5193352..8b40b7b2df6c7 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1918,99 +1918,6 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
       builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
 }
 
-/// Extract constant integer value from metadata if this is constant. Return
-/// `std::nullopt` otherwise.
-static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
-  if (!md)
-    return {};
-
-  auto *c = dyn_cast<llvm::ConstantAsMetadata>(md);
-  if (!c)
-    return {};
-
-  auto *ci = dyn_cast<llvm::ConstantInt>(c->getValue());
-  if (!ci)
-    return {};
-
-  return ci->getValue().getSExtValue();
-}
-
-/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
-template <typename ConvertType>
-static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
-                                          ConvertType convertType) {
-  if (!md || md->getNumOperands() != 2)
-    return {};
-
-  auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(md->getOperand(0).get());
-  if (!hintMD)
-    return {};
-  TypeAttr hint = TypeAttr::get(convertType(hintMD->getType()));
-
-  std::optional<int32_t> optIsSigned = parseIntegerMD(md->getOperand(1).get());
-  if (!optIsSigned)
-    return {};
-  bool isSigned = *optIsSigned != 0;
-
-  return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
-}
-
-/// Convert an `MDNode` to an MLIR `DenseI32ArrayAttr` if possible.
-static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
-                                              llvm::MDNode *md) {
-  if (!md)
-    return {};
-  SmallVector<int32_t> vals;
-  for (const llvm::MDOperand &op : md->operands()) {
-    std::optional<int32_t> mdValue = parseIntegerMD(op.get());
-    if (!mdValue)
-      return {};
-    vals.push_back(*mdValue);
-  }
-  return builder.getDenseI32ArrayAttr(vals);
-}
-
-/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
-static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
-  if (!md || md->getNumOperands() != 1)
-    return {};
-  std::optional<int32_t> val = parseIntegerMD(md->getOperand(0));
-  if (!val)
-    return {};
-  return builder.getI32IntegerAttr(*val);
-}
-
-/// Process metadata found in kernel functions:
-/// - `vec_type_hint`
-/// - `work_group_size_hint`
-/// - `reqd_work_group_size`
-/// - `intel_reqd_sub_group_size`
-template <typename ConvertType>
-static void processKernelMetadata(llvm::Function *func, LLVMFuncOp funcOp,
-                                  ConvertType convertType) {
-  Builder builder(funcOp);
-
-  if (VecTypeHintAttr attr = convertVecTypeHint(
-          builder, func->getMetadata("vec_type_hint"), convertType)) {
-    funcOp.setVecTypeHintAttr(attr);
-  }
-
-  if (DenseI32ArrayAttr attr = convertDenseI32Array(
-          builder, func->getMetadata("work_group_size_hint"))) {
-    funcOp.setWorkGroupSizeHintAttr(attr);
-  }
-
-  if (DenseI32ArrayAttr attr = convertDenseI32Array(
-          builder, func->getMetadata("reqd_work_group_size"))) {
-    funcOp.setReqdWorkGroupSizeAttr(attr);
-  }
-
-  if (IntegerAttr attr = convertIntegerMD(
-          builder, func->getMetadata("intel_reqd_sub_group_size"))) {
-    funcOp.setIntelReqdSubGroupSizeAttr(attr);
-  }
-}
-
 LogicalResult ModuleImport::processFunction(llvm::Function *func) {
   clearRegionState();
 
@@ -2059,10 +1966,6 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
   // Handle Function attributes.
   processFunctionAttributes(func, funcOp);
 
-  // Handle Kernel Metadata
-  processKernelMetadata(func, funcOp,
-                        [this](llvm::Type *type) { return convertType(type); });
-
   // Convert non-debug metadata by using the dialect interface.
   SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
   func->getAllMetadata(allMetadata);

>From d2534a9425121a58af037823da7ec992b10acb5d Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Wed, 31 Jul 2024 17:40:09 +0100
Subject: [PATCH 3/6] Add documentation

Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 7 +++++++
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td      | 1 -
 2 files changed, 7 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 0ecd2dcacffc1..4858866d2f816 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1077,6 +1077,13 @@ def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
 
 /// Represents "vec_type_hint" values
 def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
+  let summary = "Explicit vectorization compiler hint";
+  let description = [{
+    A hint to the compiler that indicates most operations used in the function
+    are explictly vectorized using a particular vector type. `$hint` is the
+    vector or scalar type in particular. `$is_signed` can be used with integer
+    types to state whether the type is signed.
+  }];
   let parameters = (ins "TypeAttr":$hint,
                         DefaultValuedParameter<"bool", "false">:$is_signed);
   let assemblyFormat = "`<` struct(params) `>`";
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index fde42682d807a..c38a2584c8eec 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1457,7 +1457,6 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OptionalAttr<UnitAttr>:$no_unwind,
     OptionalAttr<UnitAttr>:$will_return,
     OptionalAttr<UnitAttr>:$optimize_none,
-    // Kernel metadata
     OptionalAttr<LLVM_VecTypeHintAttr>:$vec_type_hint,
     OptionalAttr<DenseI32ArrayAttr>:$work_group_size_hint,
     OptionalAttr<DenseI32ArrayAttr>:$reqd_work_group_size,

>From abe03ae56511d95956232cf55e7c9b879e82db34 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 1 Aug 2024 08:53:31 +0100
Subject: [PATCH 4/6] Add clarification

---
 mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
index 58e5b2c83043d..cc5a77ed35d2b 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
@@ -84,8 +84,10 @@ class LLVMImportDialectInterface
 
   /// Hook for derived dialect interfaces to publish the supported metadata
   /// kinds. As every metadata kind has a unique integer identifier, the
-  /// function returns the list of supported metadata identifiers.
-  virtual ArrayRef<unsigned> getSupportedMetadata(llvm::LLVMContext &) const {
+  /// function returns the list of supported metadata identifiers. `ctx` can be
+  /// used to obtain IDs of metadata kinds that do not have a fixed static one.
+  virtual ArrayRef<unsigned>
+  getSupportedMetadata(llvm::LLVMContext &ctx) const {
     return {};
   }
 };

>From e342f072708a42cbeaad9a3631a093f7f99692bb Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 1 Aug 2024 08:58:33 +0100
Subject: [PATCH 5/6] Simplify code

---
 .../Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp   | 16 ++++++----------
 1 file changed, 6 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index c97eca73aac4f..393a68c010ebf 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -256,16 +256,15 @@ static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
 }
 
 /// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
-template <typename ConvertType>
 static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
-                                          ConvertType convertType) {
+                                          ModuleImport &moduleImport) {
   if (!md || md->getNumOperands() != 2)
     return {};
 
   auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(md->getOperand(0).get());
   if (!hintMD)
     return {};
-  TypeAttr hint = TypeAttr::get(convertType(hintMD->getType()));
+  TypeAttr hint = TypeAttr::get(moduleImport.convertType(hintMD->getType()));
 
   std::optional<int32_t> optIsSigned = parseIntegerMD(md->getOperand(1).get());
   if (!optIsSigned)
@@ -300,14 +299,14 @@ static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
   return builder.getI32IntegerAttr(*val);
 }
 
-template <typename Parser, typename Setter>
+template <typename Encoder, typename Setter>
 static LogicalResult setFuncAttr(Builder &builder, llvm::MDNode *node,
-                                 Operation *op, Parser parse, Setter set) {
+                                 Operation *op, Encoder encode, Setter set) {
   auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
   if (!funcOp)
     return failure();
 
-  auto attr = parse(node);
+  auto attr = encode(node);
   if (!attr)
     return failure();
 
@@ -321,10 +320,7 @@ static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
   return setFuncAttr(
       builder, node, op,
       [&builder, &moduleImport](llvm::MDNode *node) {
-        return convertVecTypeHint(builder, node,
-                                  [&moduleImport](llvm::Type *type) {
-                                    return moduleImport.convertType(type);
-                                  });
+        return convertVecTypeHint(builder, node, moduleImport);
       },
       [](LLVM::LLVMFuncOp funcOp, VecTypeHintAttr attr) {
         funcOp.setVecTypeHintAttr(attr);

>From 27f9dcccdad1cf48889ac8fd780edb7350277cec Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 1 Aug 2024 11:46:03 +0100
Subject: [PATCH 6/6] Apply suggestions

---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       |   1 -
 .../LLVMIR/LLVMIRToLLVMTranslation.cpp        | 120 ++++++++----------
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  23 ++--
 3 files changed, 66 insertions(+), 78 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 4858866d2f816..529c458ce1254 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1075,7 +1075,6 @@ def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
 // VecTypeHintAttr
 //===----------------------------------------------------------------------===//
 
-/// Represents "vec_type_hint" values
 def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
   let summary = "Explicit vectorization compiler hint";
   let description = [{
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 393a68c010ebf..d76722a24c4fc 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -32,12 +32,10 @@ using namespace mlir::LLVM::detail;
 
 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
 
-static constexpr StringLiteral vecTypeHintAttrName = "vec_type_hint";
-static constexpr StringLiteral workGroupSizeHintAttrName =
-    "work_group_size_hint";
-static constexpr StringLiteral reqdWorkGroupSizeAttrName =
-    "reqd_work_group_size";
-static constexpr StringLiteral intelReqdSubGroupSizeAttrName =
+static constexpr StringLiteral vecTypeHintMDName = "vec_type_hint";
+static constexpr StringLiteral workGroupSizeHintMDName = "work_group_size_hint";
+static constexpr StringLiteral reqdWorkGroupSizeMDName = "reqd_work_group_size";
+static constexpr StringLiteral intelReqdSubGroupSizeMDName =
     "intel_reqd_sub_group_size";
 
 /// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
@@ -86,10 +84,10 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
       llvm::LLVMContext::MD_loop,
       llvm::LLVMContext::MD_noalias,
       llvm::LLVMContext::MD_alias_scope,
-      context.getMDKindID(vecTypeHintAttrName),
-      context.getMDKindID(workGroupSizeHintAttrName),
-      context.getMDKindID(reqdWorkGroupSizeAttrName),
-      context.getMDKindID(intelReqdSubGroupSizeAttrName)};
+      context.getMDKindID(vecTypeHintMDName),
+      context.getMDKindID(workGroupSizeHintMDName),
+      context.getMDKindID(reqdWorkGroupSizeMDName),
+      context.getMDKindID(intelReqdSubGroupSizeMDName)};
   return convertibleMetadata;
 }
 
@@ -241,21 +239,22 @@ static LogicalResult setNoaliasScopesAttr(const llvm::MDNode *node,
   return success();
 }
 
-/// Extract constant integer value from metadata if this is constant. Return
-/// `std::nullopt` otherwise.
+/// Extracts an integer from the provided metadata `md` if possible. Returns
+/// nullopt otherwise.
 static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
-  auto *c = llvm::dyn_cast_or_null<llvm::ConstantAsMetadata>(md);
-  if (!c)
+  auto *constant = dyn_cast_if_present<llvm::ConstantAsMetadata>(md);
+  if (!constant)
     return {};
 
-  auto *ci = dyn_cast<llvm::ConstantInt>(c->getValue());
-  if (!ci)
+  auto *intConstant = dyn_cast<llvm::ConstantInt>(constant->getValue());
+  if (!intConstant)
     return {};
 
-  return ci->getValue().getSExtValue();
+  return intConstant->getValue().getSExtValue();
 }
 
-/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
+/// Converts the provided metadata node `md` to an LLVM dialect VecTypeHintAttr
+/// if possible.
 static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
                                           ModuleImport &moduleImport) {
   if (!md || md->getNumOperands() != 2)
@@ -299,69 +298,62 @@ static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
   return builder.getI32IntegerAttr(*val);
 }
 
-template <typename Encoder, typename Setter>
-static LogicalResult setFuncAttr(Builder &builder, llvm::MDNode *node,
-                                 Operation *op, Encoder encode, Setter set) {
+static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
+                                        Operation *op,
+                                        LLVM::ModuleImport &moduleImport) {
   auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
   if (!funcOp)
     return failure();
 
-  auto attr = encode(node);
+  VecTypeHintAttr attr = convertVecTypeHint(builder, node, moduleImport);
   if (!attr)
     return failure();
 
-  set(funcOp, attr);
+  funcOp.setVecTypeHintAttr(attr);
   return success();
 }
 
-static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
-                                        Operation *op,
-                                        LLVM::ModuleImport &moduleImport) {
-  return setFuncAttr(
-      builder, node, op,
-      [&builder, &moduleImport](llvm::MDNode *node) {
-        return convertVecTypeHint(builder, node, moduleImport);
-      },
-      [](LLVM::LLVMFuncOp funcOp, VecTypeHintAttr attr) {
-        funcOp.setVecTypeHintAttr(attr);
-      });
-}
-
 static LogicalResult
 setWorkGroupSizeHintAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
-  return setFuncAttr(
-      builder, node, op,
-      [&builder](llvm::MDNode *node) {
-        return convertDenseI32Array(builder, node);
-      },
-      [](LLVM::LLVMFuncOp funcOp, DenseI32ArrayAttr attr) {
-        funcOp.setWorkGroupSizeHintAttr(attr);
-      });
+  auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+  if (!funcOp)
+    return failure();
+
+  DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
+  if (!attr)
+    return failure();
+
+  funcOp.setWorkGroupSizeHintAttr(attr);
+  return success();
 }
 
 static LogicalResult
 setReqdWorkGroupSizeAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
-  return setFuncAttr(
-      builder, node, op,
-      [&builder](llvm::MDNode *node) {
-        return convertDenseI32Array(builder, node);
-      },
-      [](LLVM::LLVMFuncOp funcOp, DenseI32ArrayAttr attr) {
-        funcOp.setReqdWorkGroupSizeAttr(attr);
-      });
+  auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+  if (!funcOp)
+    return failure();
+
+  DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
+  if (!attr)
+    return failure();
+
+  funcOp.setReqdWorkGroupSizeAttr(attr);
+  return success();
 }
 
 static LogicalResult setIntelReqdSubGroupSizeAttr(Builder &builder,
                                                   llvm::MDNode *node,
                                                   Operation *op) {
-  return setFuncAttr(
-      builder, node, op,
-      [&builder](llvm::MDNode *node) {
-        return convertIntegerMD(builder, node);
-      },
-      [](LLVM::LLVMFuncOp funcOp, IntegerAttr attr) {
-        funcOp.setIntelReqdSubGroupSizeAttr(attr);
-      });
+  auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+  if (!funcOp)
+    return failure();
+
+  IntegerAttr attr = convertIntegerMD(builder, node);
+  if (!attr)
+    return failure();
+
+  funcOp.setIntelReqdSubGroupSizeAttr(attr);
+  return success();
 }
 
 namespace {
@@ -400,13 +392,13 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
       return setNoaliasScopesAttr(node, op, moduleImport);
 
     llvm::LLVMContext &context = node->getContext();
-    if (kind == context.getMDKindID(vecTypeHintAttrName))
+    if (kind == context.getMDKindID(vecTypeHintMDName))
       return setVecTypeHintAttr(builder, node, op, moduleImport);
-    if (kind == context.getMDKindID(workGroupSizeHintAttrName))
+    if (kind == context.getMDKindID(workGroupSizeHintMDName))
       return setWorkGroupSizeHintAttr(builder, node, op);
-    if (kind == context.getMDKindID(reqdWorkGroupSizeAttrName))
+    if (kind == context.getMDKindID(reqdWorkGroupSizeMDName))
       return setReqdWorkGroupSizeAttr(builder, node, op);
-    if (kind == context.getMDKindID(intelReqdSubGroupSizeAttrName))
+    if (kind == context.getMDKindID(intelReqdSubGroupSizeMDName))
       return setIntelReqdSubGroupSizeAttr(builder, node, op);
 
     // A handler for a supported metadata kind is missing.
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 00ca8bdbf0c23..b1c345d48a204 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1271,11 +1271,10 @@ static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
   return llvm::MDNode::get(context, {typeMD, isSignedMD});
 }
 
-/// Return an MDNode with a tuple given by the values in the input integer array
-/// attribute.
+/// Return an MDNode with a tuple given by the values in `values`.
 static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
                                                  ArrayRef<int32_t> values) {
-  llvm::SmallVector<llvm::Metadata *> mds;
+  SmallVector<llvm::Metadata *> mds;
   llvm::transform(values, std::back_inserter(mds), [&context](int32_t value) {
     return convertIntegerToMetadata(context, llvm::APInt(32, value));
   });
@@ -1484,36 +1483,35 @@ static void convertFunctionAttributes(LLVMFuncOp func,
 }
 
 /// Converts function attributes from `func` and attaches them to `llvmFunc`.
-template <typename TypeConverter>
 static void convertFunctionKernelAttributes(LLVMFuncOp func,
                                             llvm::Function *llvmFunc,
-                                            TypeConverter convertType) {
+                                            ModuleTranslation &translation) {
   llvm::LLVMContext &llvmContext = llvmFunc->getContext();
 
   if (auto vecTypeHint = func.getVecTypeHint()) {
     Type type = vecTypeHint->getHint().getValue();
-    llvm::Type *llvmType = convertType(type);
+    llvm::Type *llvmType = translation.convertType(type);
     bool isSigned = vecTypeHint->getIsSigned();
     llvmFunc->setMetadata(
-        "vec_type_hint",
+        func.getVecTypeHintAttrName(),
         convertVecTypeHintToMDNode(llvmContext, llvmType, isSigned));
   }
 
   if (auto workGroupSizeHint = func.getWorkGroupSizeHint()) {
     llvmFunc->setMetadata(
-        "work_group_size_hint",
+        func.getWorkGroupSizeHintAttrName(),
         convertIntegerArrayToMDNode(llvmContext, *workGroupSizeHint));
   }
 
   if (auto reqdWorkGroupSize = func.getReqdWorkGroupSize()) {
     llvmFunc->setMetadata(
-        "reqd_work_group_size",
+        func.getReqdWorkGroupSizeAttrName(),
         convertIntegerArrayToMDNode(llvmContext, *reqdWorkGroupSize));
   }
 
   if (auto intelReqdSubGroupSize = func.getIntelReqdSubGroupSize()) {
     llvmFunc->setMetadata(
-        "intel_reqd_sub_group_size",
+        func.getIntelReqdSubGroupSizeAttrName(),
         convertIntegerToMDNode(llvmContext,
                                llvm::APInt(32, *intelReqdSubGroupSize)));
   }
@@ -1563,9 +1561,8 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
     // Convert function attributes.
     convertFunctionAttributes(function, llvmFunc);
 
-    // Convert function kernel attributes to metadata
-    convertFunctionKernelAttributes(
-        function, llvmFunc, [this](Type type) { return convertType(type); });
+    // Convert function kernel attributes to metadata.
+    convertFunctionKernelAttributes(function, llvmFunc, *this);
 
     // Convert function_entry_count attribute to metadata.
     if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())



More information about the Mlir-commits mailing list