[llvm-branch-commits] [mlir] 7310501 - [mlir][ArmNeon][RFC] Add a Neon dialect

Nicolas Vasilache via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Dec 11 05:55:19 PST 2020


Author: Nicolas Vasilache
Date: 2020-12-11T13:49:40Z
New Revision: 7310501f74037e2845529da7affd8710d058bd04

URL: https://github.com/llvm/llvm-project/commit/7310501f74037e2845529da7affd8710d058bd04
DIFF: https://github.com/llvm/llvm-project/commit/7310501f74037e2845529da7affd8710d058bd04.diff

LOG: [mlir][ArmNeon][RFC] Add a Neon dialect

This revision starts an Arm-specific ArmNeon dialect discussed in the [discourse RFC thread](https://llvm.discourse.group/t/rfc-vector-dialects-neon-and-sve/2284).

Differential Revision: https://reviews.llvm.org/D92171

Added: 
    mlir/include/mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h
    mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
    mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
    mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
    mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeon.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h
    mlir/lib/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.cpp
    mlir/lib/Conversion/ArmNeonToLLVM/CMakeLists.txt
    mlir/lib/Dialect/ArmNeon/CMakeLists.txt
    mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp
    mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp
    mlir/test/Conversion/ArmNeonToLLVM/convert-to-llvm.mlir
    mlir/test/Dialect/ArmNeon/roundtrip.mlir
    mlir/test/Target/arm-neon.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
    mlir/include/mlir/Dialect/CMakeLists.txt
    mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/InitAllDialects.h
    mlir/include/mlir/InitAllTranslations.h
    mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/PassDetail.h
    mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/CMakeLists.txt
    mlir/lib/Dialect/LLVMIR/CMakeLists.txt
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/Target/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h b/mlir/include/mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h
new file mode 100644
index 000000000000..41342c50d5ea
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h
@@ -0,0 +1,23 @@
+//===- ArmNeonToLLVM.h - Conversion Patterns from ArmNeon to LLVM ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARMNEONTOLLVM_ARMNEONTOLLVM_H_
+#define MLIR_CONVERSION_ARMNEONTOLLVM_ARMNEONTOLLVM_H_
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class OwningRewritePatternList;
+
+/// Collect a set of patterns to convert from theArmNeon dialect to LLVM.
+void populateArmNeonToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARMNEONTOLLVM_ARMNEONTOLLVM_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 53158afa0530..56169d90c849 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -396,12 +396,13 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
     operations. The lowering pass provides several options to control
     the kinds of optimizations that are allowed. It also provides options
     that enable the use of one or more architectural-specific dialects
-    (AVX512, Neon, SVE, etc.) in combination with the architectural-neutral
+    (AVX512, ArmNeon, SVE, etc.) in combination with the architectural-neutral
     vector dialect lowering.
 
   }];
   let constructor = "mlir::createConvertVectorToLLVMPass()";
-  let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"];
+  // Override explicitly in C++ to allow conditional dialect dependence.
+  // let dependentDialects;
   let options = [
     Option<"reassociateFPReductions", "reassociate-fp-reductions",
            "bool", /*default=*/"false",
@@ -413,6 +414,10 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
     Option<"enableAVX512", "enable-avx512",
            "bool", /*default=*/"false",
            "Enables the use of AVX512 dialect while lowering the vector "
+	   "dialect.">,
+    Option<"enableArmNeon", "enable-arm-neon",
+           "bool", /*default=*/"false",
+           "Enables the use of ArmNeon dialect while lowering the vector "
 	   "dialect.">
   ];
 }

diff  --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index e6a515aa4564..7ff061cb9d09 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -23,7 +23,7 @@ class OperationPass;
 struct LowerVectorToLLVMOptions {
   LowerVectorToLLVMOptions()
       : reassociateFPReductions(false), enableIndexOptimizations(true),
-        enableAVX512(false) {}
+        enableArmNeon(false), enableAVX512(false) {}
 
   LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
     reassociateFPReductions = b;
@@ -37,9 +37,14 @@ struct LowerVectorToLLVMOptions {
     enableAVX512 = b;
     return *this;
   }
+  LowerVectorToLLVMOptions &setEnableArmNeon(bool b) {
+    enableArmNeon = b;
+    return *this;
+  }
 
   bool reassociateFPReductions;
   bool enableIndexOptimizations;
+  bool enableArmNeon;
   bool enableAVX512;
 };
 

diff  --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
new file mode 100644
index 000000000000..f38049be949f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -0,0 +1,60 @@
+//===-- ArmNeonOps.td - ArmNeon dialect op definitions -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the ArmNeon dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMNEON_OPS
+#define ARMNEON_OPS
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// ArmNeon dialect definition
+//===----------------------------------------------------------------------===//
+
+def ArmNeon_Dialect : Dialect {
+  let name = "arm_neon";
+  let cppNamespace = "::mlir::arm_neon";
+}
+
+//===----------------------------------------------------------------------===//
+// ArmNeon op definitions
+//===----------------------------------------------------------------------===//
+
+class ArmNeon_Op<string mnemonic, list<OpTrait> traits = []> :
+  Op<ArmNeon_Dialect, mnemonic, traits> {}
+
+def SMullOp : ArmNeon_Op<"smull", [NoSideEffect,
+  AllTypesMatch<["a", "b"]>,
+  TypesMatchWith<
+    "res has same vector shape and element bitwidth scaled by 2 as a",
+    "a", "res", "$_self.cast<VectorType>().scaleElementBitwidth(2)">]> {
+  let summary = "smull roundscale op";
+  let description = [{
+    Signed Multiply Long (vector). This instruction multiplies corresponding
+    signed integer values in the lower or upper half of the vectors of the two
+    source SIMD&FP registers, places the results in a vector, and writes the
+    vector to the destination SIMD&FP register.
+
+    Source:
+    https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics
+  }];
+  // Supports either:
+  //   (vector<8xi8>, vector<8xi8>) -> (vector<8xi16>)
+  //   (vector<4xi16>, vector<4xi16>) -> (vector<4xi32>)
+  //   (vector<2xi32>, vector<2xi32>) -> (vector<2xi64>)
+  let arguments = (ins VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$a,
+                       VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$b);
+  let results = (outs VectorOfLengthAndType<[8, 4, 2], [I16, I32, I64]>:$res);
+  let assemblyFormat =
+    "$a `,` $b attr-dict `:` type($a) `to` type($res)";
+}
+
+#endif // ARMNEON_OPS

diff  --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
new file mode 100644
index 000000000000..76153af97689
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
@@ -0,0 +1,25 @@
+//===- ArmNeonDialect.h - MLIR Dialect forArmNeon ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmNeon in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
+#define MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmNeon/ArmNeon.h.inc"
+
+#endif // MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_

diff  --git a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
new file mode 100644
index 000000000000..46c79d373743
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_mlir_dialect(ArmNeon arm_neon)
+add_mlir_doc(ArmNeon -gen-dialect-doc ArmNeon Dialects/)

diff  --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 09c6ae569c18..0df95ea4e937 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_subdirectory(Affine)
 add_subdirectory(Async)
+add_subdirectory(ArmNeon)
 add_subdirectory(AVX512)
 add_subdirectory(GPU)
 add_subdirectory(Linalg)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index cc4fd1bafc72..809e4abe7e84 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -8,25 +8,32 @@ mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
 mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRLLVMOpsIncGen)
 
-add_mlir_dialect(NVVMOps nvvm)
-add_mlir_doc(NVVMOps -gen-dialect-doc NVVMDialect Dialects/)
-add_mlir_dialect(ROCDLOps rocdl)
-add_mlir_doc(ROCDLOps -gen-dialect-doc ROCDLDialect Dialects/)
-
 set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
 mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
 mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions)
 mlir_tablegen(LLVMConversionEnumsFromLLVM.inc -gen-enum-from-llvmir-conversions)
 add_public_tablegen_target(MLIRLLVMConversionsIncGen)
+
+add_mlir_dialect(NVVMOps nvvm)
+add_mlir_doc(NVVMOps -gen-dialect-doc NVVMDialect Dialects/)
 set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
 mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRNVVMConversionsIncGen)
+
+add_mlir_dialect(ROCDLOps rocdl)
+add_mlir_doc(ROCDLOps -gen-dialect-doc ROCDLDialect Dialects/)
 set(LLVM_TARGET_DEFINITIONS ROCDLOps.td)
 mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRROCDLConversionsIncGen)
 
 add_mlir_dialect(LLVMAVX512 llvm_avx512 LLVMAVX512)
-
+add_mlir_doc(LLVMAVX512 -gen-dialect-doc LLVMAVX512 Dialects/)
 set(LLVM_TARGET_DEFINITIONS LLVMAVX512.td)
 mlir_tablegen(LLVMAVX512Conversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRLLVMAVX512ConversionsIncGen)
+
+add_mlir_dialect(LLVMArmNeon llvm_arm_neon LLVMArmNeon)
+add_mlir_doc(LLVMArmNeon -gen-dialect-doc LLVMArmNeon Dialects/)
+set(LLVM_TARGET_DEFINITIONS LLVMArmNeon.td)
+mlir_tablegen(LLVMArmNeonConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRLLVMArmNeonConversionsIncGen)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeon.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeon.td
new file mode 100644
index 000000000000..f15c77451cbe
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeon.td
@@ -0,0 +1,43 @@
+//===-- LLVMArmNeon.td - LLVMArmNeon dialect op definitions *- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the LLVMArmNeon dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVMIR_ARMNEON_OPS
+#define LLVMIR_ARMNEON_OPS
+
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// LLVMArmNeon dialect definition
+//===----------------------------------------------------------------------===//
+
+def LLVMArmNeon_Dialect : Dialect {
+  let name = "llvm_arm_neon";
+  let cppNamespace = "::mlir::LLVM";
+}
+
+//----------------------------------------------------------------------------//
+// MLIR LLVMArmNeon intrinsics using the MLIR LLVM Dialect type system
+//----------------------------------------------------------------------------//
+
+class LLVMArmNeon_IntrBinaryOverloadedOp<string mnemonic, list<OpTrait> traits = []> :
+  LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmNeon_Dialect,
+                  /*string opName=*/mnemonic,
+                  /*string enumName=*/"aarch64_neon_" # !subst(".", "_", mnemonic),
+                  /*list<int> overloadedResults=*/[0],
+                  /*list<int> overloadedOperands=*/[], // defined by result overload
+                  /*list<OpTrait> traits=*/traits,
+                  /*int numResults=*/1>;
+
+def LLVM_aarch64_arm_neon_smull :
+  LLVMArmNeon_IntrBinaryOverloadedOp<"smull">, Arguments<(ins LLVM_Type, LLVM_Type)>;
+
+#endif // ARMNEON_OPS

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h
new file mode 100644
index 000000000000..4fa8c9796f89
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h
@@ -0,0 +1,24 @@
+//===- LLVMArmNeonDialect.h - MLIR Dialect for LLVMArmNeon ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for LLVMArmNeon in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_LLVMARMNEONDIALECT_H_
+#define MLIR_DIALECT_LLVMIR_LLVMARMNEONDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMArmNeon.h.inc"
+
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h.inc"
+
+#endif // MLIR_DIALECT_LLVMIR_LLVMARMNEONDIALECT_H_

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index ea231fb1d27d..8ce5e4045a3a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -150,6 +150,11 @@ class IntegerType
   /// Return true if this is an unsigned integer type.
   bool isUnsigned() const { return getSignedness() == Unsigned; }
 
+  /// Get or create a new IntegerType with the same signedness as `this` and a
+  /// bitwidth scaled by `scale`.
+  /// Return null if the scaled element type cannot be represented.
+  IntegerType scaleElementBitwidth(unsigned scale);
+
   /// Integer representation maximal bitwidth.
   static constexpr unsigned kMaxWidth = 4096;
 };
@@ -174,6 +179,10 @@ class FloatType : public Type {
   /// Return the bitwidth of this float type.
   unsigned getWidth();
 
+  /// Get or create a new FloatType with bitwidth scaled by `scale`.
+  /// Return null if the scaled element type cannot be represented.
+  FloatType scaleElementBitwidth(unsigned scale);
+
   /// Return the floating semantics of this float type.
   const llvm::fltSemantics &getFloatSemantics();
 };
@@ -433,6 +442,11 @@ class VectorType
   }
 
   ArrayRef<int64_t> getShape() const;
+
+  /// Get or create a new VectorType with the same shape as `this` and an
+  /// element type of bitwidth scaled by `scale`.
+  /// Return null if the scaled element type cannot be represented.
+  VectorType scaleElementBitwidth(unsigned scale);
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index f0adcec2e664..3eb9fdd69c6c 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -16,9 +16,11 @@
 
 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/Async/IR/Async.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -44,11 +46,13 @@ inline void registerAllDialects(DialectRegistry &registry) {
   // clang-format off
   registry.insert<acc::OpenACCDialect,
                   AffineDialect,
+                  arm_neon::ArmNeonDialect,
                   async::AsyncDialect,
                   avx512::AVX512Dialect,
                   gpu::GPUDialect,
                   LLVM::LLVMAVX512Dialect,
                   LLVM::LLVMDialect,
+                  LLVM::LLVMArmNeonDialect,
                   linalg::LinalgDialect,
                   scf::SCFDialect,
                   omp::OpenMPDialect,

diff  --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h
index a1771dab144c..cafc931c2d9f 100644
--- a/mlir/include/mlir/InitAllTranslations.h
+++ b/mlir/include/mlir/InitAllTranslations.h
@@ -22,6 +22,7 @@ void registerToLLVMIRTranslation();
 void registerToSPIRVTranslation();
 void registerToNVVMIRTranslation();
 void registerToROCDLIRTranslation();
+void registerArmNeonToLLVMIRTranslation();
 void registerAVX512ToLLVMIRTranslation();
 
 // This function should be called before creating any MLIRContext if one
@@ -35,6 +36,7 @@ inline void registerAllTranslations() {
     registerToSPIRVTranslation();
     registerToNVVMIRTranslation();
     registerToROCDLIRTranslation();
+    registerArmNeonToLLVMIRTranslation();
     registerAVX512ToLLVMIRTranslation();
     return true;
   }();

diff  --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
index 06a19b057f71..9e1603f083ee 100644
--- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
+++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
@@ -22,111 +22,59 @@ using namespace mlir::vector;
 using namespace mlir::avx512;
 
 template <typename OpTy>
-static Type getSrcVectorElementType(OpTy op) {
-  return op.src().getType().template cast<VectorType>().getElementType();
-}
-
-// TODO: Code is currently copy-pasted and adapted from the code
-// 1-1 LLVM conversion. It would better if it were properly exposed in core and
-// reusable.
-/// Basic lowering implementation for one-to-one rewriting from AVX512 Ops to
-/// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass
-/// operands as is, preserve attributes.
-template <typename SourceOp, typename TargetOp>
-static LogicalResult
-matchAndRewriteOneToOne(LLVMTypeConverter &typeConverter, Operation *op,
-                        ArrayRef<Value> operands,
-                        ConversionPatternRewriter &rewriter) {
-  unsigned numResults = op->getNumResults();
-
-  Type packedType;
-  if (numResults != 0) {
-    packedType = typeConverter.packFunctionResults(op->getResultTypes());
-    if (!packedType)
-      return failure();
-  }
-
-  auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
-                                         op->getAttrs());
-
-  // If the operation produced 0 or 1 result, return them immediately.
-  if (numResults == 0)
-    return rewriter.eraseOp(op), success();
-  if (numResults == 1)
-    return rewriter.replaceOp(op, newOp->getResult(0)), success();
-
-  // Otherwise, it had been converted to an operation producing a structure.
-  // Extract individual results from the structure and return them as list.
-  SmallVector<Value, 4> results;
-  results.reserve(numResults);
-  for (unsigned i = 0; i < numResults; ++i) {
-    auto type = typeConverter.convertType(op->getResult(i).getType());
-    results.push_back(rewriter.create<LLVM::ExtractValueOp>(
-        op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
-  }
-  rewriter.replaceOp(op, results);
-  return success();
+static Type getSrcVectorElementType(Operation *op) {
+  return cast<OpTy>(op)
+      .src()
+      .getType()
+      .template cast<VectorType>()
+      .getElementType();
 }
 
 namespace {
-// TODO: Patterns are too verbose due to the fact that we have 1 op (e.g.
-// MaskRndScaleOp) and 
diff erent possible target ops. It would be better to take
-// a Functor so that all these conversions become 1-liners.
-struct MaskRndScaleOpPS512Conversion
-    : public ConvertOpToLLVMPattern<MaskRndScaleOp> {
-  using ConvertOpToLLVMPattern<MaskRndScaleOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(MaskRndScaleOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    if (!getSrcVectorElementType(op).isF32())
-      return failure();
-    return matchAndRewriteOneToOne<MaskRndScaleOp,
-                                   LLVM::x86_avx512_mask_rndscale_ps_512>(
-        *getTypeConverter(), op, operands, rewriter);
-  }
-};
-
-struct MaskRndScaleOpPD512Conversion
-    : public ConvertOpToLLVMPattern<MaskRndScaleOp> {
-  using ConvertOpToLLVMPattern<MaskRndScaleOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(MaskRndScaleOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    if (!getSrcVectorElementType(op).isF64())
-      return failure();
-    return matchAndRewriteOneToOne<MaskRndScaleOp,
-                                   LLVM::x86_avx512_mask_rndscale_pd_512>(
-        *getTypeConverter(), op, operands, rewriter);
-  }
-};
 
-struct ScaleFOpPS512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
-  using ConvertOpToLLVMPattern<MaskScaleFOp>::ConvertOpToLLVMPattern;
+// TODO: turn these into simpler declarative templated patterns when we've had
+// enough.
+struct MaskRndScaleOp512Conversion : public ConvertToLLVMPattern {
+  explicit MaskRndScaleOp512Conversion(MLIRContext *context,
+                                       LLVMTypeConverter &typeConverter)
+      : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
+                             typeConverter) {}
 
   LogicalResult
-  matchAndRewrite(MaskScaleFOp op, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!getSrcVectorElementType(op).isF32())
-      return failure();
-    return matchAndRewriteOneToOne<MaskScaleFOp,
-                                   LLVM::x86_avx512_mask_scalef_ps_512>(
-        *getTypeConverter(), op, operands, rewriter);
+    Type elementType = getSrcVectorElementType<MaskRndScaleOp>(op);
+    if (elementType.isF32())
+      return LLVM::detail::oneToOneRewrite(
+          op, LLVM::x86_avx512_mask_rndscale_ps_512::getOperationName(),
+          operands, *getTypeConverter(), rewriter);
+    if (elementType.isF64())
+      return LLVM::detail::oneToOneRewrite(
+          op, LLVM::x86_avx512_mask_rndscale_pd_512::getOperationName(),
+          operands, *getTypeConverter(), rewriter);
+    return failure();
   }
 };
 
-struct ScaleFOpPD512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
-  using ConvertOpToLLVMPattern<MaskScaleFOp>::ConvertOpToLLVMPattern;
+struct ScaleFOp512Conversion : public ConvertToLLVMPattern {
+  explicit ScaleFOp512Conversion(MLIRContext *context,
+                                 LLVMTypeConverter &typeConverter)
+      : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
+                             typeConverter) {}
 
   LogicalResult
-  matchAndRewrite(MaskScaleFOp op, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!getSrcVectorElementType(op).isF64())
-      return failure();
-    return matchAndRewriteOneToOne<MaskScaleFOp,
-                                   LLVM::x86_avx512_mask_scalef_pd_512>(
-        *getTypeConverter(), op, operands, rewriter);
+    Type elementType = getSrcVectorElementType<MaskScaleFOp>(op);
+    if (elementType.isF32())
+      return LLVM::detail::oneToOneRewrite(
+          op, LLVM::x86_avx512_mask_scalef_ps_512::getOperationName(), operands,
+          *getTypeConverter(), rewriter);
+    if (elementType.isF64())
+      return LLVM::detail::oneToOneRewrite(
+          op, LLVM::x86_avx512_mask_scalef_pd_512::getOperationName(), operands,
+          *getTypeConverter(), rewriter);
+    return failure();
   }
 };
 } // namespace
@@ -135,9 +83,7 @@ struct ScaleFOpPD512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
 void mlir::populateAVX512ToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   // clang-format off
-  patterns.insert<MaskRndScaleOpPS512Conversion,
-                  MaskRndScaleOpPD512Conversion,
-                  ScaleFOpPS512Conversion,
-                  ScaleFOpPD512Conversion>(converter);
+  patterns.insert<MaskRndScaleOp512Conversion,
+                  ScaleFOp512Conversion>(&converter.getContext(), converter);
   // clang-format on
 }

diff  --git a/mlir/lib/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.cpp b/mlir/lib/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.cpp
new file mode 100644
index 000000000000..c3c815591df7
--- /dev/null
+++ b/mlir/lib/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.cpp
@@ -0,0 +1,31 @@
+//===- ArmNeonToLLVM.cpp - ArmNeon to the LLVM dialect --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::arm_neon;
+
+using SMullOpLowering =
+    OneToOneConvertToLLVMPattern<SMullOp, LLVM::aarch64_arm_neon_smull>;
+
+/// Populate the given list with patterns that convert from ArmNeon to LLVM.
+void mlir::populateArmNeonToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  patterns.insert<SMullOpLowering>(converter);
+}

diff  --git a/mlir/lib/Conversion/ArmNeonToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmNeonToLLVM/CMakeLists.txt
new file mode 100644
index 000000000000..de028f6322a6
--- /dev/null
+++ b/mlir/lib/Conversion/ArmNeonToLLVM/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRArmNeonToLLVM
+  ArmNeonToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmNeonToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRArmNeon
+  MLIRLLVMArmNeon
+  MLIRLLVMIR
+  MLIRStandardToLLVM
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index bf1789505882..a0195486cfd6 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(AffineToStandard)
+add_subdirectory(ArmNeonToLLVM)
 add_subdirectory(AsyncToLLVM)
 add_subdirectory(AVX512ToLLVM)
 add_subdirectory(GPUCommon)

diff  --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index 6314a5c91d0b..dd69924166a5 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -25,8 +25,9 @@ class GPUModuleOp;
 } // end namespace gpu
 
 namespace LLVM {
-class LLVMDialect;
+class LLVMArmNeonDialect;
 class LLVMAVX512Dialect;
+class LLVMDialect;
 } // end namespace LLVM
 
 namespace NVVM {

diff  --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index c53839714af7..6d7f7aa04d52 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -13,8 +13,11 @@ add_mlir_conversion_library(MLIRVectorToLLVM
   Core
 
   LINK_LIBS PUBLIC
+  MLIRArmNeon
+  MLIRArmNeonToLLVM
   MLIRAVX512
   MLIRAVX512ToLLVM
+  MLIRLLVMArmNeon
   MLIRLLVMAVX512
   MLIRLLVMIR
   MLIRStandardToLLVM

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 4d5576ec9d9f..99f0bae05406 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -11,10 +11,13 @@
 #include "../PassDetail.h"
 
 #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
+#include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -28,8 +31,17 @@ struct LowerVectorToLLVMPass
   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
     this->reassociateFPReductions = options.reassociateFPReductions;
     this->enableIndexOptimizations = options.enableIndexOptimizations;
+    this->enableArmNeon = options.enableArmNeon;
     this->enableAVX512 = options.enableAVX512;
   }
+  // Override explicitly to allow conditional dialect dependence.
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<LLVM::LLVMDialect>();
+    if (enableArmNeon)
+      registry.insert<LLVM::LLVMArmNeonDialect>();
+    if (enableAVX512)
+      registry.insert<LLVM::LLVMAVX512Dialect>();
+  }
   void runOnOperation() override;
 };
 } // namespace
@@ -56,6 +68,11 @@ void LowerVectorToLLVMPass::runOnOperation() {
 
   // Architecture specific augmentations.
   LLVMConversionTarget target(getContext());
+  if (enableArmNeon) {
+    target.addLegalDialect<LLVM::LLVMArmNeonDialect>();
+    target.addIllegalDialect<arm_neon::ArmNeonDialect>();
+    populateArmNeonToLLVMConversionPatterns(converter, patterns);
+  }
   if (enableAVX512) {
     target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
     target.addIllegalDialect<avx512::AVX512Dialect>();

diff  --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
new file mode 100644
index 000000000000..12dda1aa5fe5
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRArmNeon
+  IR/ArmNeonDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
+
+  DEPENDS
+  MLIRArmNeonIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSideEffectInterfaces
+  )

diff  --git a/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp b/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp
new file mode 100644
index 000000000000..b8b8ebd35e68
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp
@@ -0,0 +1,29 @@
+//===- ArmNeonOps.cpp - MLIRArmNeon ops implementation --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the ArmNeon dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+void arm_neon::ArmNeonDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmNeon/ArmNeon.cpp.inc"
+      >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmNeon/ArmNeon.cpp.inc"

diff  --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index bc44049e2ef6..252b05cf2664 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(Affine)
+add_subdirectory(ArmNeon)
 add_subdirectory(Async)
 add_subdirectory(AVX512)
 add_subdirectory(GPU)

diff  --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index db1a5c4c80aa..87ad7e965d2e 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -49,6 +49,27 @@ add_mlir_dialect_library(MLIRLLVMAVX512
   MLIRSideEffectInterfaces
   )
 
+add_mlir_dialect_library(MLIRLLVMArmNeon
+  IR/LLVMArmNeonDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
+
+  DEPENDS
+  MLIRLLVMArmNeonIncGen
+  MLIRLLVMArmNeonConversionsIncGen
+  intrinsics_gen
+
+  LINK_COMPONENTS
+  AsmParser
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMIR
+  MLIRSideEffectInterfaces
+  )
+
 add_mlir_dialect_library(MLIRNVVMIR
   IR/NVVMDialect.cpp
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp
new file mode 100644
index 000000000000..d05b6584c39f
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp
@@ -0,0 +1,31 @@
+//===- LLVMArmNeonDialect.cpp - MLIR LLVMArmNeon ops implementation -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the LLVMArmNeon dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/IntrinsicsAArch64.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+void LLVM::LLVMArmNeonDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LLVMIR/LLVMArmNeon.cpp.inc"
+      >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMArmNeon.cpp.inc"

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 50a5a64da69e..aa4eba07cf53 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -66,6 +66,12 @@ IntegerType::SignednessSemantics IntegerType::getSignedness() const {
   return getImpl()->signedness;
 }
 
+IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
+  if (!scale)
+    return IntegerType();
+  return IntegerType::get(scale * getWidth(), getSignedness(), getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // Float Type
 //===----------------------------------------------------------------------===//
@@ -93,6 +99,22 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
   llvm_unreachable("non-floating point type used");
 }
 
+FloatType FloatType::scaleElementBitwidth(unsigned scale) {
+  if (!scale)
+    return FloatType();
+  MLIRContext *ctx = getContext();
+  if (isF16() || isBF16()) {
+    if (scale == 2)
+      return FloatType::getF32(ctx);
+    if (scale == 4)
+      return FloatType::getF64(ctx);
+  }
+  if (isF32())
+    if (scale == 2)
+      return FloatType::getF64(ctx);
+  return FloatType();
+}
+
 //===----------------------------------------------------------------------===//
 // FunctionType
 //===----------------------------------------------------------------------===//
@@ -306,6 +328,18 @@ LogicalResult VectorType::verifyConstructionInvariants(Location loc,
 
 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
 
+VectorType VectorType::scaleElementBitwidth(unsigned scale) {
+  if (!scale)
+    return VectorType();
+  if (auto et = getElementType().dyn_cast<IntegerType>())
+    if (auto scaledEt = et.scaleElementBitwidth(scale))
+      return VectorType::get(getShape(), scaledEt);
+  if (auto et = getElementType().dyn_cast<FloatType>())
+    if (auto scaledEt = et.scaleElementBitwidth(scale))
+      return VectorType::get(getShape(), scaledEt);
+  return VectorType();
+}
+
 //===----------------------------------------------------------------------===//
 // TensorType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index cdc9e2db9cd1..96568438a0a3 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -55,6 +55,25 @@ add_mlir_translation_library(MLIRTargetLLVMIR
   MLIRTargetLLVMIRModuleTranslation
   )
 
+add_mlir_translation_library(MLIRTargetArmNeon
+  LLVMIR/LLVMArmNeonIntr.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
+
+  DEPENDS
+  MLIRLLVMArmNeonConversionsIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMArmNeon
+  MLIRLLVMIR
+  MLIRTargetLLVMIRModuleTranslation
+  )
+
 add_mlir_translation_library(MLIRTargetNVVMIR
   LLVMIR/ConvertToNVVMIR.cpp
 

diff  --git a/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp
new file mode 100644
index 000000000000..0bd40ef3c80c
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp
@@ -0,0 +1,63 @@
+//===- ArmNeonIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR LLVM and ArmNeon dialects
+// and LLVM IR with ArmNeon intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Translation.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
+
+using namespace mlir;
+
+namespace {
+class LLVMArmNeonModuleTranslation : public LLVM::ModuleTranslation {
+  friend LLVM::ModuleTranslation;
+
+public:
+  using LLVM::ModuleTranslation::ModuleTranslation;
+
+protected:
+  LogicalResult convertOperation(Operation &opInst,
+                                 llvm::IRBuilder<> &builder) override {
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonConversions.inc"
+
+    return LLVM::ModuleTranslation::convertOperation(opInst, builder);
+  }
+};
+
+std::unique_ptr<llvm::Module>
+translateLLVMArmNeonModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext,
+                                   StringRef name) {
+  return LLVM::ModuleTranslation::translateModule<LLVMArmNeonModuleTranslation>(
+      m, llvmContext, name);
+}
+} // end namespace
+
+namespace mlir {
+void registerArmNeonToLLVMIRTranslation() {
+  TranslateFromMLIRRegistration reg(
+      "arm-neon-mlir-to-llvmir",
+      [](ModuleOp module, raw_ostream &output) {
+        llvm::LLVMContext llvmContext;
+        auto llvmModule = translateLLVMArmNeonModuleToLLVMIR(
+            module, llvmContext, "LLVMDialectModule");
+        if (!llvmModule)
+          return failure();
+
+        llvmModule->print(output, nullptr);
+        return success();
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<LLVM::LLVMArmNeonDialect, LLVM::LLVMDialect>();
+      });
+}
+} // namespace mlir

diff  --git a/mlir/test/Conversion/ArmNeonToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ArmNeonToLLVM/convert-to-llvm.mlir
new file mode 100644
index 000000000000..fe56052fe734
--- /dev/null
+++ b/mlir/test/Conversion/ArmNeonToLLVM/convert-to-llvm.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-neon" | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: arm_neon_smull
+func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
+    -> (vector<8xi16>, vector<4xi32>, vector<2xi64>) {
+  // CHECK: arm_neon.smull{{.*}}: (!llvm.vec<8 x i8>, !llvm.vec<8 x i8>) -> !llvm.vec<8 x i16>
+  %0 = arm_neon.smull %a, %b : vector<8xi8> to vector<8xi16>
+  %00 = vector.extract_strided_slice %0 {offsets = [3], sizes = [4], strides = [1]}:
+    vector<8xi16> to vector<4xi16>
+
+  // CHECK: arm_neon.smull{{.*}}: (!llvm.vec<4 x i16>, !llvm.vec<4 x i16>) -> !llvm.vec<4 x i32>
+  %1 = arm_neon.smull %00, %00 : vector<4xi16> to vector<4xi32>
+  %11 = vector.extract_strided_slice %1 {offsets = [1], sizes = [2], strides = [1]}:
+    vector<4xi32> to vector<2xi32>
+
+  // CHECK: arm_neon.smull{{.*}}: (!llvm.vec<2 x i32>, !llvm.vec<2 x i32>) -> !llvm.vec<2 x i64>
+  %2 = arm_neon.smull %11, %11 : vector<2xi32> to vector<2xi64>
+
+  return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
+}

diff  --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
new file mode 100644
index 000000000000..014da313a089
--- /dev/null
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: arm_neon_smull
+func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
+    -> (vector<8xi16>, vector<4xi32>, vector<2xi64>) {
+  // CHECK: arm_neon.smull {{.*}}: vector<8xi8> to vector<8xi16>
+  %0 = arm_neon.smull %a, %b : vector<8xi8> to vector<8xi16>
+  %00 = vector.extract_strided_slice %0 {offsets = [3], sizes = [4], strides = [1]}:
+    vector<8xi16> to vector<4xi16>
+
+  // CHECK: arm_neon.smull {{.*}}: vector<4xi16> to vector<4xi32>
+  %1 = arm_neon.smull %00, %00 : vector<4xi16> to vector<4xi32>
+  %11 = vector.extract_strided_slice %1 {offsets = [1], sizes = [2], strides = [1]}:
+    vector<4xi32> to vector<2xi32>
+
+  // CHECK: arm_neon.smull {{.*}}: vector<2xi32> to vector<2xi64>
+  %2 = arm_neon.smull %11, %11 : vector<2xi32> to vector<2xi64>
+
+  return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
+}

diff  --git a/mlir/test/Target/arm-neon.mlir b/mlir/test/Target/arm-neon.mlir
new file mode 100644
index 000000000000..955b4aeb40a0
--- /dev/null
+++ b/mlir/test/Target/arm-neon.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate -arm-neon-mlir-to-llvmir | FileCheck %s
+
+// CHECK-LABEL: arm_neon_smull
+llvm.func @arm_neon_smull(%arg0: !llvm.vec<8 x i8>, %arg1: !llvm.vec<8 x i8>) -> !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)> {
+  //      CHECK: %[[V0:.*]] = call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %{{.*}}, <8 x i8> %{{.*}})
+  // CHECK-NEXT: %[[V00:.*]] = shufflevector <8 x i16> %3, <8 x i16> %[[V0]], <4 x i32> <i32 3, i32 4, i32 5, i32 6>
+  %0 = "llvm_arm_neon.smull"(%arg0, %arg1) : (!llvm.vec<8 x i8>, !llvm.vec<8 x i8>) -> !llvm.vec<8 x i16>
+  %1 = llvm.shufflevector %0, %0 [3, 4, 5, 6] : !llvm.vec<8 x i16>, !llvm.vec<8 x i16>
+
+  // CHECK-NEXT: %[[V1:.*]] = call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %[[V00]], <4 x i16> %[[V00]])
+  // CHECK-NEXT: %[[V11:.*]] = shufflevector <4 x i32> %[[V1]], <4 x i32> %[[V1]], <2 x i32> <i32 1, i32 2>
+  %2 = "llvm_arm_neon.smull"(%1, %1) : (!llvm.vec<4 x i16>, !llvm.vec<4 x i16>) -> !llvm.vec<4 x i32>
+  %3 = llvm.shufflevector %2, %2 [1, 2] : !llvm.vec<4 x i32>, !llvm.vec<4 x i32>
+
+  // CHECK-NEXT: %[[V1:.*]] = call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %[[V11]], <2 x i32> %[[V11]])
+  %4 = "llvm_arm_neon.smull"(%3, %3) : (!llvm.vec<2 x i32>, !llvm.vec<2 x i32>) -> !llvm.vec<2 x i64>
+
+  %5 = llvm.mlir.undef : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
+  %6 = llvm.insertvalue %0, %5[0] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
+  %7 = llvm.insertvalue %2, %6[1] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
+  %8 = llvm.insertvalue %4, %7[2] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
+
+  //      CHECK: ret { <8 x i16>, <4 x i32>, <2 x i64> }
+  llvm.return %8 : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
+}


        


More information about the llvm-branch-commits mailing list