[Mlir-commits] [mlir] aece4e2 - [mlir][ArmSVE][RFC] Add an ArmSVE dialect
Mehdi Amini
llvmlistbot at llvm.org
Mon Dec 14 13:35:09 PST 2020
Author: Javier Setoain
Date: 2020-12-14T21:35:01Z
New Revision: aece4e2793ccf0d63d5e677a0ace83752b30979a
URL: https://github.com/llvm/llvm-project/commit/aece4e2793ccf0d63d5e677a0ace83752b30979a
DIFF: https://github.com/llvm/llvm-project/commit/aece4e2793ccf0d63d5e677a0ace83752b30979a.diff
LOG: [mlir][ArmSVE][RFC] Add an ArmSVE dialect
This revision starts an Arm-specific ArmSVE dialect discussed in the discourse RFC thread:
https://llvm.discourse.group/t/rfc-vector-dialects-neon-and-sve/2284
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D92172
Added:
mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td
mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h
mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt
mlir/lib/Dialect/ArmSVE/CMakeLists.txt
mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp
mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp
mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/ArmSVE/roundtrip.mlir
mlir/test/Target/arm-sve.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
mlir/include/mlir/Dialect/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/InitAllTranslations.h
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/PassDetail.h
mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/CMakeLists.txt
mlir/lib/Dialect/LLVMIR/CMakeLists.txt
mlir/lib/Target/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
new file mode 100644
index 000000000000..8cba4e9be5d5
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
@@ -0,0 +1,23 @@
+//===- ArmSVEToLLVM.h - Conversion Patterns from ArmSVE to LLVM -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
+#define MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class OwningRewritePatternList;
+
+/// Collect a set of patterns to convert from the ArmSVE dialect to LLVM.
+void populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ OwningRewritePatternList &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 56169d90c849..b364700bd849 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -396,8 +396,8 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
operations. The lowering pass provides several options to control
the kinds of optimizations that are allowed. It also provides options
that enable the use of one or more architectural-specific dialects
- (AVX512, ArmNeon, SVE, etc.) in combination with the architectural-neutral
- vector dialect lowering.
+ (AVX512, ArmNeon, ArmSVE, etc.) in combination with the
+ architectural-neutral vector dialect lowering.
}];
let constructor = "mlir::createConvertVectorToLLVMPass()";
@@ -418,7 +418,11 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
Option<"enableArmNeon", "enable-arm-neon",
"bool", /*default=*/"false",
"Enables the use of ArmNeon dialect while lowering the vector "
- "dialect.">
+ "dialect.">,
+ Option<"enableArmSVE", "enable-arm-sve",
+ "bool", /*default=*/"false",
+ "Enables the use of ArmSVE dialect while lowering the vector "
+ "dialect.">
];
}
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 7ff061cb9d09..8d24803eeb1c 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -23,7 +23,7 @@ class OperationPass;
struct LowerVectorToLLVMOptions {
LowerVectorToLLVMOptions()
: reassociateFPReductions(false), enableIndexOptimizations(true),
- enableArmNeon(false), enableAVX512(false) {}
+ enableArmNeon(false), enableArmSVE(false), enableAVX512(false) {}
LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
reassociateFPReductions = b;
@@ -33,18 +33,23 @@ struct LowerVectorToLLVMOptions {
enableIndexOptimizations = b;
return *this;
}
- LowerVectorToLLVMOptions &setEnableAVX512(bool b) {
- enableAVX512 = b;
- return *this;
- }
LowerVectorToLLVMOptions &setEnableArmNeon(bool b) {
enableArmNeon = b;
return *this;
}
+ LowerVectorToLLVMOptions &setEnableArmSVE(bool b) {
+ enableArmSVE = b;
+ return *this;
+ }
+ LowerVectorToLLVMOptions &setEnableAVX512(bool b) {
+ enableAVX512 = b;
+ return *this;
+ }
bool reassociateFPReductions;
bool enableIndexOptimizations;
bool enableArmNeon;
+ bool enableArmSVE;
bool enableAVX512;
};
diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
new file mode 100644
index 000000000000..b18da2358243
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -0,0 +1,276 @@
+//===-- ArmSVE.td - ArmSVE dialect operation definitions ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the ArmSVE dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMSVE_OPS
+#define ARMSVE_OPS
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// ArmSVE dialect definition
+//===----------------------------------------------------------------------===//
+
+def ArmSVE_Dialect : Dialect {
+ let name = "arm_sve";
+ let cppNamespace = "::mlir::arm_sve";
+ let summary = "Basic dialect to target Arm SVE architectures";
+ let description = [{
+ This dialect contains the definitions necessary to target Arm SVE scalable
+ vector operations, including a scalable vector type and intrinsics for
+ some Arm SVE instructions.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// ArmSVE type definitions
+//===----------------------------------------------------------------------===//
+
+def ArmSVE_ScalableVectorType : DialectType<ArmSVE_Dialect,
+ CPred<"$_self.isa<ScalableVectorType>()">,
+ "scalable vector type">,
+ BuildableType<"$_builder.getType<ScalableVectorType>()"> {
+ let typeDescription = [{
+ `arm_sve.vector` represents vectors that will be processed by a scalable
+ vector architecture.
+ }];
+}
+
+class ArmSVE_Type<string name> : TypeDef<ArmSVE_Dialect, name> { }
+
+def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
+ let mnemonic = "vector";
+
+ let summary = "Scalable vector type";
+
+ let description = [{
+ A type representing scalable length SIMD vectors. Unlike fixed-length SIMD
+ vectors, whose size is constant and known at compile time, scalable
+ vectors' length is constant but determined by the specific hardware at
+ run time.
+ }];
+
+ let parameters = (ins
+ ArrayRefParameter<"int64_t", "Vector shape">:$shape,
+ "Type":$elementType
+ );
+
+ let printer = [{
+ $_printer << "vector<";
+ for (int64_t dim : getShape())
+ $_printer << dim << 'x';
+ $_printer << getElementType() << '>';
+ }];
+
+ let parser = [{
+ VectorType vector;
+ if ($_parser.parseType(vector))
+ return Type();
+ return get(ctxt, vector.getShape(), vector.getElementType());
+ }];
+
+ let extraClassDeclaration = [{
+ bool hasStaticShape() const {
+ return llvm::none_of(getShape(), ShapedType::isDynamic);
+ }
+ int64_t getNumElements() const {
+ assert(hasStaticShape() &&
+ "cannot get element count of dynamic shaped type");
+ ArrayRef<int64_t> shape = getShape();
+ int64_t num = 1;
+ for (auto dim : shape)
+ num *= dim;
+ return num;
+ }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// ArmSVE type traits
+//===----------------------------------------------------------------------===//
+
+def IsScalableVectorTypePred :
+ CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">;
+
+class ScalableVectorOf<list<Type> allowedTypes> :
+ ContainerType<AnyTypeOf<allowedTypes>, IsScalableVectorTypePred,
+ "$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()",
+ "scalable vector">;
+
+class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
+ And<[IsScalableVectorTypePred,
+ Or<!foreach(allowedlength, allowedLengths, CPred<
+ [{$_self.cast<::mlir::arm_sve::ScalableVectorType>().getNumElements() == }]
+ # allowedlength>)>]>;
+
+class ScalableVectorOfLength<list<int> allowedLengths> : Type<
+ IsScalableVectorOfLengthPred<allowedLengths>,
+ " of length " # StrJoinInt<allowedLengths, "/">.result>;
+
+class ScalableVectorOfLengthAndType<list<int> allowedLengths,
+ list<Type> allowedTypes> : Type<
+ And<[ScalableVectorOf<allowedTypes>.predicate,
+ ScalableVectorOfLength<allowedLengths>.predicate]>,
+ ScalableVectorOf<allowedTypes>.description #
+ ScalableVectorOfLength<allowedLengths>.description>;
+
+//===----------------------------------------------------------------------===//
+// ArmSVE op definitions
+//===----------------------------------------------------------------------===//
+
+class ArmSVE_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<ArmSVE_Dialect, mnemonic, traits> {}
+
+def SdotOp : ArmSVE_Op<"sdot",
+ [NoSideEffect,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>,
+ ]> {
+ let summary = "Vector-vector dot product and accumulate op";
+ let description = [{
+ SDOT: Signed integer addition of dot product.
+
+ This function maps to the SDOT instruction, and it takes signless integer
+ operands that the operation interprets as signed. It partitions the second
+ and third vector inputs into groups of four elements. They calculate the dot
+ product of each group (without loss of precision) and then add each result
+ to the overlapping element of the first vector input.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports either:
+ // (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ // (vector<8xi16>. vector<8xi16>) -> (vector<2xi64>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$acc,
+ ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src1,
+ ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
+def SmmlaOp : ArmSVE_Op<"smmla",
+ [NoSideEffect,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>,
+ ]> {
+ let summary = "Matrix-matrix mutiply and accumulate op";
+ let description = [{
+ SMMLA: Signed integer matrix multiply-accumulate.
+
+ This function maps to the SMMLA instruction, and it takes signless integer
+ operands that the operation interprets as signed. It partitions the inputs
+ into 128-bit quadwords, with the first input containing a row-by-row 2×2
+ matrix of 32-bit integers, the second input containing a row-by-row 2×8
+ matrix of 8-bit integers, and the third input containing a column-by-column
+ 8×2 matrix of 8-bit integers. For each quadword, they multiply the second
+ input matrix by the third input matrix using natural arithmetic and then add
+ the result to the first input using modular arithmetic.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
+def UdotOp : ArmSVE_Op<"udot",
+ [NoSideEffect,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>,
+ ]> {
+ let summary = "Vector-vector dot product and accumulate op";
+ let description = [{
+ UDOT: Unsigned integer addition of dot product.
+
+ This function maps to the UDOT instruction, and it takes signless integer
+ operands that the operation interprets as unsigned. It partitions the second
+ and third vector inputs into groups of four elements. They calculate the dot
+ product of each group (without loss of precision) and then add each result
+ to the overlapping element of the first vector input.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports either:
+ // (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ // (vector<8xi16>. vector<8xi16>) -> (vector<2xi64>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$acc,
+ ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src1,
+ ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
+def UmmlaOp : ArmSVE_Op<"ummla",
+ [NoSideEffect,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>,
+ ]> {
+ let summary = "Matrix-matrix mutiply and accumulate op";
+ let description = [{
+ UMMLA: Unsigned integer matrix multiply-accumulate.
+
+ This function maps to the UMMLA instruction, and it takes signless integer
+ operands that the operation interprets as unsigned. It partitions the inputs
+ into 128-bit quadwords, with the first input containing a row-by-row 2×2
+ matrix of 32-bit integers, the second input containing a row-by-row 2×8
+ matrix of 8-bit integers, and the third input containing a column-by-column
+ 8×2 matrix of 8-bit integers. For each quadword, they multiply the second
+ input matrix by the third input matrix using natural arithmetic and then add
+ the result to the first input using modular arithmetic.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
+def VectorScaleOp : ArmSVE_Op<"vector_scale",
+ [NoSideEffect]> {
+ let summary = "Load vector scale size";
+ let description = [{
+ The vector_scale op returns the scale of the scalable vectors, a positive
+ integer value that is constant at runtime but unknown at compile time.
+ The scale of the vector indicates the multiplicity of the vectors and
+ vector operations. I.e.: an !arm_sve.vector<4xi32> is equivalent to
+ vector_scale consecutive vector<4xi32>; and an operation on an
+ !arm_sve.vector<4xi32> is equivalent to performing that operation vector_scale
+ times, once on each <4xi32> segment of the scalable vector. The vector_scale
+ op can be used to calculate the step in vector-length agnostic (VLA) loops.
+ }];
+ let results = (outs Index:$res);
+ let assemblyFormat =
+ "attr-dict `:` type($res)";
+}
+
+#endif // ARMSVE_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
new file mode 100644
index 000000000000..15ca3403bef0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
@@ -0,0 +1,29 @@
+//===- ArmSVEDialect.h - MLIR Dialect for Arm SVE ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmSVE in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H
+#define MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVE.h.inc"
+
+#endif // MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H
diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
new file mode 100644
index 000000000000..fb50fac68f33
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -0,0 +1 @@
+add_mlir_dialect(ArmSVE arm_sve ArmSVE)
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 034b611d6288..51b423ee3b98 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -1,6 +1,7 @@
add_subdirectory(Affine)
add_subdirectory(Async)
add_subdirectory(ArmNeon)
+add_subdirectory(ArmSVE)
add_subdirectory(AVX512)
add_subdirectory(GPU)
add_subdirectory(Linalg)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 809e4abe7e84..6bd289097017 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -37,3 +37,9 @@ add_mlir_doc(LLVMArmNeon -gen-dialect-doc LLVMArmNeon Dialects/)
set(LLVM_TARGET_DEFINITIONS LLVMArmNeon.td)
mlir_tablegen(LLVMArmNeonConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRLLVMArmNeonConversionsIncGen)
+
+add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE)
+add_mlir_doc(LLVMArmSVE -gen-dialect-doc LLVMArmSve Dialects/)
+set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td)
+mlir_tablegen(LLVMArmSVEConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRLLVMArmSVEConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td
new file mode 100644
index 000000000000..05237c86cb05
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td
@@ -0,0 +1,70 @@
+//===-- LLVMArmSVE.td - LLVMARMSVE dialect op definitions --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the LLVMArmSVE dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVMIR_ARMSVE_OPS
+#define LLVMIR_ARMSVE_OPS
+
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// LLVMArmSVE dialect definition
+//===----------------------------------------------------------------------===//
+
+def LLVMArmSVE_Dialect : Dialect {
+ let name = "llvm_arm_sve";
+ let cppNamespace = "::mlir::LLVM";
+}
+
+//----------------------------------------------------------------------------//
+// MLIR LLVM Arm SVE intrinsics using the MLIR LLVM Dialect type system
+//----------------------------------------------------------------------------//
+
+class LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<string mnemonic,
+ list<OpTrait> traits =[]> :
+ LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmSVE_Dialect,
+ /*string opName=*/mnemonic,
+ /*string enumName=*/mnemonic,
+ /*list<int> overloadedResults=*/[0],
+ /*list<int> overloadedOperands=*/[], // defined by result overload
+ /*list<OpTrait> traits=*/traits,
+ /*int numResults=*/1>;
+
+class LLVMArmSVE_IntrBinaryOverloadedOp<string mnemonic,
+ list<OpTrait> traits = []> :
+ LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmSVE_Dialect,
+ /*string opName=*/mnemonic,
+ /*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
+ /*list<int> overloadedResults=*/[0],
+ /*list<int> overloadedOperands=*/[], // defined by result overload
+ /*list<OpTrait> traits=*/traits,
+ /*int numResults=*/1>;
+
+def LLVM_aarch64_arm_sve_ummla :
+ LLVMArmSVE_IntrBinaryOverloadedOp<"ummla">,
+ Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+def LLVM_aarch64_arm_sve_smmla :
+ LLVMArmSVE_IntrBinaryOverloadedOp<"smmla">,
+ Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+def LLVM_aarch64_arm_sve_sdot :
+ LLVMArmSVE_IntrBinaryOverloadedOp<"sdot">,
+ Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+def LLVM_aarch64_arm_sve_udot :
+ LLVMArmSVE_IntrBinaryOverloadedOp<"udot">,
+ Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+def LLVM_vector_scale :
+ LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">;
+
+#endif // ARMSVE_OPS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h
new file mode 100644
index 000000000000..f9758c1fa80d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h
@@ -0,0 +1,24 @@
+//===- LLVMSVEDialect.h - MLIR Dialect for LLVMSVE --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for LLVMArmSVE in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
+#define MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMArmSVE.h.inc"
+
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h.inc"
+
+#endif // MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 6d34449e65d4..a541fa4742ba 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -17,10 +17,12 @@
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -54,6 +56,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
LLVM::LLVMAVX512Dialect,
LLVM::LLVMDialect,
LLVM::LLVMArmNeonDialect,
+ LLVM::LLVMArmSVEDialect,
linalg::LinalgDialect,
scf::SCFDialect,
omp::OpenMPDialect,
@@ -62,6 +65,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
quant::QuantizationDialect,
spirv::SPIRVDialect,
StandardOpsDialect,
+ arm_sve::ArmSVEDialect,
vector::VectorDialect,
NVVM::NVVMDialect,
ROCDL::ROCDLDialect,
diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h
index cafc931c2d9f..16dd113d14cd 100644
--- a/mlir/include/mlir/InitAllTranslations.h
+++ b/mlir/include/mlir/InitAllTranslations.h
@@ -24,6 +24,7 @@ void registerToNVVMIRTranslation();
void registerToROCDLIRTranslation();
void registerArmNeonToLLVMIRTranslation();
void registerAVX512ToLLVMIRTranslation();
+void registerArmSVEToLLVMIRTranslation();
// This function should be called before creating any MLIRContext if one
// expects all the possible translations to be made available to the context
@@ -38,6 +39,7 @@ inline void registerAllTranslations() {
registerToROCDLIRTranslation();
registerArmNeonToLLVMIRTranslation();
registerAVX512ToLLVMIRTranslation();
+ registerArmSVEToLLVMIRTranslation();
return true;
}();
(void)initOnce;
diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
new file mode 100644
index 000000000000..5742cd790e77
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
@@ -0,0 +1,75 @@
+//===- ArmSVEToLLVM.cpp - Convert ArmSVE to the LLVM dialect --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+using namespace mlir::vector;
+
+using SdotOpLowering =
+ OneToOneConvertToLLVMPattern<SdotOp, LLVM::aarch64_arm_sve_sdot>;
+
+using SmmlaOpLowering =
+ OneToOneConvertToLLVMPattern<SmmlaOp, LLVM::aarch64_arm_sve_smmla>;
+
+using UdotOpLowering =
+ OneToOneConvertToLLVMPattern<UdotOp, LLVM::aarch64_arm_sve_udot>;
+
+using UmmlaOpLowering =
+ OneToOneConvertToLLVMPattern<UmmlaOp, LLVM::aarch64_arm_sve_ummla>;
+
+using VectorScaleOpLowering =
+ OneToOneConvertToLLVMPattern<VectorScaleOp, LLVM::vector_scale>;
+
+// Extract an LLVM IR type from the LLVM IR dialect type.
+static LLVM::LLVMType unwrap(Type type) {
+ if (!type)
+ return nullptr;
+ auto *mlirContext = type.getContext();
+ auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
+ if (!wrappedLLVMType)
+ emitError(UnknownLoc::get(mlirContext),
+ "conversion resulted in a non-LLVM type");
+ return wrappedLLVMType;
+}
+
+static Optional<Type>
+convertScalableVectorTypeToLLVM(ScalableVectorType svType,
+ LLVMTypeConverter &converter) {
+ auto elementType = unwrap(converter.convertType(svType.getElementType()));
+ if (!elementType)
+ return {};
+
+ auto sVectorType =
+ LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
+ return sVectorType;
+}
+
+/// Populate the given list with patterns that convert from ArmSVE to LLVM.
+void mlir::populateArmSVEToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ converter.addConversion([&converter](ScalableVectorType svType) {
+ return convertScalableVectorTypeToLLVM(svType, converter);
+ });
+ // clang-format off
+ patterns.insert<SdotOpLowering,
+ SmmlaOpLowering,
+ UdotOpLowering,
+ UmmlaOpLowering,
+ VectorScaleOpLowering>(converter);
+ // clang-format on
+}
diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt
new file mode 100644
index 000000000000..6179d0a10202
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRArmSVEToLLVM
+ ArmSVEToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSVEToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArmSVE
+ MLIRLLVMArmSVE
+ MLIRLLVMIR
+ MLIRStandardToLLVM
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index a0195486cfd6..421523267df2 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -20,6 +20,7 @@ add_subdirectory(ShapeToStandard)
add_subdirectory(SPIRVToLLVM)
add_subdirectory(StandardToLLVM)
add_subdirectory(StandardToSPIRV)
+add_subdirectory(ArmSVEToLLVM)
add_subdirectory(VectorToROCDL)
add_subdirectory(VectorToLLVM)
add_subdirectory(VectorToSCF)
diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index dd69924166a5..ecd932f99c78 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -26,6 +26,7 @@ class GPUModuleOp;
namespace LLVM {
class LLVMArmNeonDialect;
+class LLVMArmSVEDialect;
class LLVMAVX512Dialect;
class LLVMDialect;
} // end namespace LLVM
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 6d7f7aa04d52..f1d662b630b1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -19,6 +19,9 @@ add_mlir_conversion_library(MLIRVectorToLLVM
MLIRAVX512ToLLVM
MLIRLLVMArmNeon
MLIRLLVMAVX512
+ MLIRArmSVE
+ MLIRArmSVEToLLVM
+ MLIRLLVMArmSVE
MLIRLLVMIR
MLIRStandardToLLVM
MLIRTargetLLVMIRModuleTranslation
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 99f0bae05406..af6ce6a0a68c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -12,12 +12,15 @@
#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
#include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
+#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -32,6 +35,7 @@ struct LowerVectorToLLVMPass
this->reassociateFPReductions = options.reassociateFPReductions;
this->enableIndexOptimizations = options.enableIndexOptimizations;
this->enableArmNeon = options.enableArmNeon;
+ this->enableArmSVE = options.enableArmSVE;
this->enableAVX512 = options.enableAVX512;
}
// Override explicitly to allow conditional dialect dependence.
@@ -39,6 +43,8 @@ struct LowerVectorToLLVMPass
registry.insert<LLVM::LLVMDialect>();
if (enableArmNeon)
registry.insert<LLVM::LLVMArmNeonDialect>();
+ if (enableArmSVE)
+ registry.insert<LLVM::LLVMArmSVEDialect>();
if (enableAVX512)
registry.insert<LLVM::LLVMAVX512Dialect>();
}
@@ -73,6 +79,11 @@ void LowerVectorToLLVMPass::runOnOperation() {
target.addIllegalDialect<arm_neon::ArmNeonDialect>();
populateArmNeonToLLVMConversionPatterns(converter, patterns);
}
+ if (enableArmSVE) {
+ target.addLegalDialect<LLVM::LLVMArmSVEDialect>();
+ target.addIllegalDialect<arm_sve::ArmSVEDialect>();
+ populateArmSVEToLLVMConversionPatterns(converter, patterns);
+ }
if (enableAVX512) {
target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
target.addIllegalDialect<avx512::AVX512Dialect>();
diff --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
new file mode 100644
index 000000000000..614ab4144574
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRArmSVE
+ IR/ArmSVEDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE
+
+ DEPENDS
+ MLIRArmSVEIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSideEffectInterfaces
+ )
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
new file mode 100644
index 000000000000..2c76a64fac5c
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -0,0 +1,57 @@
+//===- ArmSVEDialect.cpp - MLIR ArmSVE dialect implementation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the ArmSVE dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+void arm_sve::ArmSVEDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
+ >();
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
+ >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// ScalableVectorType
+//===----------------------------------------------------------------------===//
+
+Type arm_sve::ArmSVEDialect::parseType(DialectAsmParser &parser) const {
+ llvm::SMLoc typeLoc = parser.getCurrentLocation();
+ auto genType = generatedTypeParser(getContext(), parser, "vector");
+ if (genType != Type())
+ return genType;
+ parser.emitError(typeLoc, "unknown type in ArmSVE dialect");
+ return Type();
+}
+
+void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
+ if (failed(generatedTypePrinter(type, os)))
+ llvm_unreachable("unexpected 'arm_sve' type kind");
+}
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 9fd38aa92df6..ae9afdc70552 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -1,5 +1,6 @@
add_subdirectory(Affine)
add_subdirectory(ArmNeon)
+add_subdirectory(ArmSVE)
add_subdirectory(Async)
add_subdirectory(AVX512)
add_subdirectory(GPU)
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 87ad7e965d2e..cd73e7dcfc69 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -70,6 +70,27 @@ add_mlir_dialect_library(MLIRLLVMArmNeon
MLIRSideEffectInterfaces
)
+add_mlir_dialect_library(MLIRLLVMArmSVE
+ IR/LLVMArmSVEDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
+
+ DEPENDS
+ MLIRLLVMArmSVEIncGen
+ MLIRLLVMArmSVEConversionsIncGen
+ intrinsics_gen
+
+ LINK_COMPONENTS
+ AsmParser
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMIR
+ MLIRSideEffectInterfaces
+ )
+
add_mlir_dialect_library(MLIRNVVMIR
IR/NVVMDialect.cpp
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp
new file mode 100644
index 000000000000..60ef8e6799a3
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp
@@ -0,0 +1,31 @@
+//===- LLVMArmSVEDialect.cpp - MLIR LLVMSVE ops implementation ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the LLVMArmSVE dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/IntrinsicsAArch64.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+void LLVM::LLVMArmSVEDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc"
+ >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc"
diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index 275c1a0f78fc..1b1a02db5511 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -74,6 +74,25 @@ add_mlir_translation_library(MLIRTargetArmNeon
MLIRTargetLLVMIRModuleTranslation
)
+add_mlir_translation_library(MLIRTargetArmSVE
+ LLVMIR/LLVMArmSVEIntr.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
+
+ DEPENDS
+ MLIRLLVMArmSVEConversionsIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMArmSVE
+ MLIRLLVMIR
+ MLIRTargetLLVMIRModuleTranslation
+ )
+
add_mlir_translation_library(MLIRTargetNVVMIR
LLVMIR/ConvertToNVVMIR.cpp
diff --git a/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp
new file mode 100644
index 000000000000..717583a2d8d7
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp
@@ -0,0 +1,63 @@
+//===- LLVMArmSVEIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR LLVM and ArmSVE dialects
+// and LLVM IR with Arm SVE intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Translation.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
+
+using namespace mlir;
+
+namespace {
+class LLVMArmSVEModuleTranslation : public LLVM::ModuleTranslation {
+ friend LLVM::ModuleTranslation;
+
+public:
+ using LLVM::ModuleTranslation::ModuleTranslation;
+
+protected:
+ LogicalResult convertOperation(Operation &opInst,
+ llvm::IRBuilder<> &builder) override {
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEConversions.inc"
+
+ return LLVM::ModuleTranslation::convertOperation(opInst, builder);
+ }
+};
+} // end namespace
+
+static std::unique_ptr<llvm::Module>
+translateLLVMArmSVEModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext,
+ StringRef name) {
+ return LLVM::ModuleTranslation::translateModule<LLVMArmSVEModuleTranslation>(
+ m, llvmContext, name);
+}
+
+namespace mlir {
+void registerArmSVEToLLVMIRTranslation() {
+ TranslateFromMLIRRegistration reg(
+ "arm-sve-mlir-to-llvmir",
+ [](ModuleOp module, raw_ostream &output) {
+ llvm::LLVMContext llvmContext;
+ auto llvmModule = translateLLVMArmSVEModuleToLLVMIR(
+ module, llvmContext, "LLVMDialectModule");
+ if (!llvmModule)
+ return failure();
+
+ llvmModule->print(output, nullptr);
+ return success();
+ },
+ [](DialectRegistry ®istry) {
+ registry.insert<LLVM::LLVMArmSVEDialect, LLVM::LLVMDialect>();
+ });
+}
+} // namespace mlir
diff --git a/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir
new file mode 100644
index 000000000000..5f218c9f421a
--- /dev/null
+++ b/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s
+
+func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
+ %b: !arm_sve.vector<16xi8>,
+ %c: !arm_sve.vector<4xi32>)
+ -> !arm_sve.vector<4xi32> {
+ // CHECK: llvm_arm_sve.sdot
+ %0 = arm_sve.sdot %c, %a, %b :
+ !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+ return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
+ %b: !arm_sve.vector<16xi8>,
+ %c: !arm_sve.vector<4xi32>)
+ -> !arm_sve.vector<4xi32> {
+ // CHECK: llvm_arm_sve.smmla
+ %0 = arm_sve.smmla %c, %a, %b :
+ !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+ return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
+ %b: !arm_sve.vector<16xi8>,
+ %c: !arm_sve.vector<4xi32>)
+ -> !arm_sve.vector<4xi32> {
+ // CHECK: llvm_arm_sve.udot
+ %0 = arm_sve.udot %c, %a, %b :
+ !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+ return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
+ %b: !arm_sve.vector<16xi8>,
+ %c: !arm_sve.vector<4xi32>)
+ -> !arm_sve.vector<4xi32> {
+ // CHECK: llvm_arm_sve.ummla
+ %0 = arm_sve.ummla %c, %a, %b :
+ !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+ return %0 : !arm_sve.vector<4xi32>
+}
+
+func @get_vector_scale() -> index {
+ // CHECK: llvm_arm_sve.vscale
+ %0 = arm_sve.vector_scale : index
+ return %0 : index
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
new file mode 100644
index 000000000000..8834ef87207b
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
+ %b: !arm_sve.vector<16xi8>,
+ %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+ // CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
+ %0 = arm_sve.sdot %c, %a, %b :
+ !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+ return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
+ %b: !arm_sve.vector<16xi8>,
+ %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+ // CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
+ %0 = arm_sve.smmla %c, %a, %b :
+ !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+ return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
+ %b: !arm_sve.vector<16xi8>,
+ %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+ // CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
+ %0 = arm_sve.udot %c, %a, %b :
+ !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+ return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
+ %b: !arm_sve.vector<16xi8>,
+ %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+ // CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
+ %0 = arm_sve.ummla %c, %a, %b :
+ !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+ return %0 : !arm_sve.vector<4xi32>
+}
+
+func @get_vector_scale() -> index {
+ // CHECK: arm_sve.vector_scale : index
+ %0 = arm_sve.vector_scale : index
+ return %0 : index
+}
diff --git a/mlir/test/Target/arm-sve.mlir b/mlir/test/Target/arm-sve.mlir
new file mode 100644
index 000000000000..7340fea34be6
--- /dev/null
+++ b/mlir/test/Target/arm-sve.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --arm-sve-mlir-to-llvmir | FileCheck %s
+
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_sdot
+llvm.func @arm_sve_sdot(%arg0: !llvm.vec<? x 16 x i8>,
+ %arg1: !llvm.vec<? x 16 x i8>,
+ %arg2: !llvm.vec<? x 4 x i32>)
+ -> !llvm.vec<? x 4 x i32> {
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4
+ %0 = "llvm_arm_sve.sdot"(%arg2, %arg0, %arg1) :
+ (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
+ -> !llvm.vec<? x 4 x i32>
+ llvm.return %0 : !llvm.vec<? x 4 x i32>
+}
+
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_smmla
+llvm.func @arm_sve_smmla(%arg0: !llvm.vec<? x 16 x i8>,
+ %arg1: !llvm.vec<? x 16 x i8>,
+ %arg2: !llvm.vec<? x 4 x i32>)
+ -> !llvm.vec<? x 4 x i32> {
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.smmla.nxv4i32(<vscale x 4
+ %0 = "llvm_arm_sve.smmla"(%arg2, %arg0, %arg1) :
+ (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
+ -> !llvm.vec<? x 4 x i32>
+ llvm.return %0 : !llvm.vec<? x 4 x i32>
+}
+
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_udot
+llvm.func @arm_sve_udot(%arg0: !llvm.vec<? x 16 x i8>,
+ %arg1: !llvm.vec<? x 16 x i8>,
+ %arg2: !llvm.vec<? x 4 x i32>)
+ -> !llvm.vec<? x 4 x i32> {
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4
+ %0 = "llvm_arm_sve.udot"(%arg2, %arg0, %arg1) :
+ (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
+ -> !llvm.vec<? x 4 x i32>
+ llvm.return %0 : !llvm.vec<? x 4 x i32>
+}
+
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_ummla
+llvm.func @arm_sve_ummla(%arg0: !llvm.vec<? x 16 x i8>,
+ %arg1: !llvm.vec<? x 16 x i8>,
+ %arg2: !llvm.vec<? x 4 x i32>)
+ -> !llvm.vec<? x 4 x i32> {
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.ummla.nxv4i32(<vscale x 4
+ %0 = "llvm_arm_sve.ummla"(%arg2, %arg0, %arg1) :
+ (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
+ -> !llvm.vec<? x 4 x i32>
+ llvm.return %0 : !llvm.vec<? x 4 x i32>
+}
+
+// CHECK-LABEL: define i64 @get_vector_scale()
+llvm.func @get_vector_scale() -> !llvm.i64 {
+ // CHECK: call i64 @llvm.vscale.i64()
+ %0 = "llvm_arm_sve.vscale"() : () -> !llvm.i64
+ llvm.return %0 : !llvm.i64
+}
More information about the Mlir-commits
mailing list