[Mlir-commits] [mlir] [mlir][llvm] Add llvm.target_features features attribute (PR #71510)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Nov 28 08:46:31 PST 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/71510

>From 0501976205156fba01f07d8f984568a61c29fb51 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 7 Nov 2023 10:38:52 +0000
Subject: [PATCH 1/3] [mlir][llvm] Add llvm.target_features features attribute

This patch adds a target_features (TargetFeaturesAttr) to the LLVM
dialect to allow setting and querying the features in use on a function.

The features are stored as a sorted list rather plain string to allow
efficiently checking a function's features.

The motivation for this comes from the Arm SME dialect where we would
like a convenient way to check what variants of an operation are
available based on the CPU features.

Intended usage:

The target_features attribute is populated manually or by a pass:

```mlir
func.func @example() attributes {
   target_features = #llvm.target_features<"+sme,+sve,+sme-f64f64">
} {
 // ...
}
```

Then within a later rewrite the attribute can be checked, and used to
make lowering decisions.

```c++
// Finds the "target_features" attribute on the parent
// FunctionOpInterface.
auto targetFeatures = LLVM::TargetFeaturesAttr::featuresAt(op);

// Check a feature.
// Returns false if targetFeatures is null or the feature is not in
// the list.
if (!targetFeatures.contains("+sme-f64f64"))
    return failure();
```

For now, this is rather simple just checks if the exact feature is in
the list, though it could be possible to extend with implied features
using information from LLVM.
---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 53 ++++++++++++++
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h  | 25 +++++++
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  3 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      | 69 +++++++++++++++++++
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  6 ++
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  3 +
 .../Target/LLVMIR/Import/target-features.ll   |  9 +++
 mlir/test/Target/LLVMIR/target-features.mlir  |  7 ++
 8 files changed, 174 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Target/LLVMIR/Import/target-features.ll
 create mode 100644 mlir/test/Target/LLVMIR/target-features.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 9e35bf1ba977725..651e055b00a1bb3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -871,4 +871,57 @@ def LLVM_VScaleRangeAttr : LLVM_Attr<"VScaleRange", "vscale_range"> {
     "IntegerAttr":$maxRange);
   let assemblyFormat = "`<` struct(params) `>`";
 }
+
+//===----------------------------------------------------------------------===//
+// TargetFeaturesAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features"> {
+  let summary = "LLVM target features attribute";
+
+  let description = [{
+    Represents the LLVM target features in a manner that is efficient to query.
+
+    Example:
+    ```mlir
+    #llvm.target_features<"+sme,+sve,+sme-f64f64">
+    ```
+
+    Then within a pass or rewrite the features active at an op can be queried:
+
+    ```c++
+    auto targetFeatures = LLVM::TargetFeaturesAttr::featuresAt(op);
+
+    if (!targetFeatures.contains("+sme-f64f64"))
+      return failure();
+    ```
+  }];
+
+  let parameters = (ins
+    ArrayRefOfSelfAllocationParameter<"TargetFeature", "">: $features);
+
+  let builders = [
+    TypeBuilder<(ins "::llvm::ArrayRef<TargetFeature>": $features)>,
+    TypeBuilder<(ins "StringRef": $features)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Checks if a feature is contained within the features list.
+    bool contains(TargetFeature) const;
+    bool contains(StringRef feature) const {
+      return contains(TargetFeature{feature});
+    }
+
+    /// Returns the list of features as an LLVM-compatible string.
+    std::string getFeaturesString() const;
+
+    /// Finds the target features on the parent FunctionOpInterface.
+    /// Note: This assumes the attribute is called "target_features".
+    static TargetFeaturesAttr featuresAt(Operation* op);
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let skipDefaultBuilders = 1;
+}
+
 #endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index c370bfa2b733d65..68a3d1f74175bda 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -74,6 +74,31 @@ class TBAANodeAttr : public Attribute {
   }
 };
 
+/// This struct represents a single LLVM target feature.
+struct TargetFeature {
+  StringRef feature;
+
+  // Support allocating this struct into MLIR storage to ensure the feature
+  // string remains valid.
+  TargetFeature allocateInto(TypeStorageAllocator &alloc) const {
+    return TargetFeature{alloc.copyInto(feature)};
+  }
+
+  operator StringRef() const { return feature; }
+
+  bool operator==(TargetFeature const &other) const {
+    return feature == other.feature;
+  }
+
+  bool operator<(TargetFeature const &other) const {
+    return feature < other.feature;
+  }
+};
+
+inline llvm::hash_code hash_value(const TargetFeature &feature) {
+  return llvm::hash_value(feature.feature);
+}
+
 // Inline the LLVM generated Linkage enum and utility.
 // This is only necessary to isolate the "enum generated code" from the
 // attribute definition itself.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 88f4f81735372b9..cfaa75a9253b3b2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1391,7 +1391,8 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OptionalAttr<StrAttr>:$section,
     OptionalAttr<UnnamedAddr>:$unnamed_addr,
     OptionalAttr<I64Attr>:$alignment,
-    OptionalAttr<LLVM_VScaleRangeAttr>:$vscale_range
+    OptionalAttr<LLVM_VScaleRangeAttr>:$vscale_range,
+    OptionalAttr<LLVM_TargetFeaturesAttr>:$target_features
   );
 
   let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index 3d45ab8fac4d705..e97c30487210cd8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -14,10 +14,13 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/BinaryFormat/Dwarf.h"
+#include <algorithm>
 #include <optional>
+#include <set>
 
 using namespace mlir;
 using namespace mlir::LLVM;
@@ -109,3 +112,69 @@ bool MemoryEffectsAttr::isReadWrite() {
     return false;
   return true;
 }
+
+//===----------------------------------------------------------------------===//
+// TargetFeaturesAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TargetFeaturesAttr::parse(mlir::AsmParser &parser, mlir::Type) {
+  std::string targetFeatures;
+  if (parser.parseLess() || parser.parseString(&targetFeatures) ||
+      parser.parseGreater())
+    return {};
+  return get(parser.getContext(), targetFeatures);
+}
+
+void TargetFeaturesAttr::print(mlir::AsmPrinter &printer) const {
+  printer << "<\"";
+  llvm::interleave(
+      getFeatures(), printer,
+      [&](auto &feature) { printer << StringRef(feature); }, ",");
+  printer << "\">";
+}
+
+TargetFeaturesAttr
+TargetFeaturesAttr::get(MLIRContext *context,
+                        llvm::ArrayRef<TargetFeature> featuresRef) {
+  // Sort and de-duplicate target features.
+  std::set<TargetFeature> features(featuresRef.begin(), featuresRef.end());
+  return Base::get(context, llvm::to_vector(features));
+}
+
+TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
+                                           StringRef targetFeatures) {
+  SmallVector<StringRef> features;
+  StringRef{targetFeatures}.split(features, ',', /*MaxSplit=*/-1,
+                                  /*KeepEmpty=*/false);
+  return get(context, llvm::map_to_vector(features, [](StringRef feature) {
+               return TargetFeature{feature};
+             }));
+}
+
+bool TargetFeaturesAttr::contains(TargetFeature feature) const {
+  if (!bool(*this))
+    return false; // Allows checking null target features.
+  ArrayRef<TargetFeature> features = getFeatures();
+  // Note: The attribute getter ensures the feature list is sorted.
+  return std::binary_search(features.begin(), features.end(), feature);
+}
+
+std::string TargetFeaturesAttr::getFeaturesString() const {
+  std::string features;
+  bool first = true;
+  for (TargetFeature feature : getFeatures()) {
+    if (!first)
+      features += ",";
+    features += StringRef(feature);
+    first = false;
+  }
+  return features;
+}
+
+TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) {
+  auto parentFunction = op->getParentOfType<FunctionOpInterface>();
+  if (!parentFunction)
+    return {};
+  return parentFunction.getOperation()->getAttrOfType<TargetFeaturesAttr>(
+      "target_features");
+}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 9cdc1f45d38fa59..da42bf0147b2a34 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1614,6 +1614,7 @@ static constexpr std::array ExplicitAttributes{
     StringLiteral("aarch64_pstate_sm_body"),
     StringLiteral("aarch64_pstate_za_new"),
     StringLiteral("vscale_range"),
+    StringLiteral("target-features"),
 };
 
 static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
@@ -1694,6 +1695,11 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
         context, IntegerAttr::get(intTy, attr.getVScaleRangeMin()),
         IntegerAttr::get(intTy, attr.getVScaleRangeMax().value_or(0))));
   }
+  if (llvm::Attribute attr = func->getFnAttribute("target-features");
+      attr.isStringAttribute()) {
+    funcOp.setTargetFeaturesAttr(
+        LLVM::TargetFeaturesAttr::get(context, attr.getValueAsString()));
+  }
 }
 
 DictionaryAttr
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 322843e65627603..4ec5dfc02d3dddf 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -942,6 +942,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   if (func.getArmNewZa())
     llvmFunc->addFnAttr("aarch64_pstate_za_new");
 
+  if (auto targetFeatures = func.getTargetFeatures())
+    llvmFunc->addFnAttr("target-features", targetFeatures->getFeaturesString());
+
   if (auto attr = func.getVscaleRange())
     llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
         getLLVMContext(), attr->getMinRange().getInt(),
diff --git a/mlir/test/Target/LLVMIR/Import/target-features.ll b/mlir/test/Target/LLVMIR/Import/target-features.ll
new file mode 100644
index 000000000000000..39e9a1204d3e022
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/target-features.ll
@@ -0,0 +1,9 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: llvm.func @target_features()
+; CHECK-SAME: #llvm.target_features<"+sme,+sme-f64f64,+sve">
+define void @target_features() #0 {
+  ret void
+}
+
+attributes #0 = { "target-features"="+sme,+sme-f64f64,+sve" }
diff --git a/mlir/test/Target/LLVMIR/target-features.mlir b/mlir/test/Target/LLVMIR/target-features.mlir
new file mode 100644
index 000000000000000..02c07d27ca3cd84
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/target-features.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @target_features
+// CHECK: attributes #{{.*}} = { "target-features"="+sme,+sme-f64f64,+sve" }
+llvm.func @target_features() attributes { target_features = #llvm.target_features<"+sme,+sve,+sme-f64f64"> } {
+  llvm.return
+}

>From 2f2892de6fb104a498d10eeb2a5183c6287a6425 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 7 Nov 2023 12:21:26 +0000
Subject: [PATCH 2/3] Fixups

---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 11 +++++++---
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      | 20 ++++++++-----------
 2 files changed, 16 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 651e055b00a1bb3..ef912ab9fad69ea 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -898,11 +898,11 @@ def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features"> {
   }];
 
   let parameters = (ins
-    ArrayRefOfSelfAllocationParameter<"TargetFeature", "">: $features);
+    ArrayRefOfSelfAllocationParameter<"TargetFeature", "">:$features);
 
   let builders = [
-    TypeBuilder<(ins "::llvm::ArrayRef<TargetFeature>": $features)>,
-    TypeBuilder<(ins "StringRef": $features)>
+    TypeBuilder<(ins "::llvm::ArrayRef<TargetFeature>":$features)>,
+    TypeBuilder<(ins "::llvm::StringRef":$features)>
   ];
 
   let extraClassDeclaration = [{
@@ -918,6 +918,11 @@ def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features"> {
     /// Finds the target features on the parent FunctionOpInterface.
     /// Note: This assumes the attribute is called "target_features".
     static TargetFeaturesAttr featuresAt(Operation* op);
+
+    /// Canonical name for this attribute within MLIR.
+    static constexpr StringLiteral attributeName() {
+      return StringLiteral("target_features");
+    }
   }];
 
   let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index e97c30487210cd8..f510f949f9cfe4b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -144,8 +144,8 @@ TargetFeaturesAttr::get(MLIRContext *context,
 TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
                                            StringRef targetFeatures) {
   SmallVector<StringRef> features;
-  StringRef{targetFeatures}.split(features, ',', /*MaxSplit=*/-1,
-                                  /*KeepEmpty=*/false);
+  targetFeatures.split(features, ',', /*MaxSplit=*/-1,
+                       /*KeepEmpty=*/false);
   return get(context, llvm::map_to_vector(features, [](StringRef feature) {
                return TargetFeature{feature};
              }));
@@ -160,15 +160,11 @@ bool TargetFeaturesAttr::contains(TargetFeature feature) const {
 }
 
 std::string TargetFeaturesAttr::getFeaturesString() const {
-  std::string features;
-  bool first = true;
-  for (TargetFeature feature : getFeatures()) {
-    if (!first)
-      features += ",";
-    features += StringRef(feature);
-    first = false;
-  }
-  return features;
+  std::string featuresString;
+  llvm::raw_string_ostream ss(featuresString);
+  llvm::interleave(
+      getFeatures(), ss, [&](auto &feature) { ss << StringRef(feature); }, ",");
+  return ss.str();
 }
 
 TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) {
@@ -176,5 +172,5 @@ TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) {
   if (!parentFunction)
     return {};
   return parentFunction.getOperation()->getAttrOfType<TargetFeaturesAttr>(
-      "target_features");
+      attributeName());
 }

>From 17b30172ff2886f2cdbf8b94d135b95a511ec3ce Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 8 Nov 2023 10:28:53 +0000
Subject: [PATCH 3/3] Switch to using a list of StringAttr

---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 29 ++++----
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h  | 25 -------
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      | 68 +++++++++----------
 .../Target/LLVMIR/Import/target-features.ll   |  2 +-
 mlir/test/Target/LLVMIR/llvmir-invalid.mlir   | 21 ++++++
 mlir/test/Target/LLVMIR/target-features.mlir  |  6 +-
 6 files changed, 77 insertions(+), 74 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index ef912ab9fad69ea..59542e1932fbbae 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -876,15 +876,17 @@ def LLVM_VScaleRangeAttr : LLVM_Attr<"VScaleRange", "vscale_range"> {
 // TargetFeaturesAttr
 //===----------------------------------------------------------------------===//
 
-def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features"> {
+def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features">
+{
   let summary = "LLVM target features attribute";
 
   let description = [{
-    Represents the LLVM target features in a manner that is efficient to query.
+    Represents the LLVM target features as a list that can be checked within
+    passes/rewrites.
 
     Example:
     ```mlir
-    #llvm.target_features<"+sme,+sve,+sme-f64f64">
+    #llvm.target_features<["+sme", "+sve", "+sme-f64f64"]>
     ```
 
     Then within a pass or rewrite the features active at an op can be queried:
@@ -897,19 +899,22 @@ def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features"> {
     ```
   }];
 
-  let parameters = (ins
-    ArrayRefOfSelfAllocationParameter<"TargetFeature", "">:$features);
+  let parameters = (ins OptionalArrayRefParameter<"StringAttr">:$features);
 
   let builders = [
-    TypeBuilder<(ins "::llvm::ArrayRef<TargetFeature>":$features)>,
-    TypeBuilder<(ins "::llvm::StringRef":$features)>
+    TypeBuilder<(ins "::llvm::StringRef":$features)>,
+    TypeBuilder<(ins "::llvm::ArrayRef<::llvm::StringRef>":$features)>
   ];
 
   let extraClassDeclaration = [{
     /// Checks if a feature is contained within the features list.
-    bool contains(TargetFeature) const;
-    bool contains(StringRef feature) const {
-      return contains(TargetFeature{feature});
+    /// Note: Using a StringAttr allows doing pointer-comparisons.
+    bool contains(::mlir::StringAttr feature) const;
+    bool contains(::llvm::StringRef feature) const;
+
+    bool nullOrEmpty() const {
+      // Checks if this attribute is fasly, or the features are empty.
+      return !bool(*this) || getFeatures().empty();
     }
 
     /// Returns the list of features as an LLVM-compatible string.
@@ -925,8 +930,8 @@ def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features"> {
     }
   }];
 
-  let hasCustomAssemblyFormat = 1;
-  let skipDefaultBuilders = 1;
+  let assemblyFormat = "`<` `[` (`]`) : ($features^ `]`)? `>`";
+  let genVerifyDecl = 1;
 }
 
 #endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
index 68a3d1f74175bda..c370bfa2b733d65 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h
@@ -74,31 +74,6 @@ class TBAANodeAttr : public Attribute {
   }
 };
 
-/// This struct represents a single LLVM target feature.
-struct TargetFeature {
-  StringRef feature;
-
-  // Support allocating this struct into MLIR storage to ensure the feature
-  // string remains valid.
-  TargetFeature allocateInto(TypeStorageAllocator &alloc) const {
-    return TargetFeature{alloc.copyInto(feature)};
-  }
-
-  operator StringRef() const { return feature; }
-
-  bool operator==(TargetFeature const &other) const {
-    return feature == other.feature;
-  }
-
-  bool operator<(TargetFeature const &other) const {
-    return feature < other.feature;
-  }
-};
-
-inline llvm::hash_code hash_value(const TargetFeature &feature) {
-  return llvm::hash_value(feature.feature);
-}
-
 // Inline the LLVM generated Linkage enum and utility.
 // This is only necessary to isolate the "enum generated code" from the
 // attribute definition itself.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index f510f949f9cfe4b..eda28e616e94a60 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -18,9 +18,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/BinaryFormat/Dwarf.h"
-#include <algorithm>
 #include <optional>
-#include <set>
 
 using namespace mlir;
 using namespace mlir::LLVM;
@@ -117,28 +115,12 @@ bool MemoryEffectsAttr::isReadWrite() {
 // TargetFeaturesAttr
 //===----------------------------------------------------------------------===//
 
-Attribute TargetFeaturesAttr::parse(mlir::AsmParser &parser, mlir::Type) {
-  std::string targetFeatures;
-  if (parser.parseLess() || parser.parseString(&targetFeatures) ||
-      parser.parseGreater())
-    return {};
-  return get(parser.getContext(), targetFeatures);
-}
-
-void TargetFeaturesAttr::print(mlir::AsmPrinter &printer) const {
-  printer << "<\"";
-  llvm::interleave(
-      getFeatures(), printer,
-      [&](auto &feature) { printer << StringRef(feature); }, ",");
-  printer << "\">";
-}
-
-TargetFeaturesAttr
-TargetFeaturesAttr::get(MLIRContext *context,
-                        llvm::ArrayRef<TargetFeature> featuresRef) {
-  // Sort and de-duplicate target features.
-  std::set<TargetFeature> features(featuresRef.begin(), featuresRef.end());
-  return Base::get(context, llvm::to_vector(features));
+TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
+                                           llvm::ArrayRef<StringRef> features) {
+  return Base::get(context,
+                   llvm::map_to_vector(features, [&](StringRef feature) {
+                     return StringAttr::get(context, feature);
+                   }));
 }
 
 TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
@@ -146,24 +128,42 @@ TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
   SmallVector<StringRef> features;
   targetFeatures.split(features, ',', /*MaxSplit=*/-1,
                        /*KeepEmpty=*/false);
-  return get(context, llvm::map_to_vector(features, [](StringRef feature) {
-               return TargetFeature{feature};
-             }));
+  return get(context, features);
 }
 
-bool TargetFeaturesAttr::contains(TargetFeature feature) const {
-  if (!bool(*this))
-    return false; // Allows checking null target features.
-  ArrayRef<TargetFeature> features = getFeatures();
-  // Note: The attribute getter ensures the feature list is sorted.
-  return std::binary_search(features.begin(), features.end(), feature);
+LogicalResult
+TargetFeaturesAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                           llvm::ArrayRef<StringAttr> features) {
+  for (StringAttr featureAttr : features) {
+    if (!featureAttr || featureAttr.empty())
+      return emitError() << "target features can not be null or empty";
+    auto feature = featureAttr.strref();
+    if (feature[0] != '+' && feature[0] != '-')
+      return emitError() << "target features must start with '+' or '-'";
+    if (feature.contains(','))
+      return emitError() << "target features can not contain ','";
+  }
+  return success();
+}
+
+bool TargetFeaturesAttr::contains(StringAttr feature) const {
+  if (nullOrEmpty())
+    return false;
+  // Note: Using StringAttr does pointer comparisons.
+  return llvm::is_contained(getFeatures(), feature);
+}
+
+bool TargetFeaturesAttr::contains(StringRef feature) const {
+  if (nullOrEmpty())
+    return false;
+  return llvm::is_contained(getFeatures(), feature);
 }
 
 std::string TargetFeaturesAttr::getFeaturesString() const {
   std::string featuresString;
   llvm::raw_string_ostream ss(featuresString);
   llvm::interleave(
-      getFeatures(), ss, [&](auto &feature) { ss << StringRef(feature); }, ",");
+      getFeatures(), ss, [&](auto &feature) { ss << feature.strref(); }, ",");
   return ss.str();
 }
 
diff --git a/mlir/test/Target/LLVMIR/Import/target-features.ll b/mlir/test/Target/LLVMIR/Import/target-features.ll
index 39e9a1204d3e022..d3feeda85691a26 100644
--- a/mlir/test/Target/LLVMIR/Import/target-features.ll
+++ b/mlir/test/Target/LLVMIR/Import/target-features.ll
@@ -1,7 +1,7 @@
 ; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
 
 ; CHECK-LABEL: llvm.func @target_features()
-; CHECK-SAME: #llvm.target_features<"+sme,+sme-f64f64,+sve">
+; CHECK-SAME: #llvm.target_features<["+sme", "+sme-f64f64", "+sve"]>
 define void @target_features() #0 {
   ret void
 }
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index e531c3cb4e24093..7fba35d8eb81c7b 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -245,6 +245,27 @@ llvm.func @stepvector_intr_wrong_type() -> vector<7xf32> {
 
 // -----
 
+// expected-error @below{{target features can not contain ','}}
+llvm.func @invalid_target_feature() attributes { target_features = #llvm.target_features<["+bad,feature", "+test"]> }
+{
+}
+
+// -----
+
+// expected-error @below{{target features must start with '+' or '-'}}
+llvm.func @missing_target_feature_prefix() attributes { target_features = #llvm.target_features<["sme"]> }
+{
+}
+
+// -----
+
+// expected-error @below{{target features can not be null or empty}}
+llvm.func @empty_target_feature() attributes { target_features = #llvm.target_features<["", "+sve"]> }
+{
+}
+
+// -----
+
 llvm.comdat @__llvm_comdat {
   llvm.comdat_selector @foo any
 }
diff --git a/mlir/test/Target/LLVMIR/target-features.mlir b/mlir/test/Target/LLVMIR/target-features.mlir
index 02c07d27ca3cd84..7a69a2c78897809 100644
--- a/mlir/test/Target/LLVMIR/target-features.mlir
+++ b/mlir/test/Target/LLVMIR/target-features.mlir
@@ -1,7 +1,9 @@
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
 // CHECK-LABEL: define void @target_features
-// CHECK: attributes #{{.*}} = { "target-features"="+sme,+sme-f64f64,+sve" }
-llvm.func @target_features() attributes { target_features = #llvm.target_features<"+sme,+sve,+sme-f64f64"> } {
+// CHECK: attributes #{{.*}} = { "target-features"="+sme,+sve,+sme-f64f64" }
+llvm.func @target_features() attributes {
+  target_features = #llvm.target_features<["+sme", "+sve", "+sme-f64f64"]>
+} {
   llvm.return
 }



More information about the Mlir-commits mailing list