[Mlir-commits] [mlir] Add support for MLIR to llvm vscale attribute (PR #67012)
Mats Petersson
llvmlistbot at llvm.org
Mon Sep 25 03:51:53 PDT 2023
https://github.com/Leporacanthicus updated https://github.com/llvm/llvm-project/pull/67012
>From bf5dd93dfc94ace1bb8405f8e94ea9b1d42d91b8 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 21 Sep 2023 12:37:49 +0100
Subject: [PATCH] [MLIR] Add support for vscale_range to LLVM dialect
The vscale_range is used for scalabale vector functionality in
Arm Scalable Vector Extension to select the size of vector
operation (and I thnk RISCV has something similar).
This patch adds the base support for the vscale_range attribute to
the LLVM::FuncOp, and the marshalling for translation to LLVM-IR
and import from LLVM-IR to LLVM dialect.
This attribute is intended to be used at higher level MLIR, specified
either by command-line options to the compiler or using compiler
directives (e.g. pragmas or function attributes in the source code)
to indicate the desired range.
---
.../mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 10 +++++++++
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 3 ++-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 21 ++++++++++++++++++-
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 11 +++++++++-
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 5 +++++
mlir/test/Dialect/LLVMIR/func.mlir | 6 ++++++
mlir/test/Target/LLVMIR/Import/vscale.ll | 7 +++++++
mlir/test/Target/LLVMIR/vscale.mlir | 7 +++++++
8 files changed, 67 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/Import/vscale.ll
create mode 100644 mlir/test/Target/LLVMIR/vscale.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 2443b23e42e43ce..4a9bb5355d1c7ea 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -850,4 +850,14 @@ def LLVM_TBAATagArrayAttr
let constBuilderCall = ?;
}
+//===----------------------------------------------------------------------===//
+// VScaleRangeAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_VScaleRangeAttr : LLVM_Attr<"VScaleRange", "vscale_range"> {
+ let parameters = (ins
+ "IntegerAttr":$minRange,
+ "IntegerAttr":$maxRange);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
#endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index e4486eb36e51a1f..10bc8afeefa1777 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1393,7 +1393,8 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<UnitAttr>:$arm_locally_streaming,
OptionalAttr<StrAttr>:$section,
OptionalAttr<UnnamedAddr>:$unnamed_addr,
- OptionalAttr<I64Attr>:$alignment
+ OptionalAttr<I64Attr>:$alignment,
+ OptionalAttr<LLVM_VScaleRangeAttr>:$vscale_range
);
let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2e0b2582eee9629..d5f2f47bc147e6c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2342,6 +2342,19 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(type));
+ if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
+ int64_t minRange, maxRange;
+ if (parser.parseLParen() || parser.parseInteger(minRange) ||
+ parser.parseComma() || parser.parseInteger(maxRange) ||
+ parser.parseRParen())
+ return failure();
+ auto intTy = IntegerType::get(parser.getContext(), 32);
+ result.addAttribute(
+ getVscaleRangeAttrName(result.name),
+ LLVM::VScaleRangeAttr::get(parser.getContext(),
+ IntegerAttr::get(intTy, minRange),
+ IntegerAttr::get(intTy, maxRange)));
+ }
// Parse the optional comdat selector.
if (succeeded(parser.parseOptionalKeyword("comdat"))) {
SymbolRefAttr comdat;
@@ -2398,6 +2411,11 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionSignature(p, *this, argTypes,
isVarArg(), resTypes);
+ // Print vscale range if present
+ if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
+ p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
+ << vscale->getMaxRange().getInt() << ')';
+
// Print the optional comdat selector.
if (auto comdat = getComdat())
p << " comdat(" << *comdat << ')';
@@ -2406,7 +2424,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
p, *this,
{getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
- getComdatAttrName(), getUnnamedAddrAttrName()});
+ getComdatAttrName(), getUnnamedAddrAttrName(),
+ getVscaleRangeAttrName()});
// Print the body if this is not an external function.
Region &body = getBody();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 3d93332c4c56796..672d81bd3fd9129 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1569,8 +1569,9 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
// Skip the aarch64_pstate_sm_<body|enabled> since the LLVMFuncOp has an
// explicit attribute.
+ // Also skip the vscale_range, it is also an explicit attribute.
if (attrName == "aarch64_pstate_sm_enabled" ||
- attrName == "aarch64_pstate_sm_body")
+ attrName == "aarch64_pstate_sm_body" || attrName == "vscale_range")
continue;
if (attr.isStringAttribute()) {
@@ -1610,6 +1611,14 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
funcOp.setArmStreaming(true);
else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
funcOp.setArmLocallyStreaming(true);
+ llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
+ if (attr.isValid()) {
+ MLIRContext *context = funcOp.getContext();
+ auto intTy = IntegerType::get(context, 32);
+ funcOp.setVscaleRangeAttr(LLVM::VScaleRangeAttr::get(
+ context, IntegerAttr::get(intTy, attr.getVScaleRangeMin()),
+ IntegerAttr::get(intTy, attr.getVScaleRangeMax().value_or(0))));
+ }
}
DictionaryAttr
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 33c85d85a684cc8..ee73b04e020fd26 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -907,6 +907,11 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
else if (func.getArmLocallyStreaming())
llvmFunc->addFnAttr("aarch64_pstate_sm_body");
+ if (auto attr = func.getVscaleRange())
+ llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
+ getLLVMContext(), attr->getMinRange().getInt(),
+ attr->getMaxRange().getInt()));
+
// First, create all blocks so we can jump to them.
llvm::LLVMContext &llvmContext = llvmFunc->getContext();
for (auto &bb : func) {
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index c0da9fc40e16faa..b45e6c4ef897b10 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -225,6 +225,12 @@ module {
llvm.func @any() comdat(@__llvm_comdat::@any) attributes { dso_local } {
llvm.return
}
+
+ llvm.func @vscale_roundtrip() vscale_range(1, 2) {
+ // CHECK: @vscale_roundtrip
+ // CHECK-SAME: vscale_range(1, 2)
+ llvm.return
+ }
}
// -----
diff --git a/mlir/test/Target/LLVMIR/Import/vscale.ll b/mlir/test/Target/LLVMIR/Import/vscale.ll
new file mode 100644
index 000000000000000..48a52e949bda36a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/vscale.ll
@@ -0,0 +1,7 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+define void @vscale_func() vscale_range(2,8) {
+ ; CHECK: llvm.func @vscale_func()
+ ; CHECK-SAME: vscale_range(2, 8)
+ ret void
+}
diff --git a/mlir/test/Target/LLVMIR/vscale.mlir b/mlir/test/Target/LLVMIR/vscale.mlir
new file mode 100644
index 000000000000000..644cde73ca89677
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/vscale.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @vscale_func() vscale_range(2,8) {
+ // CHECK-LABEL: define void @vscale_func
+ // CHECK: attributes #{{.*}} = { vscale_range(2,8) }
+ llvm.return
+}
More information about the Mlir-commits
mailing list