[Mlir-commits] [mlir] Add support for MLIR to llvm vscale attribute (PR #67012)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 21 06:14:43 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
<details>
<summary>Changes</summary>
The vscale_range is used for scalabale vector functionality in Arm Scalable Vector Extension to select the size of vector operation (and I thnk RISCV has something similar).
This patch adds the base support for the vscale_range attribute to the LLVM::FuncOp, and the marshalling for translation to LLVM-IR and import from LLVM-IR to LLVM dialect.
This attribute is intended to be used at higher level MLIR, specified either by command-line options to the compiler or using compiler directives (e.g. pragmas or function attributes in the source code) to indicate the desired range.
---
Full diff: https://github.com/llvm/llvm-project/pull/67012.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td (+7)
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+2-1)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+20-1)
- (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+13-1)
- (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+6)
- (added) mlir/test/Target/LLVMIR/Import/vscale.ll (+10)
- (added) mlir/test/Target/LLVMIR/vscale.mlir (+9)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 2443b23e42e43ce..322451334f2c5d0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -850,4 +850,11 @@ def LLVM_TBAATagArrayAttr
let constBuilderCall = ?;
}
+//--- VScaleRange Attribute
+
+def VScaleRangeAttr : LLVM_Attr<"VScaleRange", "vscale_range"> {
+ let parameters = (ins "IntegerAttr":$minRange,
+ "IntegerAttr":$maxRange);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
#endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index c0216d1971e58d2..345e9b072737e76 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1393,7 +1393,8 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<UnitAttr>:$arm_locally_streaming,
OptionalAttr<StrAttr>:$section,
OptionalAttr<UnnamedAddr>:$unnamed_addr,
- OptionalAttr<I64Attr>:$alignment
+ OptionalAttr<I64Attr>:$alignment,
+ OptionalAttr<VScaleRangeAttr>:$vscale_range
);
let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 1cf91bde28183ac..22d6d1ccc9d429f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2342,6 +2342,19 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(type));
+ if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
+ int64_t minRange, maxRange;
+ if (parser.parseLParen() || parser.parseInteger(minRange) ||
+ parser.parseComma() || parser.parseInteger(maxRange) ||
+ parser.parseRParen())
+ return failure();
+ auto intTy = mlir::IntegerType::get(parser.getContext(), 32);
+ result.addAttribute(getVscaleRangeAttrName(result.name),
+ mlir::LLVM::VScaleRangeAttr::get(
+ parser.getContext(),
+ mlir::IntegerAttr::get(intTy, minRange),
+ mlir::IntegerAttr::get(intTy, maxRange)));
+ }
// Parse the optional comdat selector.
if (succeeded(parser.parseOptionalKeyword("comdat"))) {
SymbolRefAttr comdat;
@@ -2398,6 +2411,11 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionSignature(p, *this, argTypes,
isVarArg(), resTypes);
+ // Print vscale range if present
+ if (auto vscale = getVscaleRange())
+ p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
+ << vscale->getMaxRange().getInt() << ')';
+
// Print the optional comdat selector.
if (auto comdat = getComdat())
p << " comdat(" << *comdat << ')';
@@ -2406,7 +2424,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
p, *this,
{getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
- getComdatAttrName(), getUnnamedAddrAttrName()});
+ getComdatAttrName(), getUnnamedAddrAttrName(),
+ getVscaleRangeAttrName()});
// Print the body if this is not an external function.
Region &body = getBody();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index c6c30880d4f2c15..f462a6a5f94fa45 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -31,6 +31,7 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Comdat.h"
+#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/InstIterator.h"
@@ -1569,8 +1570,9 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
// Skip the aarch64_pstate_sm_<body|enabled> since the LLVMFuncOp has an
// explicit attribute.
+ // Also skip the vscale_range, it is also an explicit attribute.
if (attrName == "aarch64_pstate_sm_enabled" ||
- attrName == "aarch64_pstate_sm_body")
+ attrName == "aarch64_pstate_sm_body" || attrName == "vscale_range")
continue;
if (attr.isStringAttribute()) {
@@ -1610,6 +1612,16 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
funcOp.setArmStreaming(true);
else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
funcOp.setArmLocallyStreaming(true);
+ llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
+ if (attr.isValid()) {
+ MLIRContext *context = funcOp.getContext();
+ auto intTy = mlir::IntegerType::get(context, 32);
+ std::optional<unsigned> maxAttr = attr.getVScaleRangeMax();
+ unsigned maxVal = maxAttr ? *maxAttr : 0;
+ funcOp.setVscaleRangeAttr(mlir::LLVM::VScaleRangeAttr::get(
+ context, mlir::IntegerAttr::get(intTy, attr.getVScaleRangeMin()),
+ mlir::IntegerAttr::get(intTy, maxVal)));
+ }
}
DictionaryAttr
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 33c85d85a684cc8..9e2a40c0d9ed795 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -907,6 +907,12 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
else if (func.getArmLocallyStreaming())
llvmFunc->addFnAttr("aarch64_pstate_sm_body");
+ if (auto attr = func.getVscaleRange()) {
+ llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
+ getLLVMContext(), attr->getMinRange().getInt(),
+ attr->getMaxRange().getInt()));
+ }
+
// First, create all blocks so we can jump to them.
llvm::LLVMContext &llvmContext = llvmFunc->getContext();
for (auto &bb : func) {
diff --git a/mlir/test/Target/LLVMIR/Import/vscale.ll b/mlir/test/Target/LLVMIR/Import/vscale.ll
new file mode 100644
index 000000000000000..99b85d77f6f920d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/vscale.ll
@@ -0,0 +1,10 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+define void @vscale_func() #0 {
+; CHECK: llvm.func @vscale_func()
+; CHECK-SAME: vscale_range(2, 8)
+ ret void
+}
+
+attributes #0 = { vscale_range(2,8) }
+
diff --git a/mlir/test/Target/LLVMIR/vscale.mlir b/mlir/test/Target/LLVMIR/vscale.mlir
new file mode 100644
index 000000000000000..df245964d2c4bbd
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/vscale.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @vscale_func()
+ vscale_range(2,8) {
+ // CHECK-LABEL: define void @vscale_func
+ // CHECK: attributes #{{.*}} = { vscale_range(2,8) }
+ llvm.return
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/67012
More information about the Mlir-commits
mailing list