[Mlir-commits] [mlir] 6d2bbba - [MLIR][LLVM] Attach kernel metadata representation to `llvm.func` (#101314)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 2 02:53:36 PDT 2024


Author: Victor Perez
Date: 2024-08-02T11:53:30+02:00
New Revision: 6d2bbba187cd8fdc3e6e46cb753d4d9c6c276103

URL: https://github.com/llvm/llvm-project/commit/6d2bbba187cd8fdc3e6e46cb753d4d9c6c276103
DIFF: https://github.com/llvm/llvm-project/commit/6d2bbba187cd8fdc3e6e46cb753d4d9c6c276103.diff

LOG: [MLIR][LLVM] Attach kernel metadata representation to `llvm.func` (#101314)

Add optional attributes to `llvm.func` representing LLVM so-called
"kernel" metadata:

-
[`vec_type_hint`](https://clang.llvm.org/docs/AttributeReference.html#vec-type-hint)
-
[`work_group_size_hint`](https://clang.llvm.org/docs/AttributeReference.html#work-group-size-hint)
-
[`reqd_work_group_size`](https://clang.llvm.org/docs/AttributeReference.html#reqd-work-group-size)
-
[`intel_reqd_sub_group_size`](https://clang.llvm.org/docs/AttributeReference.html#intel-reqd-sub-group-size).

---------

Signed-off-by: Victor Perez <victor.perez at codeplay.com>

Added: 
    mlir/test/Target/LLVMIR/Import/metadata-kernel.ll

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
    mlir/include/mlir/Target/LLVMIR/ModuleImport.h
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/func.mlir
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 695a962bcab9b..529c458ce1254 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1071,6 +1071,23 @@ def LLVM_UndefAttr : LLVM_Attr<"Undef", "undef">;
 /// Folded into from LLVM::PoisonOp.
 def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
 
+//===----------------------------------------------------------------------===//
+// VecTypeHintAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_VecTypeHintAttr : LLVM_Attr<"VecTypeHint", "vec_type_hint"> {
+  let summary = "Explicit vectorization compiler hint";
+  let description = [{
+    A hint to the compiler that indicates most operations used in the function
+    are explictly vectorized using a particular vector type. `$hint` is the
+    vector or scalar type in particular. `$is_signed` can be used with integer
+    types to state whether the type is signed.
+  }];
+  let parameters = (ins "TypeAttr":$hint,
+                        DefaultValuedParameter<"bool", "false">:$is_signed);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // ZeroAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 260d42185b57f..c38a2584c8eec 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1456,7 +1456,11 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OptionalAttr<UnitAttr>:$always_inline,
     OptionalAttr<UnitAttr>:$no_unwind,
     OptionalAttr<UnitAttr>:$will_return,
-    OptionalAttr<UnitAttr>:$optimize_none
+    OptionalAttr<UnitAttr>:$optimize_none,
+    OptionalAttr<LLVM_VecTypeHintAttr>:$vec_type_hint,
+    OptionalAttr<DenseI32ArrayAttr>:$work_group_size_hint,
+    OptionalAttr<DenseI32ArrayAttr>:$reqd_work_group_size,
+    OptionalAttr<I32Attr>:$intel_reqd_sub_group_size
   );
 
   let regions = (region AnyRegion:$body);

diff  --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
index 74b545490b045..cc5a77ed35d2b 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
@@ -84,8 +84,12 @@ class LLVMImportDialectInterface
 
   /// Hook for derived dialect interfaces to publish the supported metadata
   /// kinds. As every metadata kind has a unique integer identifier, the
-  /// function returns the list of supported metadata identifiers.
-  virtual ArrayRef<unsigned> getSupportedMetadata() const { return {}; }
+  /// function returns the list of supported metadata identifiers. `ctx` can be
+  /// used to obtain IDs of metadata kinds that do not have a fixed static one.
+  virtual ArrayRef<unsigned>
+  getSupportedMetadata(llvm::LLVMContext &ctx) const {
+    return {};
+  }
 };
 
 /// Interface collection for the import of LLVM IR that dispatches to a concrete
@@ -101,7 +105,7 @@ class LLVMImportInterface
   /// intrinsic and metadata kinds and builds the dispatch tables for the
   /// conversion. Returns failure if multiple dialect interfaces translate the
   /// same LLVM IR intrinsic.
-  LogicalResult initializeImport() {
+  LogicalResult initializeImport(llvm::LLVMContext &llvmContext) {
     for (const LLVMImportDialectInterface &iface : *this) {
       // Verify the supported intrinsics have not been mapped before.
       const auto *intrinsicIt =
@@ -139,7 +143,7 @@ class LLVMImportInterface
       for (unsigned id : iface.getSupportedInstructions())
         instructionToDialect[id] = &iface;
       // Add a mapping for all supported metadata kinds.
-      for (unsigned kind : iface.getSupportedMetadata())
+      for (unsigned kind : iface.getSupportedMetadata(llvmContext))
         metadataToDialect[kind].push_back(iface.getDialect());
     }
 

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 04d098d38155b..df3feb0a3a280 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -53,7 +53,9 @@ class ModuleImport {
   /// dialect interfaces for the supported LLVM IR intrinsics and metadata kinds
   /// and builds the dispatch tables. Returns failure if multiple dialect
   /// interfaces translate the same LLVM IR intrinsic.
-  LogicalResult initializeImportInterface() { return iface.initializeImport(); }
+  LogicalResult initializeImportInterface() {
+    return iface.initializeImport(llvmModule->getContext());
+  }
 
   /// Converts all functions of the LLVM module to MLIR functions.
   LogicalResult convertFunctions();

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 06673965245c0..d034e576dfc57 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -32,6 +32,12 @@ using namespace mlir::LLVM::detail;
 
 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
 
+static constexpr StringLiteral vecTypeHintMDName = "vec_type_hint";
+static constexpr StringLiteral workGroupSizeHintMDName = "work_group_size_hint";
+static constexpr StringLiteral reqdWorkGroupSizeMDName = "reqd_work_group_size";
+static constexpr StringLiteral intelReqdSubGroupSizeMDName =
+    "intel_reqd_sub_group_size";
+
 /// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
 /// intrinsic. Returns false otherwise.
 static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) {
@@ -70,11 +76,18 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
 
 /// Returns the list of LLVM IR metadata kinds that are convertible to MLIR LLVM
 /// dialect attributes.
-static ArrayRef<unsigned> getSupportedMetadataImpl() {
+static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
   static const SmallVector<unsigned> convertibleMetadata = {
-      llvm::LLVMContext::MD_prof,         llvm::LLVMContext::MD_tbaa,
-      llvm::LLVMContext::MD_access_group, llvm::LLVMContext::MD_loop,
-      llvm::LLVMContext::MD_noalias,      llvm::LLVMContext::MD_alias_scope};
+      llvm::LLVMContext::MD_prof,
+      llvm::LLVMContext::MD_tbaa,
+      llvm::LLVMContext::MD_access_group,
+      llvm::LLVMContext::MD_loop,
+      llvm::LLVMContext::MD_noalias,
+      llvm::LLVMContext::MD_alias_scope,
+      context.getMDKindID(vecTypeHintMDName),
+      context.getMDKindID(workGroupSizeHintMDName),
+      context.getMDKindID(reqdWorkGroupSizeMDName),
+      context.getMDKindID(intelReqdSubGroupSizeMDName)};
   return convertibleMetadata;
 }
 
@@ -226,6 +239,128 @@ static LogicalResult setNoaliasScopesAttr(const llvm::MDNode *node,
   return success();
 }
 
+/// Extracts an integer from the provided metadata `md` if possible. Returns
+/// nullopt otherwise.
+static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
+  auto *constant = dyn_cast_if_present<llvm::ConstantAsMetadata>(md);
+  if (!constant)
+    return {};
+
+  auto *intConstant = dyn_cast<llvm::ConstantInt>(constant->getValue());
+  if (!intConstant)
+    return {};
+
+  return intConstant->getValue().getSExtValue();
+}
+
+/// Converts the provided metadata node `node` to an LLVM dialect
+/// VecTypeHintAttr if possible.
+static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *node,
+                                          ModuleImport &moduleImport) {
+  if (!node || node->getNumOperands() != 2)
+    return {};
+
+  auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(node->getOperand(0).get());
+  if (!hintMD)
+    return {};
+  TypeAttr hint = TypeAttr::get(moduleImport.convertType(hintMD->getType()));
+
+  std::optional<int32_t> optIsSigned =
+      parseIntegerMD(node->getOperand(1).get());
+  if (!optIsSigned)
+    return {};
+  bool isSigned = *optIsSigned != 0;
+
+  return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
+}
+
+/// Converts the provided metadata node `node` to an MLIR DenseI32ArrayAttr if
+/// possible.
+static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
+                                              llvm::MDNode *node) {
+  if (!node)
+    return {};
+  SmallVector<int32_t> vals;
+  for (const llvm::MDOperand &op : node->operands()) {
+    std::optional<int32_t> mdValue = parseIntegerMD(op.get());
+    if (!mdValue)
+      return {};
+    vals.push_back(*mdValue);
+  }
+  return builder.getDenseI32ArrayAttr(vals);
+}
+
+/// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
+static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *node) {
+  if (!node || node->getNumOperands() != 1)
+    return {};
+  std::optional<int32_t> val = parseIntegerMD(node->getOperand(0));
+  if (!val)
+    return {};
+  return builder.getI32IntegerAttr(*val);
+}
+
+static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
+                                        Operation *op,
+                                        LLVM::ModuleImport &moduleImport) {
+  auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+  if (!funcOp)
+    return failure();
+
+  VecTypeHintAttr attr = convertVecTypeHint(builder, node, moduleImport);
+  if (!attr)
+    return failure();
+
+  funcOp.setVecTypeHintAttr(attr);
+  return success();
+}
+
+static LogicalResult
+setWorkGroupSizeHintAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
+  auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+  if (!funcOp)
+    return failure();
+
+  DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
+  if (!attr)
+    return failure();
+
+  funcOp.setWorkGroupSizeHintAttr(attr);
+  return success();
+}
+
+static LogicalResult
+setReqdWorkGroupSizeAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
+  auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+  if (!funcOp)
+    return failure();
+
+  DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
+  if (!attr)
+    return failure();
+
+  funcOp.setReqdWorkGroupSizeAttr(attr);
+  return success();
+}
+
+/// Converts the given intel required subgroup size metadata node to an MLIR
+/// attribute and attaches it to the imported operation if the translation
+/// succeeds. Returns failure otherwise.
+static LogicalResult setIntelReqdSubGroupSizeAttr(Builder &builder,
+                                                  llvm::MDNode *node,
+                                                  Operation *op) {
+  auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
+  if (!funcOp)
+    return failure();
+
+  IntegerAttr attr = convertIntegerMD(builder, node);
+  if (!attr)
+    return failure();
+
+  funcOp.setIntelReqdSubGroupSizeAttr(attr);
+  return success();
+}
+
 namespace {
 
 /// Implementation of the dialect interface that converts operations belonging
@@ -261,6 +396,16 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
     if (kind == llvm::LLVMContext::MD_noalias)
       return setNoaliasScopesAttr(node, op, moduleImport);
 
+    llvm::LLVMContext &context = node->getContext();
+    if (kind == context.getMDKindID(vecTypeHintMDName))
+      return setVecTypeHintAttr(builder, node, op, moduleImport);
+    if (kind == context.getMDKindID(workGroupSizeHintMDName))
+      return setWorkGroupSizeHintAttr(builder, node, op);
+    if (kind == context.getMDKindID(reqdWorkGroupSizeMDName))
+      return setReqdWorkGroupSizeAttr(builder, node, op);
+    if (kind == context.getMDKindID(intelReqdSubGroupSizeMDName))
+      return setIntelReqdSubGroupSizeAttr(builder, node, op);
+
     // A handler for a supported metadata kind is missing.
     llvm_unreachable("unknown metadata type");
   }
@@ -273,8 +418,9 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
 
   /// Returns the list of LLVM IR metadata kinds that are convertible to MLIR
   /// LLVM dialect attributes.
-  ArrayRef<unsigned> getSupportedMetadata() const final {
-    return getSupportedMetadataImpl();
+  ArrayRef<unsigned>
+  getSupportedMetadata(llvm::LLVMContext &context) const final {
+    return getSupportedMetadataImpl(context);
   }
 };
 } // namespace

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 8b40b7b2df6c7..9e45cf352940b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -500,10 +500,10 @@ LogicalResult ModuleImport::convertLinkerOptionsMetadata() {
     if (named.getName() != "llvm.linker.options")
       continue;
     // llvm.linker.options operands are lists of strings.
-    for (const llvm::MDNode *md : named.operands()) {
+    for (const llvm::MDNode *node : named.operands()) {
       SmallVector<StringRef> options;
-      options.reserve(md->getNumOperands());
-      for (const llvm::MDOperand &option : md->operands())
+      options.reserve(node->getNumOperands());
+      for (const llvm::MDOperand &option : node->operands())
         options.push_back(cast<llvm::MDString>(option)->getString());
       builder.create<LLVM::LinkerOptionsOp>(mlirModule.getLoc(),
                                             builder.getStrArrayAttr(options));

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 3016d1846e00f..b468228ea78b7 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1247,6 +1247,41 @@ static LogicalResult checkedAddLLVMFnAttribute(Location loc,
   return success();
 }
 
+/// Return a representation of `value` as metadata.
+static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context,
+                                                const llvm::APInt &value) {
+  llvm::Constant *constant = llvm::ConstantInt::get(context, value);
+  return llvm::ConstantAsMetadata::get(constant);
+}
+
+/// Return a representation of `value` as an MDNode.
+static llvm::MDNode *convertIntegerToMDNode(llvm::LLVMContext &context,
+                                            const llvm::APInt &value) {
+  return llvm::MDNode::get(context, convertIntegerToMetadata(context, value));
+}
+
+/// Return an MDNode encoding `vec_type_hint` metadata.
+static llvm::MDNode *convertVecTypeHintToMDNode(llvm::LLVMContext &context,
+                                                llvm::Type *type,
+                                                bool isSigned) {
+  llvm::Metadata *typeMD =
+      llvm::ConstantAsMetadata::get(llvm::UndefValue::get(type));
+  llvm::Metadata *isSignedMD =
+      convertIntegerToMetadata(context, llvm::APInt(32, isSigned ? 1 : 0));
+  return llvm::MDNode::get(context, {typeMD, isSignedMD});
+}
+
+/// Return an MDNode with a tuple given by the values in `values`.
+static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context,
+                                                 ArrayRef<int32_t> values) {
+  SmallVector<llvm::Metadata *> mdValues;
+  llvm::transform(
+      values, std::back_inserter(mdValues), [&context](int32_t value) {
+        return convertIntegerToMetadata(context, llvm::APInt(32, value));
+      });
+  return llvm::MDNode::get(context, mdValues);
+}
+
 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
 /// to be an array attribute containing either string attributes, treated as
@@ -1448,6 +1483,44 @@ static void convertFunctionAttributes(LLVMFuncOp func,
   convertFunctionMemoryAttributes(func, llvmFunc);
 }
 
+/// Converts function attributes from `func` and attaches them to `llvmFunc`.
+static void convertFunctionKernelAttributes(LLVMFuncOp func,
+                                            llvm::Function *llvmFunc,
+                                            ModuleTranslation &translation) {
+  llvm::LLVMContext &llvmContext = llvmFunc->getContext();
+
+  if (VecTypeHintAttr vecTypeHint = func.getVecTypeHintAttr()) {
+    Type type = vecTypeHint.getHint().getValue();
+    llvm::Type *llvmType = translation.convertType(type);
+    bool isSigned = vecTypeHint.getIsSigned();
+    llvmFunc->setMetadata(
+        func.getVecTypeHintAttrName(),
+        convertVecTypeHintToMDNode(llvmContext, llvmType, isSigned));
+  }
+
+  if (std::optional<ArrayRef<int32_t>> workGroupSizeHint =
+          func.getWorkGroupSizeHint()) {
+    llvmFunc->setMetadata(
+        func.getWorkGroupSizeHintAttrName(),
+        convertIntegerArrayToMDNode(llvmContext, *workGroupSizeHint));
+  }
+
+  if (std::optional<ArrayRef<int32_t>> reqdWorkGroupSize =
+          func.getReqdWorkGroupSize()) {
+    llvmFunc->setMetadata(
+        func.getReqdWorkGroupSizeAttrName(),
+        convertIntegerArrayToMDNode(llvmContext, *reqdWorkGroupSize));
+  }
+
+  if (std::optional<uint32_t> intelReqdSubGroupSize =
+          func.getIntelReqdSubGroupSize()) {
+    llvmFunc->setMetadata(
+        func.getIntelReqdSubGroupSizeAttrName(),
+        convertIntegerToMDNode(llvmContext,
+                               llvm::APInt(32, *intelReqdSubGroupSize)));
+  }
+}
+
 FailureOr<llvm::AttrBuilder>
 ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
                                          DictionaryAttr paramAttrs) {
@@ -1492,6 +1565,9 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
     // Convert function attributes.
     convertFunctionAttributes(function, llvmFunc);
 
+    // Convert function kernel attributes to metadata.
+    convertFunctionKernelAttributes(function, llvmFunc, *this);
+
     // Convert function_entry_count attribute to metadata.
     if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
       llvmFunc->setEntryCount(entryCount.value());

diff  --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 0e29a548de72f..40b4e49f08a3e 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 // RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-op-generic %s | FileCheck %s --check-prefix=GENERIC
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s -mlir-print-debuginfo | mlir-opt -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
+// RUN: mlir-opt -split-input-file -verify-diagnostics -mlir-print-debuginfo %s | mlir-opt -split-input-file -mlir-print-debuginfo | FileCheck %s --check-prefix=LOCINFO
 // RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-LLVM
 
 module {
@@ -432,3 +432,43 @@ module {
   // expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
   }) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
 }
+
+// -----
+
+// CHECK: @vec_type_hint()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32>
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: @vec_type_hint_signed()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: @vec_type_hint_signed_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: @vec_type_hint_float_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: @vec_type_hint_bfloat_vec()
+// CHECK-SAME: vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// -----
+
+// CHECK: @work_group_size_hint()
+// CHECK-SAME: work_group_size_hint = array<i32: 128, 128, 128>
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// -----
+
+// CHECK: @reqd_work_group_size_hint()
+// CHECK-SAME: reqd_work_group_size = array<i32: 128, 256, 128>
+llvm.func @reqd_work_group_size_hint() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// -----
+
+// CHECK: @intel_reqd_sub_group_size_hint()
+// CHECK-SAME: intel_reqd_sub_group_size = 32 : i32
+llvm.func @intel_reqd_sub_group_size_hint() attributes {llvm.intel_reqd_sub_group_size = 32 : i32}

diff  --git a/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
new file mode 100644
index 0000000000000..963e40348c795
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/metadata-kernel.ll
@@ -0,0 +1,34 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK:   llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+declare !vec_type_hint !0 void @vec_type_hint()
+
+; CHECK:   llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+declare !vec_type_hint !1 void @vec_type_hint_signed()
+
+; CHECK:   llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+declare !vec_type_hint !2 void @vec_type_hint_signed_vec()
+
+; CHECK:   llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+declare !vec_type_hint !3 void @vec_type_hint_float_vec()
+
+; CHECK:   llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+declare !vec_type_hint !4 void @vec_type_hint_bfloat_vec()
+
+; CHECK:   llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+declare !work_group_size_hint !5 void @work_group_size_hint()
+
+; CHECK:   llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+declare !reqd_work_group_size !6 void @reqd_work_group_size()
+
+; CHECK:   llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+declare !intel_reqd_sub_group_size !7 void @intel_reqd_sub_group_size()
+
+!0 = !{i32 undef, i32 0}
+!1 = !{i32 undef, i32 1}
+!2 = !{<2 x i32> undef, i32 1}
+!3 = !{<3 x float> undef, i32 0}
+!4 = !{<8 x bfloat> undef, i32 0}
+!5 = !{i32 128, i32 128, i32 128}
+!6 = !{i32 128, i32 256, i32 128}
+!7 = !{i32 32}

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 82256f753abdd..fbdf725f3ec17 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2552,3 +2552,47 @@ llvm.func @mem_effects_call() {
 // CHECK-SAME: memory(read, inaccessiblemem: write)
 // CHECK: #[[ATTRS_3]]
 // CHECK-SAME: memory(readwrite, argmem: read)
+
+// -----
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT:]] void @vec_type_hint()
+llvm.func @vec_type_hint() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED:]] void @vec_type_hint_signed()
+llvm.func @vec_type_hint_signed() attributes {vec_type_hint = #llvm.vec_type_hint<hint = i32, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_SIGNED_VEC:]] void @vec_type_hint_signed_vec()
+llvm.func @vec_type_hint_signed_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<2xi32>, is_signed = true>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_FVEC:]] void @vec_type_hint_float_vec()
+llvm.func @vec_type_hint_float_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<3xf32>>}
+
+// CHECK: declare !vec_type_hint ![[#VEC_TYPE_HINT_BFVEC:]] void @vec_type_hint_bfloat_vec()
+llvm.func @vec_type_hint_bfloat_vec() attributes {vec_type_hint = #llvm.vec_type_hint<hint = vector<8xbf16>>}
+
+// CHECK: ![[#VEC_TYPE_HINT]] = !{i32 undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED]] = !{i32 undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_SIGNED_VEC]] = !{<2 x i32> undef, i32 1}
+// CHECK: ![[#VEC_TYPE_HINT_FVEC]] = !{<3 x float> undef, i32 0}
+// CHECK: ![[#VEC_TYPE_HINT_BFVEC]] = !{<8 x bfloat> undef, i32 0}
+
+// -----
+
+// CHECK: declare !work_group_size_hint ![[#WORK_GROUP_SIZE_HINT:]] void @work_group_size_hint()
+llvm.func @work_group_size_hint() attributes {work_group_size_hint = array<i32: 128, 128, 128>}
+
+// CHECK: ![[#WORK_GROUP_SIZE_HINT]] = !{i32 128, i32 128, i32 128}
+
+// -----
+
+// CHECK: declare !reqd_work_group_size ![[#REQD_WORK_GROUP_SIZE:]] void @reqd_work_group_size()
+llvm.func @reqd_work_group_size() attributes {reqd_work_group_size = array<i32: 128, 256, 128>}
+
+// CHECK: ![[#REQD_WORK_GROUP_SIZE]] = !{i32 128, i32 256, i32 128}
+
+// -----
+
+// CHECK: declare !intel_reqd_sub_group_size ![[#INTEL_REQD_SUB_GROUP_SIZE:]] void @intel_reqd_sub_group_size()
+llvm.func @intel_reqd_sub_group_size() attributes {intel_reqd_sub_group_size = 32 : i32}
+
+// CHECK: ![[#INTEL_REQD_SUB_GROUP_SIZE]] = !{i32 32}


        


More information about the Mlir-commits mailing list