[Mlir-commits] [mlir] 7310501 - [mlir][ArmNeon][RFC] Add a Neon dialect
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Dec 11 05:50:46 PST 2020
Author: Nicolas Vasilache
Date: 2020-12-11T13:49:40Z
New Revision: 7310501f74037e2845529da7affd8710d058bd04
URL: https://github.com/llvm/llvm-project/commit/7310501f74037e2845529da7affd8710d058bd04
DIFF: https://github.com/llvm/llvm-project/commit/7310501f74037e2845529da7affd8710d058bd04.diff
LOG: [mlir][ArmNeon][RFC] Add a Neon dialect
This revision starts an Arm-specific ArmNeon dialect discussed in the [discourse RFC thread](https://llvm.discourse.group/t/rfc-vector-dialects-neon-and-sve/2284).
Differential Revision: https://reviews.llvm.org/D92171
Added:
mlir/include/mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h
mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeon.td
mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h
mlir/lib/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.cpp
mlir/lib/Conversion/ArmNeonToLLVM/CMakeLists.txt
mlir/lib/Dialect/ArmNeon/CMakeLists.txt
mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp
mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp
mlir/test/Conversion/ArmNeonToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/ArmNeon/roundtrip.mlir
mlir/test/Target/arm-neon.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
mlir/include/mlir/Dialect/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/InitAllTranslations.h
mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/PassDetail.h
mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/CMakeLists.txt
mlir/lib/Dialect/LLVMIR/CMakeLists.txt
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/Target/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h b/mlir/include/mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h
new file mode 100644
index 000000000000..41342c50d5ea
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h
@@ -0,0 +1,23 @@
+//===- ArmNeonToLLVM.h - Conversion Patterns from ArmNeon to LLVM ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARMNEONTOLLVM_ARMNEONTOLLVM_H_
+#define MLIR_CONVERSION_ARMNEONTOLLVM_ARMNEONTOLLVM_H_
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class OwningRewritePatternList;
+
+/// Collect a set of patterns to convert from theArmNeon dialect to LLVM.
+void populateArmNeonToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARMNEONTOLLVM_ARMNEONTOLLVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 53158afa0530..56169d90c849 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -396,12 +396,13 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
operations. The lowering pass provides several options to control
the kinds of optimizations that are allowed. It also provides options
that enable the use of one or more architectural-specific dialects
- (AVX512, Neon, SVE, etc.) in combination with the architectural-neutral
+ (AVX512, ArmNeon, SVE, etc.) in combination with the architectural-neutral
vector dialect lowering.
}];
let constructor = "mlir::createConvertVectorToLLVMPass()";
- let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"];
+ // Override explicitly in C++ to allow conditional dialect dependence.
+ // let dependentDialects;
let options = [
Option<"reassociateFPReductions", "reassociate-fp-reductions",
"bool", /*default=*/"false",
@@ -413,6 +414,10 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
Option<"enableAVX512", "enable-avx512",
"bool", /*default=*/"false",
"Enables the use of AVX512 dialect while lowering the vector "
+ "dialect.">,
+ Option<"enableArmNeon", "enable-arm-neon",
+ "bool", /*default=*/"false",
+ "Enables the use of ArmNeon dialect while lowering the vector "
"dialect.">
];
}
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index e6a515aa4564..7ff061cb9d09 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -23,7 +23,7 @@ class OperationPass;
struct LowerVectorToLLVMOptions {
LowerVectorToLLVMOptions()
: reassociateFPReductions(false), enableIndexOptimizations(true),
- enableAVX512(false) {}
+ enableArmNeon(false), enableAVX512(false) {}
LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
reassociateFPReductions = b;
@@ -37,9 +37,14 @@ struct LowerVectorToLLVMOptions {
enableAVX512 = b;
return *this;
}
+ LowerVectorToLLVMOptions &setEnableArmNeon(bool b) {
+ enableArmNeon = b;
+ return *this;
+ }
bool reassociateFPReductions;
bool enableIndexOptimizations;
+ bool enableArmNeon;
bool enableAVX512;
};
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
new file mode 100644
index 000000000000..f38049be949f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -0,0 +1,60 @@
+//===-- ArmNeonOps.td - ArmNeon dialect op definitions -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the ArmNeon dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMNEON_OPS
+#define ARMNEON_OPS
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// ArmNeon dialect definition
+//===----------------------------------------------------------------------===//
+
+def ArmNeon_Dialect : Dialect {
+ let name = "arm_neon";
+ let cppNamespace = "::mlir::arm_neon";
+}
+
+//===----------------------------------------------------------------------===//
+// ArmNeon op definitions
+//===----------------------------------------------------------------------===//
+
+class ArmNeon_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<ArmNeon_Dialect, mnemonic, traits> {}
+
+def SMullOp : ArmNeon_Op<"smull", [NoSideEffect,
+ AllTypesMatch<["a", "b"]>,
+ TypesMatchWith<
+ "res has same vector shape and element bitwidth scaled by 2 as a",
+ "a", "res", "$_self.cast<VectorType>().scaleElementBitwidth(2)">]> {
+ let summary = "smull roundscale op";
+ let description = [{
+ Signed Multiply Long (vector). This instruction multiplies corresponding
+ signed integer values in the lower or upper half of the vectors of the two
+ source SIMD&FP registers, places the results in a vector, and writes the
+ vector to the destination SIMD&FP register.
+
+ Source:
+ https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics
+ }];
+ // Supports either:
+ // (vector<8xi8>, vector<8xi8>) -> (vector<8xi16>)
+ // (vector<4xi16>, vector<4xi16>) -> (vector<4xi32>)
+ // (vector<2xi32>, vector<2xi32>) -> (vector<2xi64>)
+ let arguments = (ins VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$a,
+ VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$b);
+ let results = (outs VectorOfLengthAndType<[8, 4, 2], [I16, I32, I64]>:$res);
+ let assemblyFormat =
+ "$a `,` $b attr-dict `:` type($a) `to` type($res)";
+}
+
+#endif // ARMNEON_OPS
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
new file mode 100644
index 000000000000..76153af97689
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h
@@ -0,0 +1,25 @@
+//===- ArmNeonDialect.h - MLIR Dialect forArmNeon ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmNeon in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
+#define MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmNeon/ArmNeon.h.inc"
+
+#endif // MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_
diff --git a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
new file mode 100644
index 000000000000..46c79d373743
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_mlir_dialect(ArmNeon arm_neon)
+add_mlir_doc(ArmNeon -gen-dialect-doc ArmNeon Dialects/)
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 09c6ae569c18..0df95ea4e937 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -1,5 +1,6 @@
add_subdirectory(Affine)
add_subdirectory(Async)
+add_subdirectory(ArmNeon)
add_subdirectory(AVX512)
add_subdirectory(GPU)
add_subdirectory(Linalg)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index cc4fd1bafc72..809e4abe7e84 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -8,25 +8,32 @@ mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRLLVMOpsIncGen)
-add_mlir_dialect(NVVMOps nvvm)
-add_mlir_doc(NVVMOps -gen-dialect-doc NVVMDialect Dialects/)
-add_mlir_dialect(ROCDLOps rocdl)
-add_mlir_doc(ROCDLOps -gen-dialect-doc ROCDLDialect Dialects/)
-
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions)
mlir_tablegen(LLVMConversionEnumsFromLLVM.inc -gen-enum-from-llvmir-conversions)
add_public_tablegen_target(MLIRLLVMConversionsIncGen)
+
+add_mlir_dialect(NVVMOps nvvm)
+add_mlir_doc(NVVMOps -gen-dialect-doc NVVMDialect Dialects/)
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRNVVMConversionsIncGen)
+
+add_mlir_dialect(ROCDLOps rocdl)
+add_mlir_doc(ROCDLOps -gen-dialect-doc ROCDLDialect Dialects/)
set(LLVM_TARGET_DEFINITIONS ROCDLOps.td)
mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRROCDLConversionsIncGen)
add_mlir_dialect(LLVMAVX512 llvm_avx512 LLVMAVX512)
-
+add_mlir_doc(LLVMAVX512 -gen-dialect-doc LLVMAVX512 Dialects/)
set(LLVM_TARGET_DEFINITIONS LLVMAVX512.td)
mlir_tablegen(LLVMAVX512Conversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRLLVMAVX512ConversionsIncGen)
+
+add_mlir_dialect(LLVMArmNeon llvm_arm_neon LLVMArmNeon)
+add_mlir_doc(LLVMArmNeon -gen-dialect-doc LLVMArmNeon Dialects/)
+set(LLVM_TARGET_DEFINITIONS LLVMArmNeon.td)
+mlir_tablegen(LLVMArmNeonConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRLLVMArmNeonConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeon.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeon.td
new file mode 100644
index 000000000000..f15c77451cbe
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeon.td
@@ -0,0 +1,43 @@
+//===-- LLVMArmNeon.td - LLVMArmNeon dialect op definitions *- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the LLVMArmNeon dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVMIR_ARMNEON_OPS
+#define LLVMIR_ARMNEON_OPS
+
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// LLVMArmNeon dialect definition
+//===----------------------------------------------------------------------===//
+
+def LLVMArmNeon_Dialect : Dialect {
+ let name = "llvm_arm_neon";
+ let cppNamespace = "::mlir::LLVM";
+}
+
+//----------------------------------------------------------------------------//
+// MLIR LLVMArmNeon intrinsics using the MLIR LLVM Dialect type system
+//----------------------------------------------------------------------------//
+
+class LLVMArmNeon_IntrBinaryOverloadedOp<string mnemonic, list<OpTrait> traits = []> :
+ LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmNeon_Dialect,
+ /*string opName=*/mnemonic,
+ /*string enumName=*/"aarch64_neon_" # !subst(".", "_", mnemonic),
+ /*list<int> overloadedResults=*/[0],
+ /*list<int> overloadedOperands=*/[], // defined by result overload
+ /*list<OpTrait> traits=*/traits,
+ /*int numResults=*/1>;
+
+def LLVM_aarch64_arm_neon_smull :
+ LLVMArmNeon_IntrBinaryOverloadedOp<"smull">, Arguments<(ins LLVM_Type, LLVM_Type)>;
+
+#endif // ARMNEON_OPS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h
new file mode 100644
index 000000000000..4fa8c9796f89
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h
@@ -0,0 +1,24 @@
+//===- LLVMArmNeonDialect.h - MLIR Dialect for LLVMArmNeon ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for LLVMArmNeon in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_LLVMARMNEONDIALECT_H_
+#define MLIR_DIALECT_LLVMIR_LLVMARMNEONDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMArmNeon.h.inc"
+
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h.inc"
+
+#endif // MLIR_DIALECT_LLVMIR_LLVMARMNEONDIALECT_H_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index ea231fb1d27d..8ce5e4045a3a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -150,6 +150,11 @@ class IntegerType
/// Return true if this is an unsigned integer type.
bool isUnsigned() const { return getSignedness() == Unsigned; }
+ /// Get or create a new IntegerType with the same signedness as `this` and a
+ /// bitwidth scaled by `scale`.
+ /// Return null if the scaled element type cannot be represented.
+ IntegerType scaleElementBitwidth(unsigned scale);
+
/// Integer representation maximal bitwidth.
static constexpr unsigned kMaxWidth = 4096;
};
@@ -174,6 +179,10 @@ class FloatType : public Type {
/// Return the bitwidth of this float type.
unsigned getWidth();
+ /// Get or create a new FloatType with bitwidth scaled by `scale`.
+ /// Return null if the scaled element type cannot be represented.
+ FloatType scaleElementBitwidth(unsigned scale);
+
/// Return the floating semantics of this float type.
const llvm::fltSemantics &getFloatSemantics();
};
@@ -433,6 +442,11 @@ class VectorType
}
ArrayRef<int64_t> getShape() const;
+
+ /// Get or create a new VectorType with the same shape as `this` and an
+ /// element type of bitwidth scaled by `scale`.
+ /// Return null if the scaled element type cannot be represented.
+ VectorType scaleElementBitwidth(unsigned scale);
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index f0adcec2e664..3eb9fdd69c6c 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -16,9 +16,11 @@
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -44,11 +46,13 @@ inline void registerAllDialects(DialectRegistry ®istry) {
// clang-format off
registry.insert<acc::OpenACCDialect,
AffineDialect,
+ arm_neon::ArmNeonDialect,
async::AsyncDialect,
avx512::AVX512Dialect,
gpu::GPUDialect,
LLVM::LLVMAVX512Dialect,
LLVM::LLVMDialect,
+ LLVM::LLVMArmNeonDialect,
linalg::LinalgDialect,
scf::SCFDialect,
omp::OpenMPDialect,
diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h
index a1771dab144c..cafc931c2d9f 100644
--- a/mlir/include/mlir/InitAllTranslations.h
+++ b/mlir/include/mlir/InitAllTranslations.h
@@ -22,6 +22,7 @@ void registerToLLVMIRTranslation();
void registerToSPIRVTranslation();
void registerToNVVMIRTranslation();
void registerToROCDLIRTranslation();
+void registerArmNeonToLLVMIRTranslation();
void registerAVX512ToLLVMIRTranslation();
// This function should be called before creating any MLIRContext if one
@@ -35,6 +36,7 @@ inline void registerAllTranslations() {
registerToSPIRVTranslation();
registerToNVVMIRTranslation();
registerToROCDLIRTranslation();
+ registerArmNeonToLLVMIRTranslation();
registerAVX512ToLLVMIRTranslation();
return true;
}();
diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
index 06a19b057f71..9e1603f083ee 100644
--- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
+++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
@@ -22,111 +22,59 @@ using namespace mlir::vector;
using namespace mlir::avx512;
template <typename OpTy>
-static Type getSrcVectorElementType(OpTy op) {
- return op.src().getType().template cast<VectorType>().getElementType();
-}
-
-// TODO: Code is currently copy-pasted and adapted from the code
-// 1-1 LLVM conversion. It would better if it were properly exposed in core and
-// reusable.
-/// Basic lowering implementation for one-to-one rewriting from AVX512 Ops to
-/// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass
-/// operands as is, preserve attributes.
-template <typename SourceOp, typename TargetOp>
-static LogicalResult
-matchAndRewriteOneToOne(LLVMTypeConverter &typeConverter, Operation *op,
- ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) {
- unsigned numResults = op->getNumResults();
-
- Type packedType;
- if (numResults != 0) {
- packedType = typeConverter.packFunctionResults(op->getResultTypes());
- if (!packedType)
- return failure();
- }
-
- auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
- op->getAttrs());
-
- // If the operation produced 0 or 1 result, return them immediately.
- if (numResults == 0)
- return rewriter.eraseOp(op), success();
- if (numResults == 1)
- return rewriter.replaceOp(op, newOp->getResult(0)), success();
-
- // Otherwise, it had been converted to an operation producing a structure.
- // Extract individual results from the structure and return them as list.
- SmallVector<Value, 4> results;
- results.reserve(numResults);
- for (unsigned i = 0; i < numResults; ++i) {
- auto type = typeConverter.convertType(op->getResult(i).getType());
- results.push_back(rewriter.create<LLVM::ExtractValueOp>(
- op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
- }
- rewriter.replaceOp(op, results);
- return success();
+static Type getSrcVectorElementType(Operation *op) {
+ return cast<OpTy>(op)
+ .src()
+ .getType()
+ .template cast<VectorType>()
+ .getElementType();
}
namespace {
-// TODO: Patterns are too verbose due to the fact that we have 1 op (e.g.
-// MaskRndScaleOp) and
diff erent possible target ops. It would be better to take
-// a Functor so that all these conversions become 1-liners.
-struct MaskRndScaleOpPS512Conversion
- : public ConvertOpToLLVMPattern<MaskRndScaleOp> {
- using ConvertOpToLLVMPattern<MaskRndScaleOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(MaskRndScaleOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- if (!getSrcVectorElementType(op).isF32())
- return failure();
- return matchAndRewriteOneToOne<MaskRndScaleOp,
- LLVM::x86_avx512_mask_rndscale_ps_512>(
- *getTypeConverter(), op, operands, rewriter);
- }
-};
-
-struct MaskRndScaleOpPD512Conversion
- : public ConvertOpToLLVMPattern<MaskRndScaleOp> {
- using ConvertOpToLLVMPattern<MaskRndScaleOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(MaskRndScaleOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- if (!getSrcVectorElementType(op).isF64())
- return failure();
- return matchAndRewriteOneToOne<MaskRndScaleOp,
- LLVM::x86_avx512_mask_rndscale_pd_512>(
- *getTypeConverter(), op, operands, rewriter);
- }
-};
-struct ScaleFOpPS512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
- using ConvertOpToLLVMPattern<MaskScaleFOp>::ConvertOpToLLVMPattern;
+// TODO: turn these into simpler declarative templated patterns when we've had
+// enough.
+struct MaskRndScaleOp512Conversion : public ConvertToLLVMPattern {
+ explicit MaskRndScaleOp512Conversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
+ typeConverter) {}
LogicalResult
- matchAndRewrite(MaskScaleFOp op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!getSrcVectorElementType(op).isF32())
- return failure();
- return matchAndRewriteOneToOne<MaskScaleFOp,
- LLVM::x86_avx512_mask_scalef_ps_512>(
- *getTypeConverter(), op, operands, rewriter);
+ Type elementType = getSrcVectorElementType<MaskRndScaleOp>(op);
+ if (elementType.isF32())
+ return LLVM::detail::oneToOneRewrite(
+ op, LLVM::x86_avx512_mask_rndscale_ps_512::getOperationName(),
+ operands, *getTypeConverter(), rewriter);
+ if (elementType.isF64())
+ return LLVM::detail::oneToOneRewrite(
+ op, LLVM::x86_avx512_mask_rndscale_pd_512::getOperationName(),
+ operands, *getTypeConverter(), rewriter);
+ return failure();
}
};
-struct ScaleFOpPD512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
- using ConvertOpToLLVMPattern<MaskScaleFOp>::ConvertOpToLLVMPattern;
+struct ScaleFOp512Conversion : public ConvertToLLVMPattern {
+ explicit ScaleFOp512Conversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
+ typeConverter) {}
LogicalResult
- matchAndRewrite(MaskScaleFOp op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!getSrcVectorElementType(op).isF64())
- return failure();
- return matchAndRewriteOneToOne<MaskScaleFOp,
- LLVM::x86_avx512_mask_scalef_pd_512>(
- *getTypeConverter(), op, operands, rewriter);
+ Type elementType = getSrcVectorElementType<MaskScaleFOp>(op);
+ if (elementType.isF32())
+ return LLVM::detail::oneToOneRewrite(
+ op, LLVM::x86_avx512_mask_scalef_ps_512::getOperationName(), operands,
+ *getTypeConverter(), rewriter);
+ if (elementType.isF64())
+ return LLVM::detail::oneToOneRewrite(
+ op, LLVM::x86_avx512_mask_scalef_pd_512::getOperationName(), operands,
+ *getTypeConverter(), rewriter);
+ return failure();
}
};
} // namespace
@@ -135,9 +83,7 @@ struct ScaleFOpPD512Conversion : public ConvertOpToLLVMPattern<MaskScaleFOp> {
void mlir::populateAVX512ToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// clang-format off
- patterns.insert<MaskRndScaleOpPS512Conversion,
- MaskRndScaleOpPD512Conversion,
- ScaleFOpPS512Conversion,
- ScaleFOpPD512Conversion>(converter);
+ patterns.insert<MaskRndScaleOp512Conversion,
+ ScaleFOp512Conversion>(&converter.getContext(), converter);
// clang-format on
}
diff --git a/mlir/lib/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.cpp b/mlir/lib/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.cpp
new file mode 100644
index 000000000000..c3c815591df7
--- /dev/null
+++ b/mlir/lib/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.cpp
@@ -0,0 +1,31 @@
+//===- ArmNeonToLLVM.cpp - ArmNeon to the LLVM dialect --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::arm_neon;
+
+using SMullOpLowering =
+ OneToOneConvertToLLVMPattern<SMullOp, LLVM::aarch64_arm_neon_smull>;
+
+/// Populate the given list with patterns that convert from ArmNeon to LLVM.
+void mlir::populateArmNeonToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ patterns.insert<SMullOpLowering>(converter);
+}
diff --git a/mlir/lib/Conversion/ArmNeonToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmNeonToLLVM/CMakeLists.txt
new file mode 100644
index 000000000000..de028f6322a6
--- /dev/null
+++ b/mlir/lib/Conversion/ArmNeonToLLVM/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRArmNeonToLLVM
+ ArmNeonToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmNeonToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArmNeon
+ MLIRLLVMArmNeon
+ MLIRLLVMIR
+ MLIRStandardToLLVM
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index bf1789505882..a0195486cfd6 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(AffineToStandard)
+add_subdirectory(ArmNeonToLLVM)
add_subdirectory(AsyncToLLVM)
add_subdirectory(AVX512ToLLVM)
add_subdirectory(GPUCommon)
diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index 6314a5c91d0b..dd69924166a5 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -25,8 +25,9 @@ class GPUModuleOp;
} // end namespace gpu
namespace LLVM {
-class LLVMDialect;
+class LLVMArmNeonDialect;
class LLVMAVX512Dialect;
+class LLVMDialect;
} // end namespace LLVM
namespace NVVM {
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index c53839714af7..6d7f7aa04d52 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -13,8 +13,11 @@ add_mlir_conversion_library(MLIRVectorToLLVM
Core
LINK_LIBS PUBLIC
+ MLIRArmNeon
+ MLIRArmNeonToLLVM
MLIRAVX512
MLIRAVX512ToLLVM
+ MLIRLLVMArmNeon
MLIRLLVMAVX512
MLIRLLVMIR
MLIRStandardToLLVM
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 4d5576ec9d9f..99f0bae05406 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -11,10 +11,13 @@
#include "../PassDetail.h"
#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
+#include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -28,8 +31,17 @@ struct LowerVectorToLLVMPass
LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
this->reassociateFPReductions = options.reassociateFPReductions;
this->enableIndexOptimizations = options.enableIndexOptimizations;
+ this->enableArmNeon = options.enableArmNeon;
this->enableAVX512 = options.enableAVX512;
}
+ // Override explicitly to allow conditional dialect dependence.
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<LLVM::LLVMDialect>();
+ if (enableArmNeon)
+ registry.insert<LLVM::LLVMArmNeonDialect>();
+ if (enableAVX512)
+ registry.insert<LLVM::LLVMAVX512Dialect>();
+ }
void runOnOperation() override;
};
} // namespace
@@ -56,6 +68,11 @@ void LowerVectorToLLVMPass::runOnOperation() {
// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
+ if (enableArmNeon) {
+ target.addLegalDialect<LLVM::LLVMArmNeonDialect>();
+ target.addIllegalDialect<arm_neon::ArmNeonDialect>();
+ populateArmNeonToLLVMConversionPatterns(converter, patterns);
+ }
if (enableAVX512) {
target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
target.addIllegalDialect<avx512::AVX512Dialect>();
diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
new file mode 100644
index 000000000000..12dda1aa5fe5
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRArmNeon
+ IR/ArmNeonDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon
+
+ DEPENDS
+ MLIRArmNeonIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSideEffectInterfaces
+ )
diff --git a/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp b/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp
new file mode 100644
index 000000000000..b8b8ebd35e68
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp
@@ -0,0 +1,29 @@
+//===- ArmNeonOps.cpp - MLIRArmNeon ops implementation --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the ArmNeon dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+void arm_neon::ArmNeonDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmNeon/ArmNeon.cpp.inc"
+ >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmNeon/ArmNeon.cpp.inc"
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index bc44049e2ef6..252b05cf2664 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(Affine)
+add_subdirectory(ArmNeon)
add_subdirectory(Async)
add_subdirectory(AVX512)
add_subdirectory(GPU)
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index db1a5c4c80aa..87ad7e965d2e 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -49,6 +49,27 @@ add_mlir_dialect_library(MLIRLLVMAVX512
MLIRSideEffectInterfaces
)
+add_mlir_dialect_library(MLIRLLVMArmNeon
+ IR/LLVMArmNeonDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
+
+ DEPENDS
+ MLIRLLVMArmNeonIncGen
+ MLIRLLVMArmNeonConversionsIncGen
+ intrinsics_gen
+
+ LINK_COMPONENTS
+ AsmParser
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMIR
+ MLIRSideEffectInterfaces
+ )
+
add_mlir_dialect_library(MLIRNVVMIR
IR/NVVMDialect.cpp
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp
new file mode 100644
index 000000000000..d05b6584c39f
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp
@@ -0,0 +1,31 @@
+//===- LLVMArmNeonDialect.cpp - MLIR LLVMArmNeon ops implementation -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the LLVMArmNeon dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/IntrinsicsAArch64.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+void LLVM::LLVMArmNeonDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LLVMIR/LLVMArmNeon.cpp.inc"
+ >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMArmNeon.cpp.inc"
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 50a5a64da69e..aa4eba07cf53 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -66,6 +66,12 @@ IntegerType::SignednessSemantics IntegerType::getSignedness() const {
return getImpl()->signedness;
}
+IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
+ if (!scale)
+ return IntegerType();
+ return IntegerType::get(scale * getWidth(), getSignedness(), getContext());
+}
+
//===----------------------------------------------------------------------===//
// Float Type
//===----------------------------------------------------------------------===//
@@ -93,6 +99,22 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
llvm_unreachable("non-floating point type used");
}
+FloatType FloatType::scaleElementBitwidth(unsigned scale) {
+ if (!scale)
+ return FloatType();
+ MLIRContext *ctx = getContext();
+ if (isF16() || isBF16()) {
+ if (scale == 2)
+ return FloatType::getF32(ctx);
+ if (scale == 4)
+ return FloatType::getF64(ctx);
+ }
+ if (isF32())
+ if (scale == 2)
+ return FloatType::getF64(ctx);
+ return FloatType();
+}
+
//===----------------------------------------------------------------------===//
// FunctionType
//===----------------------------------------------------------------------===//
@@ -306,6 +328,18 @@ LogicalResult VectorType::verifyConstructionInvariants(Location loc,
ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
+VectorType VectorType::scaleElementBitwidth(unsigned scale) {
+ if (!scale)
+ return VectorType();
+ if (auto et = getElementType().dyn_cast<IntegerType>())
+ if (auto scaledEt = et.scaleElementBitwidth(scale))
+ return VectorType::get(getShape(), scaledEt);
+ if (auto et = getElementType().dyn_cast<FloatType>())
+ if (auto scaledEt = et.scaleElementBitwidth(scale))
+ return VectorType::get(getShape(), scaledEt);
+ return VectorType();
+}
+
//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index cdc9e2db9cd1..96568438a0a3 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -55,6 +55,25 @@ add_mlir_translation_library(MLIRTargetLLVMIR
MLIRTargetLLVMIRModuleTranslation
)
+add_mlir_translation_library(MLIRTargetArmNeon
+ LLVMIR/LLVMArmNeonIntr.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
+
+ DEPENDS
+ MLIRLLVMArmNeonConversionsIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMArmNeon
+ MLIRLLVMIR
+ MLIRTargetLLVMIRModuleTranslation
+ )
+
add_mlir_translation_library(MLIRTargetNVVMIR
LLVMIR/ConvertToNVVMIR.cpp
diff --git a/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp
new file mode 100644
index 000000000000..0bd40ef3c80c
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp
@@ -0,0 +1,63 @@
+//===- ArmNeonIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR LLVM and ArmNeon dialects
+// and LLVM IR with ArmNeon intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Translation.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
+
+using namespace mlir;
+
+namespace {
+class LLVMArmNeonModuleTranslation : public LLVM::ModuleTranslation {
+ friend LLVM::ModuleTranslation;
+
+public:
+ using LLVM::ModuleTranslation::ModuleTranslation;
+
+protected:
+ LogicalResult convertOperation(Operation &opInst,
+ llvm::IRBuilder<> &builder) override {
+#include "mlir/Dialect/LLVMIR/LLVMArmNeonConversions.inc"
+
+ return LLVM::ModuleTranslation::convertOperation(opInst, builder);
+ }
+};
+
+std::unique_ptr<llvm::Module>
+translateLLVMArmNeonModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext,
+ StringRef name) {
+ return LLVM::ModuleTranslation::translateModule<LLVMArmNeonModuleTranslation>(
+ m, llvmContext, name);
+}
+} // end namespace
+
+namespace mlir {
+void registerArmNeonToLLVMIRTranslation() {
+ TranslateFromMLIRRegistration reg(
+ "arm-neon-mlir-to-llvmir",
+ [](ModuleOp module, raw_ostream &output) {
+ llvm::LLVMContext llvmContext;
+ auto llvmModule = translateLLVMArmNeonModuleToLLVMIR(
+ module, llvmContext, "LLVMDialectModule");
+ if (!llvmModule)
+ return failure();
+
+ llvmModule->print(output, nullptr);
+ return success();
+ },
+ [](DialectRegistry ®istry) {
+ registry.insert<LLVM::LLVMArmNeonDialect, LLVM::LLVMDialect>();
+ });
+}
+} // namespace mlir
diff --git a/mlir/test/Conversion/ArmNeonToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ArmNeonToLLVM/convert-to-llvm.mlir
new file mode 100644
index 000000000000..fe56052fe734
--- /dev/null
+++ b/mlir/test/Conversion/ArmNeonToLLVM/convert-to-llvm.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-neon" | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: arm_neon_smull
+func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
+ -> (vector<8xi16>, vector<4xi32>, vector<2xi64>) {
+ // CHECK: arm_neon.smull{{.*}}: (!llvm.vec<8 x i8>, !llvm.vec<8 x i8>) -> !llvm.vec<8 x i16>
+ %0 = arm_neon.smull %a, %b : vector<8xi8> to vector<8xi16>
+ %00 = vector.extract_strided_slice %0 {offsets = [3], sizes = [4], strides = [1]}:
+ vector<8xi16> to vector<4xi16>
+
+ // CHECK: arm_neon.smull{{.*}}: (!llvm.vec<4 x i16>, !llvm.vec<4 x i16>) -> !llvm.vec<4 x i32>
+ %1 = arm_neon.smull %00, %00 : vector<4xi16> to vector<4xi32>
+ %11 = vector.extract_strided_slice %1 {offsets = [1], sizes = [2], strides = [1]}:
+ vector<4xi32> to vector<2xi32>
+
+ // CHECK: arm_neon.smull{{.*}}: (!llvm.vec<2 x i32>, !llvm.vec<2 x i32>) -> !llvm.vec<2 x i64>
+ %2 = arm_neon.smull %11, %11 : vector<2xi32> to vector<2xi64>
+
+ return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
+}
diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
new file mode 100644
index 000000000000..014da313a089
--- /dev/null
+++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: arm_neon_smull
+func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>)
+ -> (vector<8xi16>, vector<4xi32>, vector<2xi64>) {
+ // CHECK: arm_neon.smull {{.*}}: vector<8xi8> to vector<8xi16>
+ %0 = arm_neon.smull %a, %b : vector<8xi8> to vector<8xi16>
+ %00 = vector.extract_strided_slice %0 {offsets = [3], sizes = [4], strides = [1]}:
+ vector<8xi16> to vector<4xi16>
+
+ // CHECK: arm_neon.smull {{.*}}: vector<4xi16> to vector<4xi32>
+ %1 = arm_neon.smull %00, %00 : vector<4xi16> to vector<4xi32>
+ %11 = vector.extract_strided_slice %1 {offsets = [1], sizes = [2], strides = [1]}:
+ vector<4xi32> to vector<2xi32>
+
+ // CHECK: arm_neon.smull {{.*}}: vector<2xi32> to vector<2xi64>
+ %2 = arm_neon.smull %11, %11 : vector<2xi32> to vector<2xi64>
+
+ return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64>
+}
diff --git a/mlir/test/Target/arm-neon.mlir b/mlir/test/Target/arm-neon.mlir
new file mode 100644
index 000000000000..955b4aeb40a0
--- /dev/null
+++ b/mlir/test/Target/arm-neon.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate -arm-neon-mlir-to-llvmir | FileCheck %s
+
+// CHECK-LABEL: arm_neon_smull
+llvm.func @arm_neon_smull(%arg0: !llvm.vec<8 x i8>, %arg1: !llvm.vec<8 x i8>) -> !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)> {
+ // CHECK: %[[V0:.*]] = call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %{{.*}}, <8 x i8> %{{.*}})
+ // CHECK-NEXT: %[[V00:.*]] = shufflevector <8 x i16> %3, <8 x i16> %[[V0]], <4 x i32> <i32 3, i32 4, i32 5, i32 6>
+ %0 = "llvm_arm_neon.smull"(%arg0, %arg1) : (!llvm.vec<8 x i8>, !llvm.vec<8 x i8>) -> !llvm.vec<8 x i16>
+ %1 = llvm.shufflevector %0, %0 [3, 4, 5, 6] : !llvm.vec<8 x i16>, !llvm.vec<8 x i16>
+
+ // CHECK-NEXT: %[[V1:.*]] = call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %[[V00]], <4 x i16> %[[V00]])
+ // CHECK-NEXT: %[[V11:.*]] = shufflevector <4 x i32> %[[V1]], <4 x i32> %[[V1]], <2 x i32> <i32 1, i32 2>
+ %2 = "llvm_arm_neon.smull"(%1, %1) : (!llvm.vec<4 x i16>, !llvm.vec<4 x i16>) -> !llvm.vec<4 x i32>
+ %3 = llvm.shufflevector %2, %2 [1, 2] : !llvm.vec<4 x i32>, !llvm.vec<4 x i32>
+
+ // CHECK-NEXT: %[[V1:.*]] = call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %[[V11]], <2 x i32> %[[V11]])
+ %4 = "llvm_arm_neon.smull"(%3, %3) : (!llvm.vec<2 x i32>, !llvm.vec<2 x i32>) -> !llvm.vec<2 x i64>
+
+ %5 = llvm.mlir.undef : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
+ %6 = llvm.insertvalue %0, %5[0] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
+ %7 = llvm.insertvalue %2, %6[1] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
+ %8 = llvm.insertvalue %4, %7[2] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
+
+ // CHECK: ret { <8 x i16>, <4 x i32>, <2 x i64> }
+ llvm.return %8 : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)>
+}
More information about the Mlir-commits
mailing list