[Mlir-commits] [mlir] Add support for MLIR to llvm vscale attribute (PR #67012)

Mats Petersson llvmlistbot at llvm.org
Thu Sep 21 06:13:39 PDT 2023


https://github.com/Leporacanthicus created https://github.com/llvm/llvm-project/pull/67012

The vscale_range is used for scalabale vector functionality in Arm Scalable Vector Extension to select the size of vector operation (and I thnk RISCV has something similar).

This patch adds the base support for the vscale_range attribute to the LLVM::FuncOp, and the marshalling for translation to LLVM-IR and import from LLVM-IR to LLVM dialect.

This attribute is intended to be used at higher level MLIR, specified either by command-line options to the compiler or using compiler directives (e.g. pragmas or function attributes in the source code) to indicate the desired range.

>From 03763717bcf879cf4ead297ead070259ee6cdde1 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 21 Sep 2023 12:37:49 +0100
Subject: [PATCH] [MLIR] Add support for vscale_range to LLVM dialect

The vscale_range is used for scalabale vector functionality in
Arm Scalable Vector Extension to select the size of vector
operation (and I thnk RISCV has something similar).

This patch adds the base support for the vscale_range attribute to
the LLVM::FuncOp, and the marshalling for translation to LLVM-IR
and import from LLVM-IR to LLVM dialect.

This attribute is intended to be used at higher level MLIR, specified
either by command-line options to the compiler or using compiler
directives (e.g. pragmas or function attributes in the source code)
to indicate the desired range.
---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       |  7 +++++++
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  3 ++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 21 ++++++++++++++++++-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 14 ++++++++++++-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  6 ++++++
 mlir/test/Target/LLVMIR/Import/vscale.ll      | 10 +++++++++
 mlir/test/Target/LLVMIR/vscale.mlir           |  9 ++++++++
 7 files changed, 67 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/Import/vscale.ll
 create mode 100644 mlir/test/Target/LLVMIR/vscale.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 2443b23e42e43ce..322451334f2c5d0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -850,4 +850,11 @@ def LLVM_TBAATagArrayAttr
   let constBuilderCall = ?;
 }
 
+//--- VScaleRange Attribute
+
+def VScaleRangeAttr : LLVM_Attr<"VScaleRange", "vscale_range"> {
+    let parameters =  (ins  "IntegerAttr":$minRange,
+                       "IntegerAttr":$maxRange);
+    let assemblyFormat = "`<` struct(params) `>`";
+}
 #endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index c0216d1971e58d2..345e9b072737e76 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1393,7 +1393,8 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OptionalAttr<UnitAttr>:$arm_locally_streaming,
     OptionalAttr<StrAttr>:$section,
     OptionalAttr<UnnamedAddr>:$unnamed_addr,
-    OptionalAttr<I64Attr>:$alignment
+    OptionalAttr<I64Attr>:$alignment,
+    OptionalAttr<VScaleRangeAttr>:$vscale_range
   );
 
   let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 1cf91bde28183ac..22d6d1ccc9d429f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2342,6 +2342,19 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(type));
 
+  if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
+    int64_t minRange, maxRange;
+    if (parser.parseLParen() || parser.parseInteger(minRange) ||
+        parser.parseComma() || parser.parseInteger(maxRange) ||
+        parser.parseRParen())
+      return failure();
+    auto intTy = mlir::IntegerType::get(parser.getContext(), 32);
+    result.addAttribute(getVscaleRangeAttrName(result.name),
+                        mlir::LLVM::VScaleRangeAttr::get(
+                            parser.getContext(),
+                            mlir::IntegerAttr::get(intTy, minRange),
+                            mlir::IntegerAttr::get(intTy, maxRange)));
+  }
   // Parse the optional comdat selector.
   if (succeeded(parser.parseOptionalKeyword("comdat"))) {
     SymbolRefAttr comdat;
@@ -2398,6 +2411,11 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
   function_interface_impl::printFunctionSignature(p, *this, argTypes,
                                                   isVarArg(), resTypes);
 
+  // Print vscale range if present
+  if (auto vscale = getVscaleRange())
+    p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
+      << vscale->getMaxRange().getInt() << ')';
+
   // Print the optional comdat selector.
   if (auto comdat = getComdat())
     p << " comdat(" << *comdat << ')';
@@ -2406,7 +2424,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
       p, *this,
       {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
        getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
-       getComdatAttrName(), getUnnamedAddrAttrName()});
+       getComdatAttrName(), getUnnamedAddrAttrName(),
+       getVscaleRangeAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = getBody();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index c6c30880d4f2c15..f462a6a5f94fa45 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -31,6 +31,7 @@
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/Comdat.h"
+#include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/InstIterator.h"
@@ -1569,8 +1570,9 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
 
     // Skip the aarch64_pstate_sm_<body|enabled> since the LLVMFuncOp has an
     // explicit attribute.
+    // Also skip the vscale_range, it is also an explicit attribute.
     if (attrName == "aarch64_pstate_sm_enabled" ||
-        attrName == "aarch64_pstate_sm_body")
+        attrName == "aarch64_pstate_sm_body" || attrName == "vscale_range")
       continue;
 
     if (attr.isStringAttribute()) {
@@ -1610,6 +1612,16 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
     funcOp.setArmStreaming(true);
   else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
     funcOp.setArmLocallyStreaming(true);
+  llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
+  if (attr.isValid()) {
+    MLIRContext *context = funcOp.getContext();
+    auto intTy = mlir::IntegerType::get(context, 32);
+    std::optional<unsigned> maxAttr = attr.getVScaleRangeMax();
+    unsigned maxVal = maxAttr ? *maxAttr : 0;
+    funcOp.setVscaleRangeAttr(mlir::LLVM::VScaleRangeAttr::get(
+        context, mlir::IntegerAttr::get(intTy, attr.getVScaleRangeMin()),
+        mlir::IntegerAttr::get(intTy, maxVal)));
+  }
 }
 
 DictionaryAttr
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 33c85d85a684cc8..9e2a40c0d9ed795 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -907,6 +907,12 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   else if (func.getArmLocallyStreaming())
     llvmFunc->addFnAttr("aarch64_pstate_sm_body");
 
+  if (auto attr = func.getVscaleRange()) {
+    llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
+        getLLVMContext(), attr->getMinRange().getInt(),
+        attr->getMaxRange().getInt()));
+  }
+
   // First, create all blocks so we can jump to them.
   llvm::LLVMContext &llvmContext = llvmFunc->getContext();
   for (auto &bb : func) {
diff --git a/mlir/test/Target/LLVMIR/Import/vscale.ll b/mlir/test/Target/LLVMIR/Import/vscale.ll
new file mode 100644
index 000000000000000..99b85d77f6f920d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/vscale.ll
@@ -0,0 +1,10 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+define void @vscale_func() #0 {
+; CHECK: llvm.func @vscale_func()
+; CHECK-SAME: vscale_range(2, 8)
+  ret void
+}
+
+attributes #0 = { vscale_range(2,8) }
+
diff --git a/mlir/test/Target/LLVMIR/vscale.mlir b/mlir/test/Target/LLVMIR/vscale.mlir
new file mode 100644
index 000000000000000..df245964d2c4bbd
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/vscale.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @vscale_func()
+    vscale_range(2,8) {
+    // CHECK-LABEL: define void @vscale_func
+    // CHECK: attributes #{{.*}} = { vscale_range(2,8) }
+    llvm.return
+}
+



More information about the Mlir-commits mailing list