[Mlir-commits] [mlir] 0b63e32 - [mlir] X86Vector: Add AVX Rsqrt

Aart Bik llvmlistbot at llvm.org
Tue Apr 13 08:44:01 PDT 2021


Author: Emilio Cota
Date: 2021-04-13T08:43:48-07:00
New Revision: 0b63e3222b2de3ec24aade18c99513a5ae3f30d2

URL: https://github.com/llvm/llvm-project/commit/0b63e3222b2de3ec24aade18c99513a5ae3f30d2
DIFF: https://github.com/llvm/llvm-project/commit/0b63e3222b2de3ec24aade18c99513a5ae3f30d2.diff

LOG: [mlir] X86Vector: Add AVX Rsqrt

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D99818

Added: 
    mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-rsqrt.mlir

Modified: 
    mlir/include/mlir/Dialect/X86Vector/X86Vector.td
    mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
    mlir/test/Dialect/X86Vector/roundtrip.mlir
    mlir/test/Target/LLVMIR/x86vector.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 0c9aed89bc188..9f5e7577d2d32 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -267,4 +267,32 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
                    VectorOfLengthAndType<[8], [I64]>:$b);
 }
 
+//===----------------------------------------------------------------------===//
+// AVX op definitions
+//===----------------------------------------------------------------------===//
+
+class AVX_Op<string mnemonic, list<OpTrait> traits = []> :
+  Op<X86Vector_Dialect, "avx." # mnemonic, traits> {}
+
+class AVX_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
+  LLVM_IntrOpBase<X86Vector_Dialect, "avx.intr." # mnemonic,
+                  "x86_avx_" # !subst(".", "_", mnemonic),
+                  [], [], traits, numResults>;
+
+//----------------------------------------------------------------------------//
+// AVX Rsqrt
+//----------------------------------------------------------------------------//
+
+def RsqrtOp : AVX_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> {
+  let summary = "Rsqrt";
+  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
+  let results = (outs VectorOfLengthAndType<[8], [F32]>:$b);
+  let assemblyFormat = "$a attr-dict `:` type($a)";
+}
+
+def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [NoSideEffect,
+  SameOperandsAndResultType]> {
+  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
+}
+
 #endif // X86VECTOR_OPS

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 38f264ac0988b..9e2a743450ff0 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -90,6 +90,20 @@ struct MaskCompressOpConversion
   }
 };
 
+struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
+  using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    RsqrtOp::Adaptor adaptor(operands);
+
+    auto opType = adaptor.a().getType();
+    rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a());
+    return success();
+  }
+};
+
 /// An entry associating the "main" AVX512 op with its instantiations for
 /// vectors of 32-bit and 64-bit elements.
 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
@@ -131,7 +145,7 @@ using Registry = RegistryImpl<
 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   Registry::registerPatterns(converter, patterns);
-  patterns.add<MaskCompressOpConversion>(converter);
+  patterns.add<MaskCompressOpConversion, RsqrtOpConversion>(converter);
 }
 
 void mlir::configureX86VectorLegalizeForExportTarget(
@@ -139,4 +153,6 @@ void mlir::configureX86VectorLegalizeForExportTarget(
   Registry::configureTarget(target);
   target.addLegalOp<MaskCompressIntrOp>();
   target.addIllegalOp<MaskCompressOp>();
+  target.addLegalOp<RsqrtIntrOp>();
+  target.addIllegalOp<RsqrtOp>();
 }

diff  --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 477ba558523e9..6f23153d41db2 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -42,3 +42,11 @@ func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
   %2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64>
   return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
 }
+
+// CHECK-LABEL: func @avx_rsqrt
+func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
+{
+  // CHECK: x86vector.avx.intr.rsqrt.ps.256
+  %0 = x86vector.avx.rsqrt %a : vector<8xf32>
+  return %0 : vector<8xf32>
+}

diff  --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 57cfe9800c06a..4dfd934c59385 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -46,3 +46,11 @@ func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
   %2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
   return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
 }
+
+// CHECK-LABEL: func @avx_rsqrt
+func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
+{
+  // CHECK: x86vector.avx.rsqrt {{.*}} : vector<8xf32>
+  %0 = x86vector.avx.rsqrt %a : vector<8xf32>
+  return %0 : vector<8xf32>
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-rsqrt.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-rsqrt.mlir
new file mode 100644
index 0000000000000..e43c4b05ecc95
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-rsqrt.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-x86vector" -convert-std-to-llvm | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: %lli --jit-kind=mcjit --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+// TODO: drop lli's --jit-kind flag once PR#49906 (https://bugs.llvm.org/show_bug.cgi?id=49906) is fixed.
+
+func @entry() -> i32 {
+  %i0 = constant 0 : i32
+
+  %v = std.constant dense<[0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]> : vector<8xf32>
+  %r = x86vector.avx.rsqrt %v : vector<8xf32>
+  // CHECK: ( 2.82764, 1.99951, 1.41382, 0.999756, 0.706909, 0.499878, 0.353455, 0.249939 )
+  vector.print %r : vector<8xf32>
+
+  return %i0 : i32
+}

diff  --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 9c202a6e3ab58..190732868cb7a 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -59,3 +59,11 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
     (vector<8xi64>, vector<8xi64>) -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
   llvm.return %0 : !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
 }
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
+llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
+{
+  // CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float>
+  %0 = "x86vector.avx.intr.rsqrt.ps.256"(%a) : (vector<8xf32>) -> (vector<8xf32>)
+  llvm.return %0 : vector<8xf32>
+}


        


More information about the Mlir-commits mailing list