[Mlir-commits] [mlir] aece4e2 - [mlir][ArmSVE][RFC] Add an ArmSVE dialect

Mehdi Amini llvmlistbot at llvm.org
Mon Dec 14 13:35:09 PST 2020


Author: Javier Setoain
Date: 2020-12-14T21:35:01Z
New Revision: aece4e2793ccf0d63d5e677a0ace83752b30979a

URL: https://github.com/llvm/llvm-project/commit/aece4e2793ccf0d63d5e677a0ace83752b30979a
DIFF: https://github.com/llvm/llvm-project/commit/aece4e2793ccf0d63d5e677a0ace83752b30979a.diff

LOG: [mlir][ArmSVE][RFC] Add an ArmSVE dialect

This revision starts an Arm-specific ArmSVE dialect discussed in the discourse RFC thread:

https://llvm.discourse.group/t/rfc-vector-dialects-neon-and-sve/2284

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D92172

Added: 
    mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
    mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
    mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
    mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
    mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h
    mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
    mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt
    mlir/lib/Dialect/ArmSVE/CMakeLists.txt
    mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp
    mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp
    mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir
    mlir/test/Dialect/ArmSVE/roundtrip.mlir
    mlir/test/Target/arm-sve.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
    mlir/include/mlir/Dialect/CMakeLists.txt
    mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/include/mlir/InitAllTranslations.h
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/PassDetail.h
    mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/CMakeLists.txt
    mlir/lib/Dialect/LLVMIR/CMakeLists.txt
    mlir/lib/Target/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
new file mode 100644
index 000000000000..8cba4e9be5d5
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h
@@ -0,0 +1,23 @@
+//===- ArmSVEToLLVM.h - Conversion Patterns from ArmSVE to LLVM -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
+#define MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class OwningRewritePatternList;
+
+/// Collect a set of patterns to convert from the ArmSVE dialect to LLVM.
+void populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                            OwningRewritePatternList &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 56169d90c849..b364700bd849 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -396,8 +396,8 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
     operations. The lowering pass provides several options to control
     the kinds of optimizations that are allowed. It also provides options
     that enable the use of one or more architectural-specific dialects
-    (AVX512, ArmNeon, SVE, etc.) in combination with the architectural-neutral
-    vector dialect lowering.
+    (AVX512, ArmNeon, ArmSVE, etc.) in combination with the
+    architectural-neutral vector dialect lowering.
 
   }];
   let constructor = "mlir::createConvertVectorToLLVMPass()";
@@ -418,7 +418,11 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
     Option<"enableArmNeon", "enable-arm-neon",
            "bool", /*default=*/"false",
            "Enables the use of ArmNeon dialect while lowering the vector "
-	   "dialect.">
+	   "dialect.">,
+    Option<"enableArmSVE", "enable-arm-sve",
+           "bool", /*default=*/"false",
+           "Enables the use of ArmSVE dialect while lowering the vector "
+       "dialect.">
   ];
 }
 

diff  --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 7ff061cb9d09..8d24803eeb1c 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -23,7 +23,7 @@ class OperationPass;
 struct LowerVectorToLLVMOptions {
   LowerVectorToLLVMOptions()
       : reassociateFPReductions(false), enableIndexOptimizations(true),
-        enableArmNeon(false), enableAVX512(false) {}
+        enableArmNeon(false), enableArmSVE(false), enableAVX512(false) {}
 
   LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
     reassociateFPReductions = b;
@@ -33,18 +33,23 @@ struct LowerVectorToLLVMOptions {
     enableIndexOptimizations = b;
     return *this;
   }
-  LowerVectorToLLVMOptions &setEnableAVX512(bool b) {
-    enableAVX512 = b;
-    return *this;
-  }
   LowerVectorToLLVMOptions &setEnableArmNeon(bool b) {
     enableArmNeon = b;
     return *this;
   }
+  LowerVectorToLLVMOptions &setEnableArmSVE(bool b) {
+    enableArmSVE = b;
+    return *this;
+  }
+  LowerVectorToLLVMOptions &setEnableAVX512(bool b) {
+    enableAVX512 = b;
+    return *this;
+  }
 
   bool reassociateFPReductions;
   bool enableIndexOptimizations;
   bool enableArmNeon;
+  bool enableArmSVE;
   bool enableAVX512;
 };
 

diff  --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
new file mode 100644
index 000000000000..b18da2358243
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -0,0 +1,276 @@
+//===-- ArmSVE.td - ArmSVE dialect operation definitions ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the ArmSVE dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMSVE_OPS
+#define ARMSVE_OPS
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// ArmSVE dialect definition
+//===----------------------------------------------------------------------===//
+
+def ArmSVE_Dialect : Dialect {
+  let name = "arm_sve";
+  let cppNamespace = "::mlir::arm_sve";
+  let summary = "Basic dialect to target Arm SVE architectures";
+  let description = [{
+    This dialect contains the definitions necessary to target Arm SVE scalable
+    vector operations, including a scalable vector type and intrinsics for
+    some Arm SVE instructions.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// ArmSVE type definitions
+//===----------------------------------------------------------------------===//
+
+def ArmSVE_ScalableVectorType : DialectType<ArmSVE_Dialect,
+    CPred<"$_self.isa<ScalableVectorType>()">,
+              "scalable vector type">,
+    BuildableType<"$_builder.getType<ScalableVectorType>()"> {
+  let typeDescription = [{
+    `arm_sve.vector` represents vectors that will be processed by a scalable
+    vector architecture.
+  }];
+}
+
+class ArmSVE_Type<string name> : TypeDef<ArmSVE_Dialect, name> { }
+
+def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
+  let mnemonic = "vector";
+
+  let summary = "Scalable vector type";
+
+  let description = [{
+    A type representing scalable length SIMD vectors. Unlike fixed-length SIMD
+    vectors, whose size is constant and known at compile time, scalable
+    vectors' length is constant but determined by the specific hardware at
+    run time.
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t", "Vector shape">:$shape,
+    "Type":$elementType
+  );
+
+  let printer = [{
+    $_printer << "vector<";
+    for (int64_t dim : getShape())
+      $_printer << dim << 'x';
+    $_printer << getElementType() << '>';
+  }];
+
+  let parser = [{
+    VectorType vector;
+    if ($_parser.parseType(vector))
+      return Type();
+    return get(ctxt, vector.getShape(), vector.getElementType());
+  }];
+
+  let extraClassDeclaration = [{
+    bool hasStaticShape() const {
+      return llvm::none_of(getShape(), ShapedType::isDynamic);
+    }
+    int64_t getNumElements() const {
+      assert(hasStaticShape() &&
+             "cannot get element count of dynamic shaped type");
+      ArrayRef<int64_t> shape = getShape();
+      int64_t num = 1;
+      for (auto dim : shape)
+        num *= dim;
+      return num;
+    }
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// ArmSVE type traits
+//===----------------------------------------------------------------------===//
+
+def IsScalableVectorTypePred :
+    CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">;
+
+class ScalableVectorOf<list<Type> allowedTypes> :
+    ContainerType<AnyTypeOf<allowedTypes>, IsScalableVectorTypePred,
+          "$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()",
+          "scalable vector">;
+
+class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
+  And<[IsScalableVectorTypePred,
+       Or<!foreach(allowedlength, allowedLengths, CPred<
+          [{$_self.cast<::mlir::arm_sve::ScalableVectorType>().getNumElements() == }]
+          # allowedlength>)>]>;
+
+class ScalableVectorOfLength<list<int> allowedLengths> : Type<
+  IsScalableVectorOfLengthPred<allowedLengths>,
+  " of length " # StrJoinInt<allowedLengths, "/">.result>;
+
+class ScalableVectorOfLengthAndType<list<int> allowedLengths,
+                                    list<Type> allowedTypes> : Type<
+  And<[ScalableVectorOf<allowedTypes>.predicate,
+       ScalableVectorOfLength<allowedLengths>.predicate]>,
+  ScalableVectorOf<allowedTypes>.description #
+  ScalableVectorOfLength<allowedLengths>.description>;
+
+//===----------------------------------------------------------------------===//
+// ArmSVE op definitions
+//===----------------------------------------------------------------------===//
+
+class ArmSVE_Op<string mnemonic, list<OpTrait> traits = []> :
+  Op<ArmSVE_Dialect, mnemonic, traits> {}
+
+def SdotOp : ArmSVE_Op<"sdot",
+               [NoSideEffect,
+               AllTypesMatch<["src1", "src2"]>,
+               AllTypesMatch<["acc", "dst"]>,
+             ]> {
+  let summary = "Vector-vector dot product and accumulate op";
+  let description = [{
+    SDOT: Signed integer addition of dot product.
+
+    This function maps to the SDOT instruction, and it takes signless integer
+    operands that the operation interprets as signed. It partitions the second
+    and third vector inputs into groups of four elements. They calculate the dot
+    product of each group (without loss of precision) and then add each result
+    to the overlapping element of the first vector input.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports either:
+  //   (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  //   (vector<8xi16>. vector<8xi16>) -> (vector<2xi64>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$acc,
+          ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src1,
+          ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
+def SmmlaOp : ArmSVE_Op<"smmla",
+                [NoSideEffect,
+                AllTypesMatch<["src1", "src2"]>,
+                AllTypesMatch<["acc", "dst"]>,
+              ]> {
+  let summary = "Matrix-matrix mutiply and accumulate op";
+  let description = [{
+    SMMLA: Signed integer matrix multiply-accumulate.
+
+    This function maps to the SMMLA instruction, and it takes signless integer
+    operands that the operation interprets as signed. It partitions the inputs
+    into 128-bit quadwords, with the first input containing a row-by-row 2×2
+    matrix of 32-bit integers, the second input containing a row-by-row 2×8
+    matrix of 8-bit integers, and the third input containing a column-by-column
+    8×2 matrix of 8-bit integers. For each quadword, they multiply the second
+    input matrix by the third input matrix using natural arithmetic and then add
+    the result to the first input using modular arithmetic.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
+def UdotOp : ArmSVE_Op<"udot",
+               [NoSideEffect,
+               AllTypesMatch<["src1", "src2"]>,
+               AllTypesMatch<["acc", "dst"]>,
+             ]> {
+  let summary = "Vector-vector dot product and accumulate op";
+  let description = [{
+    UDOT: Unsigned integer addition of dot product.
+
+    This function maps to the UDOT instruction, and it takes signless integer
+    operands that the operation interprets as unsigned. It partitions the second
+    and third vector inputs into groups of four elements. They calculate the dot
+    product of each group (without loss of precision) and then add each result
+    to the overlapping element of the first vector input.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports either:
+  //   (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  //   (vector<8xi16>. vector<8xi16>) -> (vector<2xi64>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$acc,
+          ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src1,
+          ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
+def UmmlaOp : ArmSVE_Op<"ummla",
+                [NoSideEffect,
+                AllTypesMatch<["src1", "src2"]>,
+                AllTypesMatch<["acc", "dst"]>,
+              ]> {
+  let summary = "Matrix-matrix mutiply and accumulate op";
+  let description = [{
+    UMMLA: Unsigned integer matrix multiply-accumulate.
+
+    This function maps to the UMMLA instruction, and it takes signless integer
+    operands that the operation interprets as unsigned. It partitions the inputs
+    into 128-bit quadwords, with the first input containing a row-by-row 2×2
+    matrix of 32-bit integers, the second input containing a row-by-row 2×8
+    matrix of 8-bit integers, and the third input containing a column-by-column
+    8×2 matrix of 8-bit integers. For each quadword, they multiply the second
+    input matrix by the third input matrix using natural arithmetic and then add
+    the result to the first input using modular arithmetic.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
+def VectorScaleOp : ArmSVE_Op<"vector_scale",
+                 [NoSideEffect]> {
+  let summary = "Load vector scale size";
+  let description = [{
+    The vector_scale op returns the scale of the scalable vectors, a positive
+    integer value that is constant at runtime but unknown at compile time.
+    The scale of the vector indicates the multiplicity of the vectors and
+    vector operations. I.e.: an !arm_sve.vector<4xi32> is equivalent to
+    vector_scale consecutive vector<4xi32>; and an operation on an
+    !arm_sve.vector<4xi32> is equivalent to performing that operation vector_scale
+    times, once on each <4xi32> segment of the scalable vector. The vector_scale
+    op can be used to calculate the step in vector-length agnostic (VLA) loops.
+  }];
+  let results = (outs Index:$res);
+  let assemblyFormat =
+    "attr-dict `:` type($res)";
+}
+
+#endif // ARMSVE_OPS

diff  --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
new file mode 100644
index 000000000000..15ca3403bef0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
@@ -0,0 +1,29 @@
+//===- ArmSVEDialect.h - MLIR Dialect for Arm SVE ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmSVE in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H
+#define MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVE.h.inc"
+
+#endif // MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H

diff  --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
new file mode 100644
index 000000000000..fb50fac68f33
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -0,0 +1 @@
+add_mlir_dialect(ArmSVE arm_sve ArmSVE)

diff  --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 034b611d6288..51b423ee3b98 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_subdirectory(Affine)
 add_subdirectory(Async)
 add_subdirectory(ArmNeon)
+add_subdirectory(ArmSVE)
 add_subdirectory(AVX512)
 add_subdirectory(GPU)
 add_subdirectory(Linalg)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 809e4abe7e84..6bd289097017 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -37,3 +37,9 @@ add_mlir_doc(LLVMArmNeon -gen-dialect-doc LLVMArmNeon Dialects/)
 set(LLVM_TARGET_DEFINITIONS LLVMArmNeon.td)
 mlir_tablegen(LLVMArmNeonConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRLLVMArmNeonConversionsIncGen)
+
+add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE)
+add_mlir_doc(LLVMArmSVE -gen-dialect-doc LLVMArmSve Dialects/)
+set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td)
+mlir_tablegen(LLVMArmSVEConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRLLVMArmSVEConversionsIncGen)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td
new file mode 100644
index 000000000000..05237c86cb05
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td
@@ -0,0 +1,70 @@
+//===-- LLVMArmSVE.td - LLVMARMSVE dialect op definitions --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the LLVMArmSVE dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVMIR_ARMSVE_OPS
+#define LLVMIR_ARMSVE_OPS
+
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// LLVMArmSVE dialect definition
+//===----------------------------------------------------------------------===//
+
+def LLVMArmSVE_Dialect : Dialect {
+  let name = "llvm_arm_sve";
+  let cppNamespace = "::mlir::LLVM";
+}
+
+//----------------------------------------------------------------------------//
+// MLIR LLVM Arm SVE intrinsics using the MLIR LLVM Dialect type system
+//----------------------------------------------------------------------------//
+
+class LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<string mnemonic,
+                                             list<OpTrait> traits =[]> :
+  LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmSVE_Dialect,
+                  /*string opName=*/mnemonic,
+                  /*string enumName=*/mnemonic,
+                  /*list<int> overloadedResults=*/[0],
+                  /*list<int> overloadedOperands=*/[], // defined by result overload
+                  /*list<OpTrait> traits=*/traits,
+                  /*int numResults=*/1>;
+
+class LLVMArmSVE_IntrBinaryOverloadedOp<string mnemonic,
+                                        list<OpTrait> traits = []> :
+  LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmSVE_Dialect,
+                  /*string opName=*/mnemonic,
+                  /*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
+                  /*list<int> overloadedResults=*/[0],
+                  /*list<int> overloadedOperands=*/[], // defined by result overload
+                  /*list<OpTrait> traits=*/traits,
+                  /*int numResults=*/1>;
+
+def LLVM_aarch64_arm_sve_ummla :
+  LLVMArmSVE_IntrBinaryOverloadedOp<"ummla">,
+  Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+def LLVM_aarch64_arm_sve_smmla :
+  LLVMArmSVE_IntrBinaryOverloadedOp<"smmla">,
+  Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+def LLVM_aarch64_arm_sve_sdot :
+  LLVMArmSVE_IntrBinaryOverloadedOp<"sdot">,
+  Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+def LLVM_aarch64_arm_sve_udot :
+  LLVMArmSVE_IntrBinaryOverloadedOp<"udot">,
+  Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+def LLVM_vector_scale :
+  LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">;
+
+#endif // ARMSVE_OPS

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h
new file mode 100644
index 000000000000..f9758c1fa80d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h
@@ -0,0 +1,24 @@
+//===- LLVMSVEDialect.h - MLIR Dialect for LLVMSVE --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for LLVMArmSVE in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
+#define MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMArmSVE.h.inc"
+
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h.inc"
+
+#endif // MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 6d34449e65d4..a541fa4742ba 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -17,10 +17,12 @@
 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
 #include "mlir/Dialect/Async/IR/Async.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -54,6 +56,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   LLVM::LLVMAVX512Dialect,
                   LLVM::LLVMDialect,
                   LLVM::LLVMArmNeonDialect,
+                  LLVM::LLVMArmSVEDialect,
                   linalg::LinalgDialect,
                   scf::SCFDialect,
                   omp::OpenMPDialect,
@@ -62,6 +65,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   quant::QuantizationDialect,
                   spirv::SPIRVDialect,
                   StandardOpsDialect,
+                  arm_sve::ArmSVEDialect,
                   vector::VectorDialect,
                   NVVM::NVVMDialect,
                   ROCDL::ROCDLDialect,

diff  --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h
index cafc931c2d9f..16dd113d14cd 100644
--- a/mlir/include/mlir/InitAllTranslations.h
+++ b/mlir/include/mlir/InitAllTranslations.h
@@ -24,6 +24,7 @@ void registerToNVVMIRTranslation();
 void registerToROCDLIRTranslation();
 void registerArmNeonToLLVMIRTranslation();
 void registerAVX512ToLLVMIRTranslation();
+void registerArmSVEToLLVMIRTranslation();
 
 // This function should be called before creating any MLIRContext if one
 // expects all the possible translations to be made available to the context
@@ -38,6 +39,7 @@ inline void registerAllTranslations() {
     registerToROCDLIRTranslation();
     registerArmNeonToLLVMIRTranslation();
     registerAVX512ToLLVMIRTranslation();
+    registerArmSVEToLLVMIRTranslation();
     return true;
   }();
   (void)initOnce;

diff  --git a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
new file mode 100644
index 000000000000..5742cd790e77
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
@@ -0,0 +1,75 @@
+//===- ArmSVEToLLVM.cpp - Convert ArmSVE to the LLVM dialect --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+using namespace mlir::vector;
+
+using SdotOpLowering =
+    OneToOneConvertToLLVMPattern<SdotOp, LLVM::aarch64_arm_sve_sdot>;
+
+using SmmlaOpLowering =
+    OneToOneConvertToLLVMPattern<SmmlaOp, LLVM::aarch64_arm_sve_smmla>;
+
+using UdotOpLowering =
+    OneToOneConvertToLLVMPattern<UdotOp, LLVM::aarch64_arm_sve_udot>;
+
+using UmmlaOpLowering =
+    OneToOneConvertToLLVMPattern<UmmlaOp, LLVM::aarch64_arm_sve_ummla>;
+
+using VectorScaleOpLowering =
+    OneToOneConvertToLLVMPattern<VectorScaleOp, LLVM::vector_scale>;
+
+// Extract an LLVM IR type from the LLVM IR dialect type.
+static LLVM::LLVMType unwrap(Type type) {
+  if (!type)
+    return nullptr;
+  auto *mlirContext = type.getContext();
+  auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
+  if (!wrappedLLVMType)
+    emitError(UnknownLoc::get(mlirContext),
+              "conversion resulted in a non-LLVM type");
+  return wrappedLLVMType;
+}
+
+static Optional<Type>
+convertScalableVectorTypeToLLVM(ScalableVectorType svType,
+                                LLVMTypeConverter &converter) {
+  auto elementType = unwrap(converter.convertType(svType.getElementType()));
+  if (!elementType)
+    return {};
+
+  auto sVectorType =
+      LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
+  return sVectorType;
+}
+
+/// Populate the given list with patterns that convert from ArmSVE to LLVM.
+void mlir::populateArmSVEToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  converter.addConversion([&converter](ScalableVectorType svType) {
+    return convertScalableVectorTypeToLLVM(svType, converter);
+  });
+  // clang-format off
+  patterns.insert<SdotOpLowering,
+                  SmmlaOpLowering,
+                  UdotOpLowering,
+                  UmmlaOpLowering,
+                  VectorScaleOpLowering>(converter);
+  // clang-format on
+}

diff  --git a/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt
new file mode 100644
index 000000000000..6179d0a10202
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRArmSVEToLLVM
+  ArmSVEToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSVEToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRArmSVE
+  MLIRLLVMArmSVE
+  MLIRLLVMIR
+  MLIRStandardToLLVM
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index a0195486cfd6..421523267df2 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -20,6 +20,7 @@ add_subdirectory(ShapeToStandard)
 add_subdirectory(SPIRVToLLVM)
 add_subdirectory(StandardToLLVM)
 add_subdirectory(StandardToSPIRV)
+add_subdirectory(ArmSVEToLLVM)
 add_subdirectory(VectorToROCDL)
 add_subdirectory(VectorToLLVM)
 add_subdirectory(VectorToSCF)

diff  --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index dd69924166a5..ecd932f99c78 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -26,6 +26,7 @@ class GPUModuleOp;
 
 namespace LLVM {
 class LLVMArmNeonDialect;
+class LLVMArmSVEDialect;
 class LLVMAVX512Dialect;
 class LLVMDialect;
 } // end namespace LLVM

diff  --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 6d7f7aa04d52..f1d662b630b1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -19,6 +19,9 @@ add_mlir_conversion_library(MLIRVectorToLLVM
   MLIRAVX512ToLLVM
   MLIRLLVMArmNeon
   MLIRLLVMAVX512
+  MLIRArmSVE
+  MLIRArmSVEToLLVM
+  MLIRLLVMArmSVE
   MLIRLLVMIR
   MLIRStandardToLLVM
   MLIRTargetLLVMIRModuleTranslation

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 99f0bae05406..af6ce6a0a68c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -12,12 +12,15 @@
 
 #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
 #include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
+#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -32,6 +35,7 @@ struct LowerVectorToLLVMPass
     this->reassociateFPReductions = options.reassociateFPReductions;
     this->enableIndexOptimizations = options.enableIndexOptimizations;
     this->enableArmNeon = options.enableArmNeon;
+    this->enableArmSVE = options.enableArmSVE;
     this->enableAVX512 = options.enableAVX512;
   }
   // Override explicitly to allow conditional dialect dependence.
@@ -39,6 +43,8 @@ struct LowerVectorToLLVMPass
     registry.insert<LLVM::LLVMDialect>();
     if (enableArmNeon)
       registry.insert<LLVM::LLVMArmNeonDialect>();
+    if (enableArmSVE)
+      registry.insert<LLVM::LLVMArmSVEDialect>();
     if (enableAVX512)
       registry.insert<LLVM::LLVMAVX512Dialect>();
   }
@@ -73,6 +79,11 @@ void LowerVectorToLLVMPass::runOnOperation() {
     target.addIllegalDialect<arm_neon::ArmNeonDialect>();
     populateArmNeonToLLVMConversionPatterns(converter, patterns);
   }
+  if (enableArmSVE) {
+    target.addLegalDialect<LLVM::LLVMArmSVEDialect>();
+    target.addIllegalDialect<arm_sve::ArmSVEDialect>();
+    populateArmSVEToLLVMConversionPatterns(converter, patterns);
+  }
   if (enableAVX512) {
     target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
     target.addIllegalDialect<avx512::AVX512Dialect>();

diff  --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
new file mode 100644
index 000000000000..614ab4144574
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRArmSVE
+  IR/ArmSVEDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE
+
+  DEPENDS
+  MLIRArmSVEIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSideEffectInterfaces
+  )

diff  --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
new file mode 100644
index 000000000000..2c76a64fac5c
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -0,0 +1,57 @@
+//===- ArmSVEDialect.cpp - MLIR ArmSVE dialect implementation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the ArmSVE dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+void arm_sve::ArmSVEDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
+      >();
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
+      >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// ScalableVectorType
+//===----------------------------------------------------------------------===//
+
+Type arm_sve::ArmSVEDialect::parseType(DialectAsmParser &parser) const {
+  llvm::SMLoc typeLoc = parser.getCurrentLocation();
+  auto genType = generatedTypeParser(getContext(), parser, "vector");
+  if (genType != Type())
+    return genType;
+  parser.emitError(typeLoc, "unknown type in ArmSVE dialect");
+  return Type();
+}
+
+void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
+  if (failed(generatedTypePrinter(type, os)))
+    llvm_unreachable("unexpected 'arm_sve' type kind");
+}

diff  --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 9fd38aa92df6..ae9afdc70552 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_subdirectory(Affine)
 add_subdirectory(ArmNeon)
+add_subdirectory(ArmSVE)
 add_subdirectory(Async)
 add_subdirectory(AVX512)
 add_subdirectory(GPU)

diff  --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 87ad7e965d2e..cd73e7dcfc69 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -70,6 +70,27 @@ add_mlir_dialect_library(MLIRLLVMArmNeon
   MLIRSideEffectInterfaces
   )
 
+add_mlir_dialect_library(MLIRLLVMArmSVE
+  IR/LLVMArmSVEDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
+
+  DEPENDS
+  MLIRLLVMArmSVEIncGen
+  MLIRLLVMArmSVEConversionsIncGen
+  intrinsics_gen
+
+  LINK_COMPONENTS
+  AsmParser
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMIR
+  MLIRSideEffectInterfaces
+  )
+
 add_mlir_dialect_library(MLIRNVVMIR
   IR/NVVMDialect.cpp
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp
new file mode 100644
index 000000000000..60ef8e6799a3
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp
@@ -0,0 +1,31 @@
+//===- LLVMArmSVEDialect.cpp - MLIR LLVMSVE ops implementation ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the LLVMArmSVE dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/IntrinsicsAArch64.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+void LLVM::LLVMArmSVEDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc"
+      >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc"

diff  --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index 275c1a0f78fc..1b1a02db5511 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -74,6 +74,25 @@ add_mlir_translation_library(MLIRTargetArmNeon
   MLIRTargetLLVMIRModuleTranslation
   )
 
+add_mlir_translation_library(MLIRTargetArmSVE
+  LLVMIR/LLVMArmSVEIntr.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
+
+  DEPENDS
+  MLIRLLVMArmSVEConversionsIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMArmSVE
+  MLIRLLVMIR
+  MLIRTargetLLVMIRModuleTranslation
+  )
+
 add_mlir_translation_library(MLIRTargetNVVMIR
   LLVMIR/ConvertToNVVMIR.cpp
 

diff  --git a/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp
new file mode 100644
index 000000000000..717583a2d8d7
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp
@@ -0,0 +1,63 @@
+//===- LLVMArmSVEIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR LLVM and ArmSVE dialects
+// and LLVM IR with Arm SVE intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Translation.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
+
+using namespace mlir;
+
+namespace {
+class LLVMArmSVEModuleTranslation : public LLVM::ModuleTranslation {
+  friend LLVM::ModuleTranslation;
+
+public:
+  using LLVM::ModuleTranslation::ModuleTranslation;
+
+protected:
+  LogicalResult convertOperation(Operation &opInst,
+                                 llvm::IRBuilder<> &builder) override {
+#include "mlir/Dialect/LLVMIR/LLVMArmSVEConversions.inc"
+
+    return LLVM::ModuleTranslation::convertOperation(opInst, builder);
+  }
+};
+} // end namespace
+
+static std::unique_ptr<llvm::Module>
+translateLLVMArmSVEModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext,
+                                  StringRef name) {
+  return LLVM::ModuleTranslation::translateModule<LLVMArmSVEModuleTranslation>(
+      m, llvmContext, name);
+}
+
+namespace mlir {
+void registerArmSVEToLLVMIRTranslation() {
+  TranslateFromMLIRRegistration reg(
+      "arm-sve-mlir-to-llvmir",
+      [](ModuleOp module, raw_ostream &output) {
+        llvm::LLVMContext llvmContext;
+        auto llvmModule = translateLLVMArmSVEModuleToLLVMIR(
+            module, llvmContext, "LLVMDialectModule");
+        if (!llvmModule)
+          return failure();
+
+        llvmModule->print(output, nullptr);
+        return success();
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<LLVM::LLVMArmSVEDialect, LLVM::LLVMDialect>();
+      });
+}
+} // namespace mlir

diff  --git a/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir
new file mode 100644
index 000000000000..5f218c9f421a
--- /dev/null
+++ b/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s
+
+func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
+                   %b: !arm_sve.vector<16xi8>,
+                   %c: !arm_sve.vector<4xi32>)
+    -> !arm_sve.vector<4xi32> {
+  // CHECK: llvm_arm_sve.sdot
+  %0 = arm_sve.sdot %c, %a, %b :
+               !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+  return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
+                    %b: !arm_sve.vector<16xi8>,
+                    %c: !arm_sve.vector<4xi32>)
+    -> !arm_sve.vector<4xi32> {
+  // CHECK: llvm_arm_sve.smmla
+  %0 = arm_sve.smmla %c, %a, %b :
+               !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+  return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
+                   %b: !arm_sve.vector<16xi8>,
+                   %c: !arm_sve.vector<4xi32>)
+    -> !arm_sve.vector<4xi32> {
+  // CHECK: llvm_arm_sve.udot
+  %0 = arm_sve.udot %c, %a, %b :
+               !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+  return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
+                    %b: !arm_sve.vector<16xi8>,
+                    %c: !arm_sve.vector<4xi32>)
+    -> !arm_sve.vector<4xi32> {
+  // CHECK: llvm_arm_sve.ummla
+  %0 = arm_sve.ummla %c, %a, %b :
+               !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+  return %0 : !arm_sve.vector<4xi32>
+}
+
+func @get_vector_scale() -> index {
+  // CHECK: llvm_arm_sve.vscale
+  %0 = arm_sve.vector_scale : index
+  return %0 : index
+}

diff  --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
new file mode 100644
index 000000000000..8834ef87207b
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
+                   %b: !arm_sve.vector<16xi8>,
+                   %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+  // CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
+  %0 = arm_sve.sdot %c, %a, %b :
+             !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+  return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
+                    %b: !arm_sve.vector<16xi8>,
+                    %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+  // CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
+  %0 = arm_sve.smmla %c, %a, %b :
+             !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+  return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
+                   %b: !arm_sve.vector<16xi8>,
+                   %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+  // CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
+  %0 = arm_sve.udot %c, %a, %b :
+             !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+  return %0 : !arm_sve.vector<4xi32>
+}
+
+func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
+                    %b: !arm_sve.vector<16xi8>,
+                    %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
+  // CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
+  %0 = arm_sve.ummla %c, %a, %b :
+             !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
+  return %0 : !arm_sve.vector<4xi32>
+}
+
+func @get_vector_scale() -> index {
+  // CHECK: arm_sve.vector_scale : index
+  %0 = arm_sve.vector_scale : index
+  return %0 : index
+}

diff  --git a/mlir/test/Target/arm-sve.mlir b/mlir/test/Target/arm-sve.mlir
new file mode 100644
index 000000000000..7340fea34be6
--- /dev/null
+++ b/mlir/test/Target/arm-sve.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --arm-sve-mlir-to-llvmir | FileCheck %s
+
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_sdot
+llvm.func @arm_sve_sdot(%arg0: !llvm.vec<? x 16 x i8>,
+                        %arg1: !llvm.vec<? x 16 x i8>,
+                        %arg2: !llvm.vec<? x 4 x i32>)
+                        -> !llvm.vec<? x 4 x i32> {
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4
+  %0 = "llvm_arm_sve.sdot"(%arg2, %arg0, %arg1) :
+    (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
+        -> !llvm.vec<? x 4 x i32>
+  llvm.return %0 : !llvm.vec<? x 4 x i32>
+}
+
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_smmla
+llvm.func @arm_sve_smmla(%arg0: !llvm.vec<? x 16 x i8>,
+                         %arg1: !llvm.vec<? x 16 x i8>,
+                         %arg2: !llvm.vec<? x 4 x i32>)
+                         -> !llvm.vec<? x 4 x i32> {
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.smmla.nxv4i32(<vscale x 4
+  %0 = "llvm_arm_sve.smmla"(%arg2, %arg0, %arg1) :
+    (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
+        -> !llvm.vec<? x 4 x i32>
+  llvm.return %0 : !llvm.vec<? x 4 x i32>
+}
+
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_udot
+llvm.func @arm_sve_udot(%arg0: !llvm.vec<? x 16 x i8>,
+                        %arg1: !llvm.vec<? x 16 x i8>,
+                        %arg2: !llvm.vec<? x 4 x i32>)
+                        -> !llvm.vec<? x 4 x i32> {
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4
+  %0 = "llvm_arm_sve.udot"(%arg2, %arg0, %arg1) :
+    (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
+        -> !llvm.vec<? x 4 x i32>
+  llvm.return %0 : !llvm.vec<? x 4 x i32>
+}
+
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_ummla
+llvm.func @arm_sve_ummla(%arg0: !llvm.vec<? x 16 x i8>,
+                         %arg1: !llvm.vec<? x 16 x i8>,
+                         %arg2: !llvm.vec<? x 4 x i32>)
+                         -> !llvm.vec<? x 4 x i32> {
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.ummla.nxv4i32(<vscale x 4
+  %0 = "llvm_arm_sve.ummla"(%arg2, %arg0, %arg1) :
+    (!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
+        -> !llvm.vec<? x 4 x i32>
+  llvm.return %0 : !llvm.vec<? x 4 x i32>
+}
+
+// CHECK-LABEL: define i64 @get_vector_scale()
+llvm.func @get_vector_scale() -> !llvm.i64 {
+  // CHECK: call i64 @llvm.vscale.i64()
+  %0 = "llvm_arm_sve.vscale"() : () -> !llvm.i64
+  llvm.return %0 : !llvm.i64
+}


        


More information about the Mlir-commits mailing list