[Mlir-commits] [mlir] Add support for MLIR to llvm vscale attribute (PR #67012)

Mats Petersson llvmlistbot at llvm.org
Mon Sep 25 03:51:53 PDT 2023


https://github.com/Leporacanthicus updated https://github.com/llvm/llvm-project/pull/67012

>From bf5dd93dfc94ace1bb8405f8e94ea9b1d42d91b8 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 21 Sep 2023 12:37:49 +0100
Subject: [PATCH] [MLIR] Add support for vscale_range to LLVM dialect

The vscale_range is used for scalabale vector functionality in
Arm Scalable Vector Extension to select the size of vector
operation (and I thnk RISCV has something similar).

This patch adds the base support for the vscale_range attribute to
the LLVM::FuncOp, and the marshalling for translation to LLVM-IR
and import from LLVM-IR to LLVM dialect.

This attribute is intended to be used at higher level MLIR, specified
either by command-line options to the compiler or using compiler
directives (e.g. pragmas or function attributes in the source code)
to indicate the desired range.
---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 10 +++++++++
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  3 ++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 21 ++++++++++++++++++-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 11 +++++++++-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  5 +++++
 mlir/test/Dialect/LLVMIR/func.mlir            |  6 ++++++
 mlir/test/Target/LLVMIR/Import/vscale.ll      |  7 +++++++
 mlir/test/Target/LLVMIR/vscale.mlir           |  7 +++++++
 8 files changed, 67 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/Import/vscale.ll
 create mode 100644 mlir/test/Target/LLVMIR/vscale.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 2443b23e42e43ce..4a9bb5355d1c7ea 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -850,4 +850,14 @@ def LLVM_TBAATagArrayAttr
   let constBuilderCall = ?;
 }
 
+//===----------------------------------------------------------------------===//
+// VScaleRangeAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_VScaleRangeAttr : LLVM_Attr<"VScaleRange", "vscale_range"> {
+  let parameters =  (ins
+    "IntegerAttr":$minRange,
+    "IntegerAttr":$maxRange);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
 #endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index e4486eb36e51a1f..10bc8afeefa1777 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1393,7 +1393,8 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OptionalAttr<UnitAttr>:$arm_locally_streaming,
     OptionalAttr<StrAttr>:$section,
     OptionalAttr<UnnamedAddr>:$unnamed_addr,
-    OptionalAttr<I64Attr>:$alignment
+    OptionalAttr<I64Attr>:$alignment,
+    OptionalAttr<LLVM_VScaleRangeAttr>:$vscale_range
   );
 
   let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2e0b2582eee9629..d5f2f47bc147e6c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2342,6 +2342,19 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(type));
 
+  if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
+    int64_t minRange, maxRange;
+    if (parser.parseLParen() || parser.parseInteger(minRange) ||
+        parser.parseComma() || parser.parseInteger(maxRange) ||
+        parser.parseRParen())
+      return failure();
+    auto intTy = IntegerType::get(parser.getContext(), 32);
+    result.addAttribute(
+        getVscaleRangeAttrName(result.name),
+        LLVM::VScaleRangeAttr::get(parser.getContext(),
+                                   IntegerAttr::get(intTy, minRange),
+                                   IntegerAttr::get(intTy, maxRange)));
+  }
   // Parse the optional comdat selector.
   if (succeeded(parser.parseOptionalKeyword("comdat"))) {
     SymbolRefAttr comdat;
@@ -2398,6 +2411,11 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
   function_interface_impl::printFunctionSignature(p, *this, argTypes,
                                                   isVarArg(), resTypes);
 
+  // Print vscale range if present
+  if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
+    p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
+      << vscale->getMaxRange().getInt() << ')';
+
   // Print the optional comdat selector.
   if (auto comdat = getComdat())
     p << " comdat(" << *comdat << ')';
@@ -2406,7 +2424,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
       p, *this,
       {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
        getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
-       getComdatAttrName(), getUnnamedAddrAttrName()});
+       getComdatAttrName(), getUnnamedAddrAttrName(),
+       getVscaleRangeAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = getBody();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 3d93332c4c56796..672d81bd3fd9129 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1569,8 +1569,9 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
 
     // Skip the aarch64_pstate_sm_<body|enabled> since the LLVMFuncOp has an
     // explicit attribute.
+    // Also skip the vscale_range, it is also an explicit attribute.
     if (attrName == "aarch64_pstate_sm_enabled" ||
-        attrName == "aarch64_pstate_sm_body")
+        attrName == "aarch64_pstate_sm_body" || attrName == "vscale_range")
       continue;
 
     if (attr.isStringAttribute()) {
@@ -1610,6 +1611,14 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
     funcOp.setArmStreaming(true);
   else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
     funcOp.setArmLocallyStreaming(true);
+  llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
+  if (attr.isValid()) {
+    MLIRContext *context = funcOp.getContext();
+    auto intTy = IntegerType::get(context, 32);
+    funcOp.setVscaleRangeAttr(LLVM::VScaleRangeAttr::get(
+        context, IntegerAttr::get(intTy, attr.getVScaleRangeMin()),
+        IntegerAttr::get(intTy, attr.getVScaleRangeMax().value_or(0))));
+  }
 }
 
 DictionaryAttr
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 33c85d85a684cc8..ee73b04e020fd26 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -907,6 +907,11 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   else if (func.getArmLocallyStreaming())
     llvmFunc->addFnAttr("aarch64_pstate_sm_body");
 
+  if (auto attr = func.getVscaleRange())
+    llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
+        getLLVMContext(), attr->getMinRange().getInt(),
+        attr->getMaxRange().getInt()));
+
   // First, create all blocks so we can jump to them.
   llvm::LLVMContext &llvmContext = llvmFunc->getContext();
   for (auto &bb : func) {
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index c0da9fc40e16faa..b45e6c4ef897b10 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -225,6 +225,12 @@ module {
   llvm.func @any() comdat(@__llvm_comdat::@any) attributes { dso_local } {
     llvm.return
   }
+
+  llvm.func @vscale_roundtrip() vscale_range(1, 2) {
+    // CHECK: @vscale_roundtrip
+    // CHECK-SAME: vscale_range(1, 2)
+    llvm.return
+  }
 }
 
 // -----
diff --git a/mlir/test/Target/LLVMIR/Import/vscale.ll b/mlir/test/Target/LLVMIR/Import/vscale.ll
new file mode 100644
index 000000000000000..48a52e949bda36a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/vscale.ll
@@ -0,0 +1,7 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+define void @vscale_func() vscale_range(2,8) {
+  ; CHECK: llvm.func @vscale_func()
+  ; CHECK-SAME: vscale_range(2, 8)
+  ret void
+}
diff --git a/mlir/test/Target/LLVMIR/vscale.mlir b/mlir/test/Target/LLVMIR/vscale.mlir
new file mode 100644
index 000000000000000..644cde73ca89677
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/vscale.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @vscale_func() vscale_range(2,8) {
+  // CHECK-LABEL: define void @vscale_func
+  // CHECK: attributes #{{.*}} = { vscale_range(2,8) }
+  llvm.return
+}



More information about the Mlir-commits mailing list