[Mlir-commits] [mlir] 0b63e32 - [mlir] X86Vector: Add AVX Rsqrt
Aart Bik
llvmlistbot at llvm.org
Tue Apr 13 08:44:01 PDT 2021
Author: Emilio Cota
Date: 2021-04-13T08:43:48-07:00
New Revision: 0b63e3222b2de3ec24aade18c99513a5ae3f30d2
URL: https://github.com/llvm/llvm-project/commit/0b63e3222b2de3ec24aade18c99513a5ae3f30d2
DIFF: https://github.com/llvm/llvm-project/commit/0b63e3222b2de3ec24aade18c99513a5ae3f30d2.diff
LOG: [mlir] X86Vector: Add AVX Rsqrt
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D99818
Added:
mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-rsqrt.mlir
Modified:
mlir/include/mlir/Dialect/X86Vector/X86Vector.td
mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
mlir/test/Dialect/X86Vector/roundtrip.mlir
mlir/test/Target/LLVMIR/x86vector.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 0c9aed89bc188..9f5e7577d2d32 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -267,4 +267,32 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
VectorOfLengthAndType<[8], [I64]>:$b);
}
+//===----------------------------------------------------------------------===//
+// AVX op definitions
+//===----------------------------------------------------------------------===//
+
+class AVX_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<X86Vector_Dialect, "avx." # mnemonic, traits> {}
+
+class AVX_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
+ LLVM_IntrOpBase<X86Vector_Dialect, "avx.intr." # mnemonic,
+ "x86_avx_" # !subst(".", "_", mnemonic),
+ [], [], traits, numResults>;
+
+//----------------------------------------------------------------------------//
+// AVX Rsqrt
+//----------------------------------------------------------------------------//
+
+def RsqrtOp : AVX_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Rsqrt";
+ let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
+ let results = (outs VectorOfLengthAndType<[8], [F32]>:$b);
+ let assemblyFormat = "$a attr-dict `:` type($a)";
+}
+
+def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [NoSideEffect,
+ SameOperandsAndResultType]> {
+ let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
+}
+
#endif // X86VECTOR_OPS
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 38f264ac0988b..9e2a743450ff0 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -90,6 +90,20 @@ struct MaskCompressOpConversion
}
};
+struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
+ using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ RsqrtOp::Adaptor adaptor(operands);
+
+ auto opType = adaptor.a().getType();
+ rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a());
+ return success();
+ }
+};
+
/// An entry associating the "main" AVX512 op with its instantiations for
/// vectors of 32-bit and 64-bit elements.
template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
@@ -131,7 +145,7 @@ using Registry = RegistryImpl<
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Registry::registerPatterns(converter, patterns);
- patterns.add<MaskCompressOpConversion>(converter);
+ patterns.add<MaskCompressOpConversion, RsqrtOpConversion>(converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
@@ -139,4 +153,6 @@ void mlir::configureX86VectorLegalizeForExportTarget(
Registry::configureTarget(target);
target.addLegalOp<MaskCompressIntrOp>();
target.addIllegalOp<MaskCompressOp>();
+ target.addLegalOp<RsqrtIntrOp>();
+ target.addIllegalOp<RsqrtOp>();
}
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 477ba558523e9..6f23153d41db2 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -42,3 +42,11 @@ func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
%2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64>
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
}
+
+// CHECK-LABEL: func @avx_rsqrt
+func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
+{
+ // CHECK: x86vector.avx.intr.rsqrt.ps.256
+ %0 = x86vector.avx.rsqrt %a : vector<8xf32>
+ return %0 : vector<8xf32>
+}
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 57cfe9800c06a..4dfd934c59385 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -46,3 +46,11 @@ func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
%2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
}
+
+// CHECK-LABEL: func @avx_rsqrt
+func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
+{
+ // CHECK: x86vector.avx.rsqrt {{.*}} : vector<8xf32>
+ %0 = x86vector.avx.rsqrt %a : vector<8xf32>
+ return %0 : vector<8xf32>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-rsqrt.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-rsqrt.mlir
new file mode 100644
index 0000000000000..e43c4b05ecc95
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-rsqrt.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-x86vector" -convert-std-to-llvm | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: %lli --jit-kind=mcjit --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+// TODO: drop lli's --jit-kind flag once PR#49906 (https://bugs.llvm.org/show_bug.cgi?id=49906) is fixed.
+
+func @entry() -> i32 {
+ %i0 = constant 0 : i32
+
+ %v = std.constant dense<[0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]> : vector<8xf32>
+ %r = x86vector.avx.rsqrt %v : vector<8xf32>
+ // CHECK: ( 2.82764, 1.99951, 1.41382, 0.999756, 0.706909, 0.499878, 0.353455, 0.249939 )
+ vector.print %r : vector<8xf32>
+
+ return %i0 : i32
+}
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 9c202a6e3ab58..190732868cb7a 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -59,3 +59,11 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
(vector<8xi64>, vector<8xi64>) -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
llvm.return %0 : !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
+llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float>
+ %0 = "x86vector.avx.intr.rsqrt.ps.256"(%a) : (vector<8xf32>) -> (vector<8xf32>)
+ llvm.return %0 : vector<8xf32>
+}
More information about the Mlir-commits
mailing list