[Mlir-commits] [mlir] a9e236b - [mlir][Vector] Add a vblendps-based impl for transpose8x8 (both intrin and inline_asm)

Nicolas Vasilache llvmlistbot at llvm.org
Mon Nov 22 02:32:38 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-22T10:32:34Z
New Revision: a9e236bed835c58be381dadb973a1db0681e4795

URL: https://github.com/llvm/llvm-project/commit/a9e236bed835c58be381dadb973a1db0681e4795
DIFF: https://github.com/llvm/llvm-project/commit/a9e236bed835c58be381dadb973a1db0681e4795.diff

LOG: [mlir][Vector] Add a vblendps-based impl for transpose8x8 (both intrin and inline_asm)

This revision follows up on the conversation titled:

```[llvm-dev] Understanding and controlling some of the AVX shuffle emission paths```

The revision adds a vblendps-based implementation for transpose8x8 and further distinguishes between and intrinsics and an inline_asm implementation.

This results in roughly 20% fewer cycles as reported by llvm-mca:

After this revision (intrinsic version, resolves to virtually identical assembly as per the llvm-dev discussion, no vblendps instruction is emitted):
```
Iterations:        100
Instructions:      5900
Total Cycles:      2415
Total uOps:        7300

Dispatch Width:    6
uOps Per Cycle:    3.02
IPC:               2.44
Block RThroughput: 24.0

Cycles with backend pressure increase [ 89.90% ]
Throughput Bottlenecks:
  Resource Pressure       [ 89.65% ]
  - SKXPort1  [ 0.04% ]
  - SKXPort2  [ 12.42% ]
  - SKXPort3  [ 12.42% ]
  - SKXPort5  [ 89.52% ]
  Data Dependencies:      [ 37.06% ]
  - Register Dependencies [ 37.06% ]
  - Memory Dependencies   [ 0.00% ]
```

After this revision (inline_asm version, vblendps instructions are indeed emitted):
```
Iterations:        100
Instructions:      6300
Total Cycles:      2015
Total uOps:        7700

Dispatch Width:    6
uOps Per Cycle:    3.82
IPC:               3.13
Block RThroughput: 20.0

Cycles with backend pressure increase [ 83.47% ]
Throughput Bottlenecks:
  Resource Pressure       [ 83.18% ]
  - SKXPort0  [ 14.49% ]
  - SKXPort1  [ 14.54% ]
  - SKXPort2  [ 19.70% ]
  - SKXPort3  [ 19.70% ]
  - SKXPort5  [ 83.03% ]
  - SKXPort6  [ 14.49% ]
  Data Dependencies:      [ 39.75% ]
  - Register Dependencies [ 39.75% ]
  - Memory Dependencies   [ 0.00% ]
```

An accessible copy of the conversation is available [here](https://gist.github.com/nicolasvasilache/68c7f34012584b0e00f335bcb374ede0).

Reviewed By: ftynse, dcaballe

Differential Revision: https://reviews.llvm.org/D114335

Added: 
    mlir/test/Integration/Dialect/LLVMIR/CPU/X86/test-inline-asm-vector.mlir

Modified: 
    mlir/include/mlir/Dialect/X86Vector/Transforms.h
    mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
    mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
    mlir/test/lib/Dialect/Vector/CMakeLists.txt
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index 66749d55d594d..187ee79c3686d 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -23,19 +23,43 @@ namespace x86vector {
 
 /// Helper class to factor out the creation and extraction of masks from nibs.
 struct MaskHelper {
+  /// b0 captures the lowest bit, b7 captures the highest bit.
+  /// Meant to be used with instructions such as mm256BlendPs.
+  template <uint b0, uint b1, uint b2, uint b3, uint b4, uint b5, uint b6,
+            uint b7>
+  static uint8_t blend() {
+    static_assert(b0 <= 1 && b1 <= 1 && b2 <= 1 && b3 <= 1, "overflow");
+    static_assert(b4 <= 1 && b5 <= 1 && b6 <= 1 && b7 <= 1, "overflow");
+    return static_cast<uint8_t>((b7 << 7) | (b6 << 6) | (b5 << 5) | (b4 << 4) |
+                                (b3 << 3) | (b2 << 2) | (b1 << 1) | b0);
+  }
+  /// b0 captures the lowest bit, b7 captures the highest bit.
+  /// Meant to be used with instructions such as mm256BlendPs.
+  static void extractBlend(uint8_t mask, uint8_t &b0, uint8_t &b1, uint8_t &b2,
+                           uint8_t &b3, uint8_t &b4, uint8_t &b5, uint8_t &b6,
+                           uint8_t &b7) {
+    b7 = mask & (1 << 7);
+    b6 = mask & (1 << 6);
+    b5 = mask & (1 << 5);
+    b4 = mask & (1 << 4);
+    b3 = mask & (1 << 3);
+    b2 = mask & (1 << 2);
+    b1 = mask & (1 << 1);
+    b0 = mask & 1;
+  }
   /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
   /// Meant to be used with instructions such as mm256ShufflePs.
   template <unsigned b67, unsigned b45, unsigned b23, unsigned b01>
-  static int8_t shuffle() {
+  static uint8_t shuffle() {
     static_assert(b01 <= 0x03, "overflow");
     static_assert(b23 <= 0x03, "overflow");
     static_assert(b45 <= 0x03, "overflow");
     static_assert(b67 <= 0x03, "overflow");
-    return static_cast<int8_t>((b67 << 6) | (b45 << 4) | (b23 << 2) | b01);
+    return static_cast<uint8_t>((b67 << 6) | (b45 << 4) | (b23 << 2) | b01);
   }
   /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
-  static void extractShuffle(int8_t mask, int8_t &b01, int8_t &b23, int8_t &b45,
-                             int8_t &b67) {
+  static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23,
+                             uint8_t &b45, uint8_t &b67) {
     b67 = (mask & (0x03 << 6)) >> 6;
     b45 = (mask & (0x03 << 4)) >> 4;
     b23 = (mask & (0x03 << 2)) >> 2;
@@ -44,13 +68,13 @@ struct MaskHelper {
   /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
   /// Meant to be used with instructions such as mm256Permute2f128Ps.
   template <unsigned b47, unsigned b03>
-  static int8_t permute() {
+  static uint8_t permute() {
     static_assert(b03 <= 0x0f, "overflow");
     static_assert(b47 <= 0x0f, "overflow");
-    return static_cast<int8_t>((b47 << 4) + b03);
+    return static_cast<uint8_t>((b47 << 4) + b03);
   }
   /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
-  static void extractPermute(int8_t mask, int8_t &b03, int8_t &b47) {
+  static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47) {
     b47 = (mask & (0x0f << 4)) >> 4;
     b03 = mask & 0x0f;
   }
@@ -70,6 +94,20 @@ struct MaskHelper {
 
 namespace avx2 {
 
+namespace inline_asm {
+//===----------------------------------------------------------------------===//
+/// Methods in the inline_asm namespace  emit calls to LLVM::InlineAsmOp.
+//===----------------------------------------------------------------------===//
+/// If bit i of `mask` is zero, take f32 at i from v1 else take it from v2.
+Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2,
+                      uint8_t mask);
+
+} // namespace inline_asm
+
+namespace intrin {
+//===----------------------------------------------------------------------===//
+/// Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
+//===----------------------------------------------------------------------===//
 /// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
 Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
 
@@ -80,7 +118,7 @@ Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
 /// Take an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
 ///                                 0:127    |         128:255
 ///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
-Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, int8_t mask);
+Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
 
 // imm[0:1] out of imm[0:3] is:
 //    0             1           2             3
@@ -89,8 +127,15 @@ Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, int8_t mask);
 //             0             1           2             3
 // imm[0:1] out of imm[4:7].
 Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2,
-                          int8_t mask);
+                          uint8_t mask);
+
+/// If bit i of `mask` is zero, take f32 at i from v1 else take it from v2.
+Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
+} // namespace intrin
 
+//===----------------------------------------------------------------------===//
+/// Generic lowerings may either use intrin or inline_asm depending on needs.
+//===----------------------------------------------------------------------===//
 /// 4x8xf32-specific AVX2 transpose lowering.
 void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
 

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
index 981413b160d31..38088e17bd4f4 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
@@ -11,25 +11,43 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
 using namespace mlir::vector;
 using namespace mlir::x86vector;
 using namespace mlir::x86vector::avx2;
+using namespace mlir::x86vector::avx2::inline_asm;
+using namespace mlir::x86vector::avx2::intrin;
+
+Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
+    ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
+  auto asmDialectAttr =
+      LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
+  auto asmTp = "vblendps $0, $1, $2, {0}";
+  auto asmCstr = "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
+  SmallVector<Value> asmVals{v1, v2};
+  auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
+  auto asmOp = b.create<LLVM::InlineAsmOp>(
+      v1.getType(), asmVals, asmStr, asmCstr, false, false, asmDialectAttr);
+  return asmOp.getResult(0);
+}
 
-Value mlir::x86vector::avx2::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1,
-                                             Value v2) {
+Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
+                                                     Value v1, Value v2) {
   return b.create<vector::ShuffleOp>(
       v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
 }
 
-Value mlir::x86vector::avx2::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1,
-                                             Value v2) {
+Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
+                                                     Value v1, Value v2) {
   return b.create<vector::ShuffleOp>(
       v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
 }
@@ -37,9 +55,10 @@ Value mlir::x86vector::avx2::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1,
 /// Takes an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
 ///                                 0:127    |         128:255
 ///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
-Value mlir::x86vector::avx2::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1,
-                                            Value v2, int8_t mask) {
-  int8_t b01, b23, b45, b67;
+Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
+                                                    Value v1, Value v2,
+                                                    uint8_t mask) {
+  uint8_t b01, b23, b45, b67;
   MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
   SmallVector<int64_t> shuffleMask{b01,     b23,     b45 + 8,     b67 + 8,
                                    b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
@@ -52,11 +71,10 @@ Value mlir::x86vector::avx2::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1,
 //          a[0:127] or a[128:255] or b[0:127] or b[128:255]
 //             0             1           2             3
 // imm[0:1] out of imm[4:7].
-Value mlir::x86vector::avx2::mm256Permute2f128Ps(ImplicitLocOpBuilder &b,
-                                                 Value v1, Value v2,
-                                                 int8_t mask) {
+Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
+    ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
   SmallVector<int64_t> shuffleMask;
-  auto appendToMask = [&](int8_t control) {
+  auto appendToMask = [&](uint8_t control) {
     if (control == 0)
       llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
     else if (control == 1)
@@ -68,13 +86,25 @@ Value mlir::x86vector::avx2::mm256Permute2f128Ps(ImplicitLocOpBuilder &b,
     else
       llvm_unreachable("control > 3 : overflow");
   };
-  int8_t b03, b47;
+  uint8_t b03, b47;
   MaskHelper::extractPermute(mask, b03, b47);
   appendToMask(b03);
   appendToMask(b47);
   return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
 }
 
+/// If bit i of `mask` is zero, take f32 at i from v1 else take it from v2.
+Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
+                                                  Value v1, Value v2,
+                                                  uint8_t mask) {
+  SmallVector<int64_t, 8> shuffleMask;
+  for (int i = 0; i < 8; ++i) {
+    bool isSet = mask & (1 << i);
+    shuffleMask.push_back(!isSet ? i : i + 8);
+  }
+  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
+}
+
 /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
 void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
                                              MutableArrayRef<Value> vs) {
@@ -118,14 +148,30 @@ void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
   Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
   Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
   Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
-  Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
-  Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
-  Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
-  Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
-  Value S4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 1, 0>());
-  Value S5 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<3, 2, 3, 2>());
-  Value S6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 1, 0>());
-  Value S7 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<3, 2, 3, 2>());
+
+  using inline_asm::mm256BlendPsAsm;
+  Value sh0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 3, 2>());
+  Value sh2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 3, 2>());
+  Value sh4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 3, 2>());
+  Value sh6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 3, 2>());
+
+  Value S0 =
+      mm256BlendPsAsm(ib, T0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
+  Value S1 =
+      mm256BlendPsAsm(ib, T2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
+  Value S2 =
+      mm256BlendPsAsm(ib, T1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
+  Value S3 =
+      mm256BlendPsAsm(ib, T3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
+  Value S4 =
+      mm256BlendPsAsm(ib, T4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
+  Value S5 =
+      mm256BlendPsAsm(ib, T6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
+  Value S6 =
+      mm256BlendPsAsm(ib, T5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
+  Value S7 =
+      mm256BlendPsAsm(ib, T7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
+
   vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>());
   vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>());
   vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>());

diff  --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 2bb66d8f8f757..cc62eaaf04b05 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -80,22 +80,17 @@ func @transpose8x8xf32(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
   // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
   // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
   // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
-  // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+  // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+  // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+  // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+  // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+  // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+  // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+  // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+  // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+  // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
   %0 = vector.transpose %arg0, [1, 0] : vector<8x8xf32> to vector<8x8xf32>
   return %0 : vector<8x8xf32>
 }

diff  --git a/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/test-inline-asm-vector.mlir b/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/test-inline-asm-vector.mlir
new file mode 100644
index 0000000000000..5d5bbdc25f2e4
--- /dev/null
+++ b/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/test-inline-asm-vector.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm |  \
+// RUN: mlir-cpu-runner -e entry_point_with_all_constants -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext
+
+module {
+  llvm.func @function_to_run(%a: vector<8xf32>, %b: vector<8xf32>)  {
+    // CHECK: ( 8, 10, 12, 14, 16, 18, 20, 22 )
+    %r0 = llvm.inline_asm asm_dialect = intel
+        "vaddps $0, $1, $2", "=x,x,x" %a, %b:
+      (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+    vector.print %r0: vector<8xf32>
+
+    // vblendps implemented with inline_asm.
+    // CHECK: ( 0, 1, 10, 11, 4, 5, 14, 15 )
+    %r1 = llvm.inline_asm asm_dialect = intel
+        "vblendps $0, $1, $2, 0xCC", "=x,x,x" %a, %b:
+      (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+    vector.print %r1: vector<8xf32>
+
+    // vblendps 0xCC via vector.shuffle (emulates clang intrinsics impl)
+    // CHECK: ( 0, 1, 10, 11, 4, 5, 14, 15 )
+    %r2 = vector.shuffle %a, %b[0, 1, 10, 11, 4, 5, 14, 15]
+      : vector<8xf32>, vector<8xf32>
+    vector.print %r2: vector<8xf32>
+
+    // vblendps 0x33 implemented with inline_asm.
+    // CHECK: ( 8, 9, 2, 3, 12, 13, 6, 7 )
+    %r3 = llvm.inline_asm asm_dialect = intel
+        "vblendps $0, $1, $2, 0x33", "=x,x,x" %a, %b:
+      (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
+    vector.print %r3: vector<8xf32>
+
+    // vblendps 0x33 via vector.shuffle (emulates clang intrinsics impl)
+    // CHECK: ( 8, 9, 2, 3, 12, 13, 6, 7 )
+    %r4 = vector.shuffle %a, %b[8, 9, 2, 3, 12, 13, 6, 7]
+      : vector<8xf32>, vector<8xf32>
+    vector.print %r4: vector<8xf32>
+
+    llvm.return
+  }
+
+  // Solely exists to prevent inlining and get the expected assembly.
+  llvm.func @entry_point(%a: vector<8xf32>, %b: vector<8xf32>)  {
+    llvm.call @function_to_run(%a, %b) : (vector<8xf32>, vector<8xf32>) -> ()
+    llvm.return
+  }
+
+  llvm.func @entry_point_with_all_constants()  {
+    %a = llvm.mlir.constant(dense<[0.0, 1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0]>
+      : vector<8xf32>) : vector<8xf32>
+    %b = llvm.mlir.constant(dense<[8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]>
+      : vector<8xf32>) : vector<8xf32>
+    llvm.call @function_to_run(%a, %b) : (vector<8xf32>, vector<8xf32>) -> ()
+    llvm.return
+  }
+}

diff  --git a/mlir/test/lib/Dialect/Vector/CMakeLists.txt b/mlir/test/lib/Dialect/Vector/CMakeLists.txt
index b3c0807b102eb..7629595876215 100644
--- a/mlir/test/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Vector/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_library(MLIRVectorTestPasses
   MLIRAnalysis
   MLIRLinalg
   MLIRLinalgTransforms
+  MLIRLLVMIR
   MLIRMemRef
   MLIRPass
   MLIRSCF
@@ -16,4 +17,5 @@ add_mlir_library(MLIRVectorTestPasses
   MLIRTransformUtils
   MLIRVector
   MLIRVectorToSCF
+  MLIRX86Vector
   )

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ccd6bc5fe31a5..0acab9c87a3f0 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -200,6 +201,10 @@ struct TestVectorTransposeLowering
       llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
       llvm::cl::init(false)};
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<LLVM::LLVMDialect>();
+  }
+
   void runOnFunction() override {
     RewritePatternSet patterns(&getContext());
 

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index a4556d9499060..b0ed24a7cdeca 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -483,6 +483,7 @@ cc_library(
     deps = [
         "//mlir:Affine",
         "//mlir:Analysis",
+        "//mlir:LLVMDialect",
         "//mlir:LinalgOps",
         "//mlir:LinalgTransforms",
         "//mlir:MemRefDialect",
@@ -492,6 +493,7 @@ cc_library(
         "//mlir:TransformUtils",
         "//mlir:VectorOps",
         "//mlir:VectorToSCF",
+        "//mlir:X86Vector",
     ],
 )
 


        


More information about the Mlir-commits mailing list