[Mlir-commits] [mlir] [MLIR][LLVM] Attach kernel metadata representation to `llvm.func` (PR #101314)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 31 03:06:47 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Victor Perez (victor-eds)

<details>
<summary>Changes</summary>

Add optional attributes to `llvm.func` representing LLVM so-called "kernel" metadata:

- [`vec_type_hint`](https://clang.llvm.org/docs/AttributeReference.html#vec-type-hint)
- [`work_group_size_hint`](https://clang.llvm.org/docs/AttributeReference.html#work-group-size-hint)
- [`reqd_work_group_size`](https://clang.llvm.org/docs/AttributeReference.html#reqd-work-group-size)
- [`intel_reqd_sub_group_size`](https://clang.llvm.org/docs/AttributeReference.html#intel-reqd-sub-group-size).

---
Full diff: https://github.com/llvm/llvm-project/pull/101314.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td (+11) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+6-1) 
- (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+97) 
- (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+75) 
- (modified) mlir/test/Dialect/LLVMIR/func.mlir (+41-1) 
- (added) mlir/test/Target/LLVMIR/Import/metadata-kernel.ll (+34) 
- (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+44) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 695a962bcab9b..0ecd2dcacffc1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1071,6 +1071,17 @@ def LLVM_UndefAttr : LLVM_Attr<"Undef", "undef">;
 /// Folded into from LLVM::PoisonOp.
 def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
 
+//===----------------------------------------------------------------------===//
+// VecTypeHintAttr
+//===----------------------------------------------------------------------===//
+
+/// Represents "vec_type_hint" values
+def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
+  let parameters = (ins "TypeAttr":$hint,
+                        DefaultValuedParameter<"bool", "false">:$is_signed);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // ZeroAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 260d42185b57f..fde42682d807a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1456,7 +1456,12 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OptionalAttr<UnitAttr>:$always_inline,
     OptionalAttr<UnitAttr>:$no_unwind,
     OptionalAttr<UnitAttr>:$will_return,
-    OptionalAttr<UnitAttr>:$optimize_none
+    OptionalAttr<UnitAttr>:$optimize_none,
+    // Kernel metadata
+    OptionalAttr<LLVM_VecTypeHintAttr>:$vec_type_hint,
+    OptionalAttr<DenseI32ArrayAttr>:$work_group_size_hint,
+    OptionalAttr<DenseI32ArrayAttr>:$reqd_work_group_size,
+    OptionalAttr<I32Attr>:$intel_reqd_sub_group_size
   );
 
   let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 8b40b7b2df6c7..cb18dc5193352 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1918,6 +1918,99 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
       builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
 }
 
+/// Extract constant integer value from metadata if this is constant. Return
+/// `std::nullopt` otherwise.
+static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
+  if (!md)
+    return {};
+
+  auto *c = dyn_cast<llvm::ConstantAsMetadata>(md);
+  if (!c)
+    return {};
+
+  auto *ci = dyn_cast<llvm::ConstantInt>(c->getValue());
+  if (!ci)
+    return {};
+
+  return ci->getValue().getSExtValue();
+}
+
+/// Convert an `MDNode` to an LLVM dialect `VecTypeHintAttr` if possible.
+template <typename ConvertType>
+static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *md,
+                                          ConvertType convertType) {
+  if (!md || md->getNumOperands() != 2)
+    return {};
+
+  auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(md->getOperand(0).get());
+  if (!hintMD)
+    return {};
+  TypeAttr hint = TypeAttr::get(convertType(hintMD->getType()));
+
+  std::optional<int32_t> optIsSigned = parseIntegerMD(md->getOperand(1).get());
+  if (!optIsSigned)
+    return {};
+  bool isSigned = *optIsSigned != 0;
+
+  return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
+}
+
+/// Convert an `MDNode` to an MLIR `DenseI32ArrayAttr` if possible.
+static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
+                                              llvm::MDNode *md) {
+  if (!md)
+    return {};
+  SmallVector<int32_t> vals;
+  for (const llvm::MDOperand &op : md->operands()) {
+    std::optional<int32_t> mdValue = parseIntegerMD(op.get());
+    if (!mdValue)
+      return {};
+    vals.push_back(*mdValue);
+  }
+  return builder.getDenseI32ArrayAttr(vals);
+}
+
+/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
+static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *md) {
+  if (!md || md->getNumOperands() != 1)
+    return {};
+  std::optional<int32_t> val = parseIntegerMD(md->getOperand(0));
+  if (!val)
+    return {};
+  return builder.getI32IntegerAttr(*val);
+}
+
+/// Process metadata found in kernel functions:
+/// - `vec_type_hint`
+/// - `work_group_size_hint`
+/// - `reqd_work_group_size`
+/// - `intel_reqd_sub_group_size`
+template <typename ConvertType>
+static void processKernelMetadata(llvm::Function *func, LLVMFuncOp funcOp,
+                                  ConvertType convertType) {
+  Builder builder(funcOp);
+
+  if (VecTypeHintAttr attr = convertVecTypeHint(
+          builder, func->getMetadata("vec_type_hint"), convertType)) {
+    funcOp.setVecTypeHintAttr(attr);
+  }
+
+  if (DenseI32ArrayAttr attr = convertDenseI32Array(
+          builder, func->getMetadata("work_group_size_hint"))) {
+    funcOp.setWorkGroupSizeHintAttr(attr);
+  }
+
+  if (DenseI32ArrayAttr attr = convertDenseI32Array(
+          builder, func->getMetadata("reqd_work_group_size"))) {
+    funcOp.setReqdWorkGroupSizeAttr(attr);
+  }
+
+  if (IntegerAttr attr = convertIntegerMD(
+          builder, func->getMetadata("intel_reqd_sub_group_size"))) {
+    funcOp.setIntelReqdSubGroupSizeAttr(attr);
+  }
+}
+
 LogicalResult ModuleImport::processFunction(llvm::Function *func) {
   clearRegionState();
 
@@ -1966,6 +2059,10 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
   // Handle Function attributes.
   processFunctionAttributes(func, funcOp);
 
+  // Handle Kernel Metadata
+  processKernelMetadata(func, funcOp,
+                        [this](llvm::Type *type) { return convertType(type); });
+
   // Convert non-debug metadata by using the dialect interface.
   SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
   func->getAllMetadata(allMetadata);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 3016d1846e00f..00ca8bdbf0c23 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1247,6 +1247,41 @@ static LogicalResult checkedAddLLVMFnAttribute(Location loc,
   return success();
 }
 
+/// Return a representation of `value` as metadata.
+static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
+                                                const llvm::APInt &value) {
+  llvm::Constant *constant = llvm::ConstantInt::get(context, value);
+  return llvm::ConstantAsMetadata::get(constant);
+}
+
+/// Return a representation of `value` as an MDNode.
+static llvm::MDNode *convertIntegerToMDNode(llvm::LLVMContext &context,
+                                            const llvm::APInt &value) {
+  return llvm::MDNode::get(context, convertIntegerToMetadata(context, value));
+}
+
+/// Return an MDNode encoding `vec_type_hint` metadata.
+static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
+                                                llvm::Type *type,
+                                                bool isSigned) {
+  llvm::Metadata *typeMD =
+      llvm::ConstantAsMetadata::get(llvm::UndefValue::get(type));
+  llvm::Metadata *isSignedMD =
+      convertIntegerToMetadata(context, llvm::APInt(32, isSigned ? 1 : 0));
+  return llvm::MDNode::get(context, {typeMD, isSignedMD});
+}
+
+/// Return an MDNode with a tuple given by the values in the input integer array
+/// attribute.
+static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
+                                                 ArrayRef<int32_t> values) {
+  llvm::SmallVector<llvm::Metadata *> mds;
+  llvm::transform(values, std::back_inserter(mds), [&context](int32_t value) {
+    return convertIntegerToMetadata(context, llvm::APInt(32, value));
+  });
+  return llvm::MDNode::get(context, mds);
+}
+
 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
 /// to be an array attribute containing either string attributes, treated as
@@ -1448,6 +1483,42 @@ static void convertFunctionAttributes(LLVMFuncOp func,
   convertFunctionMemoryAttributes(func, llvmFunc);
 }
 
+/// Converts function attributes from `func` and attaches them to `llvmFunc`.
+template <typename TypeConverter>
+static void convertFunctionKernelAttributes(LLVMFuncOp func,
+                                            llvm::Function *llvmFunc,
+                                            TypeConverter convertType) {
+  llvm::LLVMContext &llvmContext = llvmFunc->getContext();
+
+  if (auto vecTypeHint = func.getVecTypeHint()) {
+    Type type = vecTypeHint->getHint().getValue();
+    llvm::Type *llvmType = convertType(type);
+    bool isSigned = vecTypeHint->getIsSigned();
+    llvmFunc->setMetadata(
+        "vec_type_hint",
+        convertVecTypeHintToMDNode(llvmContext, llvmType, isSigned));
+  }
+
+  if (auto workGroupSizeHint = func.getWorkGroupSizeHint()) {
+    llvmFunc->setMetadata(
+        "work_group_size_hint",
+        convertIntegerArrayToMDNode(llvmContext, *workGroupSizeHint));
+  }
+
+  if (auto reqdWorkGroupSize = func.getReqdWorkGroupSize()) {
+    llvmFunc->setMetadata(
+        "reqd_work_group_size",
+        convertIntegerArrayToMDNode(llvmContext, *reqdWorkGroupSize));
+  }
+
+  if (auto intelReqdSubGroupSize = func.getIntelReqdSubGroupSize()) {
+    llvmFunc->setMetadata(
+        "intel_reqd_sub_group_size",
+        convertIntegerToMDNode(llvmContext,
+                               llvm::APInt(32, *intelReqdSubGroupSize)));
+  }
+}
+
 FailureOr<llvm::AttrBuilder>
 ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
                                          DictionaryAttr paramAttrs) {
@@ -1492,6 +1563,10 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
     // Convert function attributes.
     convertFunctionAttributes(function, llvmFunc);
 
+    // Convert function kernel attributes to metadata
+    convertFunctionKernelAttributes(
+        function, llvmFunc, [this](Type type) { return convertType(type); });
+
     // Convert function_entry_count attribute to metadata.
     if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
       llvmFunc->setEntryCount(entryCount.value());
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 0e29a548de72f..40b4e49f08a3e 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 // RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-op-generic %s | FileCheck %s --check-prefix=GENERIC
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s -mlir-print-debuginfo | mlir-opt -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
+// RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-debuginfo %s | mlir-opt -split-input-file -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
 // RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-LLVM
 
 module {
@@ -432,3 +432,43 @@ module {
   // expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
   }) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
 }
+
+// -----
+
+// CHECK: @vec_type_hint()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32>
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: @vec_type_hint_signed()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: @vec_type_hint_signed_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: @vec_type_hint_float_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: @vec_type_hint_bfloat_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// -----
+
+// CHECK: @work_group_size_hint()
+// CHECK-SAME: work_group_size_hint = array<i32: 128, 128, 128>
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// -----
+
+// CHECK: @reqd_work_group_size_hint()
+// CHECK-SAME: reqd_work_group_size = array<i32: 128, 256, 128>
+llvm.func @reqd_work_group_size_hint() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// -----
+
+// CHECK: @intel_reqd_sub_group_size_hint()
+// CHECK-SAME: intel_reqd_sub_group_size = 32 : i32
+llvm.func @intel_reqd_sub_group_size_hint() attributes {llvm.intel_reqd_sub_group_size = 32 : i32}
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
new file mode 100644
index 0000000000000..963e40348c795
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
@@ -0,0 +1,34 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK:   llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+declare !vec_type_hint !0 void @vec_type_hint()
+
+; CHECK:   llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+declare !vec_type_hint !1 void @vec_type_hint_signed()
+
+; CHECK:   llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+declare !vec_type_hint !2 void @vec_type_hint_signed_vec()
+
+; CHECK:   llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+declare !vec_type_hint !3 void @vec_type_hint_float_vec()
+
+; CHECK:   llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+declare !vec_type_hint !4 void @vec_type_hint_bfloat_vec()
+
+; CHECK:   llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+declare !work_group_size_hint !5 void @work_group_size_hint()
+
+; CHECK:   llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+declare !reqd_work_group_size !6 void @reqd_work_group_size()
+
+; CHECK:   llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+declare !intel_reqd_sub_group_size !7 void @intel_reqd_sub_group_size()
+
+!0 = !{i32 undef, i32 0}
+!1 = !{i32 undef, i32 1}
+!2 = !{<2 x i32> undef, i32 1}
+!3 = !{<3 x float> undef, i32 0}
+!4 = !{<8 x bfloat> undef, i32 0}
+!5 = !{i32 128, i32 128, i32 128}
+!6 = !{i32 128, i32 256, i32 128}
+!7 = !{i32 32}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index db54d131299c6..8c2efa3d16f6b 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2549,3 +2549,47 @@ llvm.func @mem_effects_call() {
 // CHECK-SAME: memory(read, inaccessiblemem: write)
 // CHECK: #[[ATTRS_3]]
 // CHECK-SAME: memory(readwrite, argmem: read)
+
+// -----
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT:]] void @vec_type_hint()
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED:]] void @vec_type_hint_signed()
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED_VEC:]] void @vec_type_hint_signed_vec()
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_FVEC:]] void @vec_type_hint_float_vec()
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_BFVEC:]] void @vec_type_hint_bfloat_vec()
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// CHECK: ![[#VEC_TYPE_HINT]] = !{i32 undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED]] = !{i32 undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED_VEC]] = !{<2 x i32> undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_FVEC]] = !{<3 x float> undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_BFVEC]] = !{<8 x bfloat> undef, i32 0}
+
+// -----
+
+// CHECK: declare !work_group_size_hint ![[#WORK_GROUP_SIZE_HINT:]] void @work_group_size_hint()
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// CHECK: ![[#WORK_GROUP_SIZE_HINT]] = !{i32 128, i32 128, i32 128}
+
+// -----
+
+// CHECK: declare !reqd_work_group_size ![[#REQD_WORK_GROUP_SIZE:]] void @reqd_work_group_size()
+llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// CHECK: ![[#REQD_WORK_GROUP_SIZE]] = !{i32 128, i32 256, i32 128}
+
+// -----
+
+// CHECK: declare !intel_reqd_sub_group_size ![[#INTEL_REQD_SUB_GROUP_SIZE:]] void @intel_reqd_sub_group_size()
+llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+
+// CHECK: ![[#INTEL_REQD_SUB_GROUP_SIZE]] = !{i32 32}

``````````

</details>


https://github.com/llvm/llvm-project/pull/101314


More information about the Mlir-commits mailing list