[Mlir-commits] [mlir] 5bfd5c6 - Add support for MLIR to llvm vscale attribute (#67012)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 25 06:32:22 PDT 2023


Author: Mats Petersson
Date: 2023-09-25T14:32:18+01:00
New Revision: 5bfd5c60bf02590a575661c609684e5cf3dd1e6f

URL: https://github.com/llvm/llvm-project/commit/5bfd5c60bf02590a575661c609684e5cf3dd1e6f
DIFF: https://github.com/llvm/llvm-project/commit/5bfd5c60bf02590a575661c609684e5cf3dd1e6f.diff

LOG: Add support for MLIR to llvm vscale attribute (#67012)

The vscale_range is used for scalabale vector functionality in Arm
Scalable Vector Extension to select the size of vector operation (and I
thnk RISCV has something similar).

This patch adds the base support for the vscale_range attribute to the
LLVM::FuncOp, and the marshalling for translation to LLVM-IR and import
from LLVM-IR to LLVM dialect.

This attribute is intended to be used at higher level MLIR, specified
either by command-line options to the compiler or using compiler
directives (e.g. pragmas or function attributes in the source code) to
indicate the desired range.

Added: 
    mlir/test/Target/LLVMIR/Import/vscale.ll
    mlir/test/Target/LLVMIR/vscale.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/func.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 2443b23e42e43ce..4a9bb5355d1c7ea 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -850,4 +850,14 @@ def LLVM_TBAATagArrayAttr
   let constBuilderCall = ?;
 }
 
+//===----------------------------------------------------------------------===//
+// VScaleRangeAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_VScaleRangeAttr : LLVM_Attr<"VScaleRange", "vscale_range"> {
+  let parameters =  (ins
+    "IntegerAttr":$minRange,
+    "IntegerAttr":$maxRange);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
 #endif // LLVMIR_ATTRDEFS

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index e4486eb36e51a1f..10bc8afeefa1777 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1393,7 +1393,8 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OptionalAttr<UnitAttr>:$arm_locally_streaming,
     OptionalAttr<StrAttr>:$section,
     OptionalAttr<UnnamedAddr>:$unnamed_addr,
-    OptionalAttr<I64Attr>:$alignment
+    OptionalAttr<I64Attr>:$alignment,
+    OptionalAttr<LLVM_VScaleRangeAttr>:$vscale_range
   );
 
   let regions = (region AnyRegion:$body);

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2e0b2582eee9629..d5f2f47bc147e6c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2342,6 +2342,19 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(type));
 
+  if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
+    int64_t minRange, maxRange;
+    if (parser.parseLParen() || parser.parseInteger(minRange) ||
+        parser.parseComma() || parser.parseInteger(maxRange) ||
+        parser.parseRParen())
+      return failure();
+    auto intTy = IntegerType::get(parser.getContext(), 32);
+    result.addAttribute(
+        getVscaleRangeAttrName(result.name),
+        LLVM::VScaleRangeAttr::get(parser.getContext(),
+                                   IntegerAttr::get(intTy, minRange),
+                                   IntegerAttr::get(intTy, maxRange)));
+  }
   // Parse the optional comdat selector.
   if (succeeded(parser.parseOptionalKeyword("comdat"))) {
     SymbolRefAttr comdat;
@@ -2398,6 +2411,11 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
   function_interface_impl::printFunctionSignature(p, *this, argTypes,
                                                   isVarArg(), resTypes);
 
+  // Print vscale range if present
+  if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
+    p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
+      << vscale->getMaxRange().getInt() << ')';
+
   // Print the optional comdat selector.
   if (auto comdat = getComdat())
     p << " comdat(" << *comdat << ')';
@@ -2406,7 +2424,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
       p, *this,
       {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
        getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
-       getComdatAttrName(), getUnnamedAddrAttrName()});
+       getComdatAttrName(), getUnnamedAddrAttrName(),
+       getVscaleRangeAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = getBody();

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 3d93332c4c56796..672d81bd3fd9129 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1569,8 +1569,9 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
 
     // Skip the aarch64_pstate_sm_<body|enabled> since the LLVMFuncOp has an
     // explicit attribute.
+    // Also skip the vscale_range, it is also an explicit attribute.
     if (attrName == "aarch64_pstate_sm_enabled" ||
-        attrName == "aarch64_pstate_sm_body")
+        attrName == "aarch64_pstate_sm_body" || attrName == "vscale_range")
       continue;
 
     if (attr.isStringAttribute()) {
@@ -1610,6 +1611,14 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
     funcOp.setArmStreaming(true);
   else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
     funcOp.setArmLocallyStreaming(true);
+  llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
+  if (attr.isValid()) {
+    MLIRContext *context = funcOp.getContext();
+    auto intTy = IntegerType::get(context, 32);
+    funcOp.setVscaleRangeAttr(LLVM::VScaleRangeAttr::get(
+        context, IntegerAttr::get(intTy, attr.getVScaleRangeMin()),
+        IntegerAttr::get(intTy, attr.getVScaleRangeMax().value_or(0))));
+  }
 }
 
 DictionaryAttr

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 33c85d85a684cc8..ee73b04e020fd26 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -907,6 +907,11 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   else if (func.getArmLocallyStreaming())
     llvmFunc->addFnAttr("aarch64_pstate_sm_body");
 
+  if (auto attr = func.getVscaleRange())
+    llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
+        getLLVMContext(), attr->getMinRange().getInt(),
+        attr->getMaxRange().getInt()));
+
   // First, create all blocks so we can jump to them.
   llvm::LLVMContext &llvmContext = llvmFunc->getContext();
   for (auto &bb : func) {

diff  --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index c0da9fc40e16faa..b45e6c4ef897b10 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -225,6 +225,12 @@ module {
   llvm.func @any() comdat(@__llvm_comdat::@any) attributes { dso_local } {
     llvm.return
   }
+
+  llvm.func @vscale_roundtrip() vscale_range(1, 2) {
+    // CHECK: @vscale_roundtrip
+    // CHECK-SAME: vscale_range(1, 2)
+    llvm.return
+  }
 }
 
 // -----

diff  --git a/mlir/test/Target/LLVMIR/Import/vscale.ll b/mlir/test/Target/LLVMIR/Import/vscale.ll
new file mode 100644
index 000000000000000..48a52e949bda36a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/vscale.ll
@@ -0,0 +1,7 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+define void @vscale_func() vscale_range(2,8) {
+  ; CHECK: llvm.func @vscale_func()
+  ; CHECK-SAME: vscale_range(2, 8)
+  ret void
+}

diff  --git a/mlir/test/Target/LLVMIR/vscale.mlir b/mlir/test/Target/LLVMIR/vscale.mlir
new file mode 100644
index 000000000000000..644cde73ca89677
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/vscale.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @vscale_func() vscale_range(2,8) {
+  // CHECK-LABEL: define void @vscale_func
+  // CHECK: attributes #{{.*}} = { vscale_range(2,8) }
+  llvm.return
+}


        


More information about the Mlir-commits mailing list