[Mlir-commits] [mlir] 17de468 - [mlir][llvm] Add llvm.target_features features attribute (#71510)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 5 03:29:35 PST 2023


Author: Benjamin Maxwell
Date: 2023-12-05T11:29:31Z
New Revision: 17de468df1af6479f31bb8c02973e01702f7b240

URL: https://github.com/llvm/llvm-project/commit/17de468df1af6479f31bb8c02973e01702f7b240
DIFF: https://github.com/llvm/llvm-project/commit/17de468df1af6479f31bb8c02973e01702f7b240.diff

LOG: [mlir][llvm] Add llvm.target_features features attribute (#71510)

This patch adds a target_features (TargetFeaturesAttr) to the LLVM
dialect to allow setting and querying the features in use on a function.

The motivation for this comes from the Arm SME dialect where we would
like a convenient way to check what variants of an operation are
available based on the CPU features.

Intended usage:

The target_features attribute is populated manually or by a pass:

```mlir
func.func @example() attributes {
   target_features = #llvm.target_features<["+sme", "+sve", "+sme-f64f64"]>
} {
 // ...
}
```

Then within a later rewrite the attribute can be checked, and used to
make lowering decisions.

```c++
// Finds the "target_features" attribute on the parent
// FunctionOpInterface.
auto targetFeatures = LLVM::TargetFeaturesAttr::featuresAt(op);

// Check a feature.
// Returns false if targetFeatures is null or the feature is not in
// the list.
if (!targetFeatures.contains("+sme-f64f64"))
    return failure();
```

For now, this is rather simple just checks if the exact feature is in
the list, though it could be possible to extend with implied features
using information from LLVM.

Added: 
    mlir/test/Target/LLVMIR/Import/target-features.ll
    mlir/test/Target/LLVMIR/target-features.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index acbf88807b10c..6975b18ab7f81 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -933,4 +933,68 @@ def LLVM_VScaleRangeAttr : LLVM_Attr<"VScaleRange", "vscale_range"> {
     "IntegerAttr":$maxRange);
   let assemblyFormat = "`<` struct(params) `>`";
 }
+
+//===----------------------------------------------------------------------===//
+// TargetFeaturesAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features">
+{
+  let summary = "LLVM target features attribute";
+
+  let description = [{
+    Represents the LLVM target features as a list that can be checked within
+    passes/rewrites.
+
+    Example:
+    ```mlir
+    #llvm.target_features<["+sme", "+sve", "+sme-f64f64"]>
+    ```
+
+    Then within a pass or rewrite the features active at an op can be queried:
+
+    ```c++
+    auto targetFeatures = LLVM::TargetFeaturesAttr::featuresAt(op);
+
+    if (!targetFeatures.contains("+sme-f64f64"))
+      return failure();
+    ```
+  }];
+
+  let parameters = (ins OptionalArrayRefParameter<"StringAttr">:$features);
+
+  let builders = [
+    TypeBuilder<(ins "::llvm::StringRef":$features)>,
+    TypeBuilder<(ins "::llvm::ArrayRef<::llvm::StringRef>":$features)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Checks if a feature is contained within the features list.
+    /// Note: Using a StringAttr allows doing pointer-comparisons.
+    bool contains(::mlir::StringAttr feature) const;
+    bool contains(::llvm::StringRef feature) const;
+
+    bool nullOrEmpty() const {
+      // Checks if this attribute is null, or the features are empty.
+      return !bool(*this) || getFeatures().empty();
+    }
+
+    /// Returns the list of features as an LLVM-compatible string.
+    std::string getFeaturesString() const;
+
+    /// Finds the target features on the parent FunctionOpInterface.
+    /// Note: This assumes the attribute name matches the return value of
+    /// `getAttributeName()`.
+    static TargetFeaturesAttr featuresAt(Operation* op);
+
+    /// Canonical name for this attribute within MLIR.
+    static constexpr StringLiteral getAttributeName() {
+      return StringLiteral("target_features");
+    }
+  }];
+
+  let assemblyFormat = "`<` `[` (`]`) : ($features^ `]`)? `>`";
+  let genVerifyDecl = 1;
+}
+
 #endif // LLVMIR_ATTRDEFS

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 83e71f071dec0..92460fa06f530 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1394,7 +1394,8 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OptionalAttr<UnnamedAddr>:$unnamed_addr,
     OptionalAttr<I64Attr>:$alignment,
     OptionalAttr<LLVM_VScaleRangeAttr>:$vscale_range,
-    OptionalAttr<FramePointerKindAttr>:$frame_pointer
+    OptionalAttr<FramePointerKindAttr>:$frame_pointer,
+    OptionalAttr<LLVM_TargetFeaturesAttr>:$target_features
   );
 
   let regions = (region AnyRegion:$body);

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index e2342670508ce..645a45dd96bef 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/BinaryFormat/Dwarf.h"
@@ -183,3 +184,67 @@ void printExpressionArg(AsmPrinter &printer, uint64_t opcode,
     i++;
   });
 }
+
+//===----------------------------------------------------------------------===//
+// TargetFeaturesAttr
+//===----------------------------------------------------------------------===//
+
+TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
+                                           llvm::ArrayRef<StringRef> features) {
+  return Base::get(context,
+                   llvm::map_to_vector(features, [&](StringRef feature) {
+                     return StringAttr::get(context, feature);
+                   }));
+}
+
+TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
+                                           StringRef targetFeatures) {
+  SmallVector<StringRef> features;
+  targetFeatures.split(features, ',', /*MaxSplit=*/-1,
+                       /*KeepEmpty=*/false);
+  return get(context, features);
+}
+
+LogicalResult
+TargetFeaturesAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                           llvm::ArrayRef<StringAttr> features) {
+  for (StringAttr featureAttr : features) {
+    if (!featureAttr || featureAttr.empty())
+      return emitError() << "target features can not be null or empty";
+    auto feature = featureAttr.strref();
+    if (feature[0] != '+' && feature[0] != '-')
+      return emitError() << "target features must start with '+' or '-'";
+    if (feature.contains(','))
+      return emitError() << "target features can not contain ','";
+  }
+  return success();
+}
+
+bool TargetFeaturesAttr::contains(StringAttr feature) const {
+  if (nullOrEmpty())
+    return false;
+  // Note: Using StringAttr does pointer comparisons.
+  return llvm::is_contained(getFeatures(), feature);
+}
+
+bool TargetFeaturesAttr::contains(StringRef feature) const {
+  if (nullOrEmpty())
+    return false;
+  return llvm::is_contained(getFeatures(), feature);
+}
+
+std::string TargetFeaturesAttr::getFeaturesString() const {
+  std::string featuresString;
+  llvm::raw_string_ostream ss(featuresString);
+  llvm::interleave(
+      getFeatures(), ss, [&](auto &feature) { ss << feature.strref(); }, ",");
+  return ss.str();
+}
+
+TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) {
+  auto parentFunction = op->getParentOfType<FunctionOpInterface>();
+  if (!parentFunction)
+    return {};
+  return parentFunction.getOperation()->getAttrOfType<TargetFeaturesAttr>(
+      getAttributeName());
+}

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index a7548ea9b9849..4bdffa572e31a 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1627,6 +1627,7 @@ static constexpr std::array ExplicitAttributes{
     StringLiteral("aarch64_pstate_za_new"),
     StringLiteral("vscale_range"),
     StringLiteral("frame-pointer"),
+    StringLiteral("target-features"),
 };
 
 static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
@@ -1717,6 +1718,12 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
                                  stringRefFramePointerKind)
                                  .value()));
   }
+
+  if (llvm::Attribute attr = func->getFnAttribute("target-features");
+      attr.isStringAttribute()) {
+    funcOp.setTargetFeaturesAttr(
+        LLVM::TargetFeaturesAttr::get(context, attr.getValueAsString()));
+  }
 }
 
 DictionaryAttr

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 5e94fa6acf7b9..d6afe354178d6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -968,6 +968,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   if (func.getArmNewZa())
     llvmFunc->addFnAttr("aarch64_pstate_za_new");
 
+  if (auto targetFeatures = func.getTargetFeatures())
+    llvmFunc->addFnAttr("target-features", targetFeatures->getFeaturesString());
+
   if (auto attr = func.getVscaleRange())
     llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
         getLLVMContext(), attr->getMinRange().getInt(),

diff  --git a/mlir/test/Target/LLVMIR/Import/target-features.ll b/mlir/test/Target/LLVMIR/Import/target-features.ll
new file mode 100644
index 0000000000000..d3feeda85691a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/target-features.ll
@@ -0,0 +1,9 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: llvm.func @target_features()
+; CHECK-SAME: #llvm.target_features<["+sme", "+sme-f64f64", "+sve"]>
+define void @target_features() #0 {
+  ret void
+}
+
+attributes #0 = { "target-features"="+sme,+sme-f64f64,+sve" }

diff  --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 0def5895fb330..521d94c45890a 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -261,6 +261,27 @@ llvm.func @stepvector_intr_wrong_type() -> vector<7xf32> {
 
 // -----
 
+// expected-error @below{{target features can not contain ','}}
+llvm.func @invalid_target_feature() attributes { target_features = #llvm.target_features<["+bad,feature", "+test"]> }
+{
+}
+
+// -----
+
+// expected-error @below{{target features must start with '+' or '-'}}
+llvm.func @missing_target_feature_prefix() attributes { target_features = #llvm.target_features<["sme"]> }
+{
+}
+
+// -----
+
+// expected-error @below{{target features can not be null or empty}}
+llvm.func @empty_target_feature() attributes { target_features = #llvm.target_features<["", "+sve"]> }
+{
+}
+
+// -----
+
 llvm.comdat @__llvm_comdat {
   llvm.comdat_selector @foo any
 }

diff  --git a/mlir/test/Target/LLVMIR/target-features.mlir b/mlir/test/Target/LLVMIR/target-features.mlir
new file mode 100644
index 0000000000000..7a69a2c788978
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/target-features.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @target_features
+// CHECK: attributes #{{.*}} = { "target-features"="+sme,+sve,+sme-f64f64" }
+llvm.func @target_features() attributes {
+  target_features = #llvm.target_features<["+sme", "+sve", "+sme-f64f64"]>
+} {
+  llvm.return
+}


        


More information about the Mlir-commits mailing list