[Mlir-commits] [mlir] 34ff857 - [mlir][X86Vector] Add specialized vector.transpose lowering patterns for AVX2

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 10 23:39:16 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-11T07:33:31Z
New Revision: 34ff8573505e04c75e84a0e515af462f223f2795

URL: https://github.com/llvm/llvm-project/commit/34ff8573505e04c75e84a0e515af462f223f2795
DIFF: https://github.com/llvm/llvm-project/commit/34ff8573505e04c75e84a0e515af462f223f2795.diff

LOG: [mlir][X86Vector] Add specialized vector.transpose lowering patterns for AVX2

This revision adds an implementation of 2-D vector.transpose for 4x8 and 8x8 for
AVX2 and surfaces it to the Linalg level of control.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D113347

Added: 
    mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
    mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/X86Vector/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
    mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
    mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir
    mlir/test/Dialect/Vector/vector-mem-transforms.mlir
    mlir/test/Dialect/Vector/vector-transforms.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
    mlir/tools/mlir-opt/mlir-opt.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    mlir/test/Dialect/Vector/vector-flat-transforms.mlir
    mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e05dde330797b..6b816f9359c55 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/Bufferize.h"
@@ -993,6 +994,12 @@ struct LinalgVectorLoweringOptions {
     transposeLowering = val;
     return *this;
   }
+  /// Enable AVX2-specific lowerings.
+  bool avx2Lowering = false;
+  LinalgVectorLoweringOptions &enableAVX2Lowering(bool val = true) {
+    avx2Lowering = val;
+    return *this;
+  }
 
   /// Configure the post staged-patterns late vector.transfer to scf
   /// conversion.
@@ -1009,6 +1016,13 @@ struct LinalgVectorLoweringOptions {
     vectorTransformOptions = options;
     return *this;
   }
+  /// Configure specialized vector lowerings.
+  x86vector::avx2::LoweringOptions avx2LoweringOptions;
+  LinalgVectorLoweringOptions &
+  setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options) {
+    avx2LoweringOptions = options;
+    return *this;
+  }
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index 5e55a03e9b6c6..ebdc2f67e72d1 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -9,13 +9,126 @@
 #ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
 #define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
 
+#include "mlir/IR/Value.h"
+
 namespace mlir {
 
+class ImplicitLocOpBuilder;
 class LLVMConversionTarget;
 class LLVMTypeConverter;
 class RewritePatternSet;
 using OwningRewritePatternList = RewritePatternSet;
 
+namespace x86vector {
+
+/// Helper class to factor out the creation and extraction of masks from nibs.
+struct MaskHelper {
+  /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
+  /// Meant to be used with instructions such as mm256ShufflePs.
+  template <unsigned b67, unsigned b45, unsigned b23, unsigned b01>
+  static char shuffle() {
+    static_assert(b01 <= 0x03, "overflow");
+    static_assert(b23 <= 0x03, "overflow");
+    static_assert(b45 <= 0x03, "overflow");
+    static_assert(b67 <= 0x03, "overflow");
+    return (b67 << 6) + (b45 << 4) + (b23 << 2) + b01;
+  }
+  /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
+  static void extractShuffle(char mask, char &b01, char &b23, char &b45,
+                             char &b67) {
+    b67 = (mask & (0x03 << 6)) >> 6;
+    b45 = (mask & (0x03 << 4)) >> 4;
+    b23 = (mask & (0x03 << 2)) >> 2;
+    b01 = mask & 0x03;
+  }
+  /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
+  /// Meant to be used with instructions such as mm256Permute2f128Ps.
+  template <unsigned b47, unsigned b03>
+  static char permute() {
+    static_assert(b03 <= 0x0f, "overflow");
+    static_assert(b47 <= 0x0f, "overflow");
+    return (b47 << 4) + b03;
+  }
+  /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
+  static void extractPermute(char mask, char &b03, char &b47) {
+    b47 = (mask & (0x0f << 4)) >> 4;
+    b03 = mask & 0x0f;
+  }
+};
+
+//===----------------------------------------------------------------------===//
+/// Helpers extracted from:
+///   - clang/lib/Headers/avxintrin.h
+///   - clang/test/CodeGen/X86/avx-builtins.c
+///   - clang/test/CodeGen/X86/avx2-builtins.c
+///   - clang/test/CodeGen/X86/avx-shuffle-builtins.c
+/// as well as the Intel Intrinsics Guide
+/// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html)
+/// make it easier to just implement known good lowerings.
+/// All intrinsics correspond 1-1 to the Intel definition.
+//===----------------------------------------------------------------------===//
+
+namespace avx2 {
+
+/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
+Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
+
+/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
+Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
+
+///                            a  a   b   b  a  a   b   b
+/// Take an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
+///                                 0:127    |         128:255
+///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
+Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, char mask);
+
+// imm[0:1] out of imm[0:3] is:
+//    0             1           2             3
+// a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
+//          a[0:127] or a[128:255] or b[0:127] or b[128:255]
+//             0             1           2             3
+// imm[0:1] out of imm[4:7].
+Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2,
+                          char mask);
+
+/// 4x8xf32-specific AVX2 transpose lowering.
+void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
+
+/// 8x8xf32-specific AVX2 transpose lowering.
+void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
+
+/// Structure to control the behavior of specialized AVX2 transpose lowering.
+struct TransposeLoweringOptions {
+  bool lower4x8xf32_ = false;
+  TransposeLoweringOptions &lower4x8xf32(bool lower = true) {
+    lower4x8xf32_ = lower;
+    return *this;
+  }
+  bool lower8x8xf32_ = false;
+  TransposeLoweringOptions &lower8x8xf32(bool lower = true) {
+    lower8x8xf32_ = lower;
+    return *this;
+  }
+};
+
+/// Options for controlling specialized AVX2 lowerings.
+struct LoweringOptions {
+  /// Configure specialized vector lowerings.
+  TransposeLoweringOptions transposeOptions;
+  LoweringOptions &setTransposeOptions(TransposeLoweringOptions options) {
+    transposeOptions = options;
+    return *this;
+  }
+};
+
+/// Insert specialized transpose lowering patterns.
+void populateSpecializedTransposeLoweringPatterns(
+    RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(),
+    int benefit = 10);
+
+} // namespace avx2
+} // namespace x86vector
+
 /// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
 /// intrinsics.
 void populateX86VectorLegalizeForLLVMExportPatterns(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index d75b5742b38d3..db91472c5a096 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -51,5 +51,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRTransforms
   MLIRTransformUtils
   MLIRVector
+  MLIRX86VectorTransforms
   MLIRVectorToSCF
 )

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index b9d8d719898cb..4a97b926655fa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -334,6 +334,9 @@ struct LinalgStrategyLowerVectorsPass
     if (options.transposeLowering) {
       vector::populateVectorTransposeLoweringPatterns(
           patterns, options.vectorTransformOptions);
+      if (options.avx2Lowering)
+        x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
+            patterns, options.avx2LoweringOptions, /*benefit=*/10);
     }
     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
   }

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
new file mode 100644
index 0000000000000..783939583ae33
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
@@ -0,0 +1,208 @@
+//===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements vector.transpose rewrites as AVX patterns for particular
+// sizes of interest.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+using namespace mlir::x86vector::avx2;
+
+Value mlir::x86vector::avx2::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1,
+                                             Value v2) {
+  return b.create<vector::ShuffleOp>(
+      v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
+}
+
+Value mlir::x86vector::avx2::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1,
+                                             Value v2) {
+  return b.create<vector::ShuffleOp>(
+      v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
+}
+///                            a  a   b   b  a  a   b   b
+/// Takes an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
+///                                 0:127    |         128:255
+///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
+Value mlir::x86vector::avx2::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1,
+                                            Value v2, char mask) {
+  char b01, b23, b45, b67;
+  MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
+  SmallVector<int64_t> shuffleMask{b01,     b23,     b45 + 8,     b67 + 8,
+                                   b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
+  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
+}
+
+// imm[0:1] out of imm[0:3] is:
+//    0             1           2             3
+// a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
+//          a[0:127] or a[128:255] or b[0:127] or b[128:255]
+//             0             1           2             3
+// imm[0:1] out of imm[4:7].
+Value mlir::x86vector::avx2::mm256Permute2f128Ps(ImplicitLocOpBuilder &b,
+                                                 Value v1, Value v2,
+                                                 char mask) {
+  SmallVector<int64_t> shuffleMask;
+  auto appendToMask = [&](char control) {
+    if (control == 0)
+      llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
+    else if (control == 1)
+      llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
+    else if (control == 2)
+      llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
+    else if (control == 3)
+      llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
+    else
+      llvm_unreachable("control > 3 : overflow");
+  };
+  char b03, b47;
+  MaskHelper::extractPermute(mask, b03, b47);
+  appendToMask(b03);
+  appendToMask(b47);
+  return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
+}
+
+/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
+void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
+                                             MutableArrayRef<Value> vs) {
+  auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
+#ifndef NDEBUG
+  assert(vs.size() == 4 && "expects 4 vectors");
+  assert(llvm::all_of(ValueRange{vs}.getTypes(),
+                      [&](Type t) { return t == vt; }) &&
+         "expects all types to be vector<8xf32>");
+#endif
+
+  Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
+  Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
+  Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
+  Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
+  Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
+  Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
+  Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
+  Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
+  vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>());
+  vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>());
+  vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>());
+  vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>());
+}
+
+/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
+void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
+                                             MutableArrayRef<Value> vs) {
+  auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
+  (void)vt;
+  assert(vs.size() == 8 && "expects 8 vectors");
+  assert(llvm::all_of(ValueRange{vs}.getTypes(),
+                      [&](Type t) { return t == vt; }) &&
+         "expects all types to be vector<8xf32>");
+
+  Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
+  Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
+  Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
+  Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
+  Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
+  Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
+  Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
+  Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
+  Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
+  Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
+  Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
+  Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
+  Value S4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 1, 0>());
+  Value S5 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<3, 2, 3, 2>());
+  Value S6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 1, 0>());
+  Value S7 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<3, 2, 3, 2>());
+  vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>());
+  vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>());
+  vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>());
+  vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>());
+  vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>());
+  vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>());
+  vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>());
+  vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>());
+}
+
+/// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and
+/// depending on the `TransposeLoweringOptions`.
+class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
+                      int benefit)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit),
+        loweringOptions(loweringOptions) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    VectorType srcType = op.getVectorType();
+    if (srcType.getRank() != 2)
+      return rewriter.notifyMatchFailure(op, "Not a 2-D transpose");
+
+    SmallVector<int64_t, 4> transp;
+    for (auto attr : op.transp())
+      transp.push_back(attr.cast<IntegerAttr>().getInt());
+    if (transp[0] != 1 && transp[1] != 0)
+      return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation");
+
+    int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
+
+    auto applyRewrite = [&]() {
+      ImplicitLocOpBuilder ib(loc, rewriter);
+      SmallVector<Value> vs;
+      for (int64_t i = 0; i < m; ++i)
+        vs.push_back(ib.create<vector::ExtractOp>(op.vector(), i));
+      if (m == 4)
+        transpose4x8xf32(ib, vs);
+      if (m == 8)
+        transpose8x8xf32(ib, vs);
+      auto flattenedType =
+          VectorType::get({n * m}, op.getVectorType().getElementType());
+      auto transposedType =
+          VectorType::get({n, m}, op.getVectorType().getElementType());
+      Value res = ib.create<arith::ConstantOp>(
+          op.getVectorType(), ib.getZeroAttr(op.getVectorType()));
+      // The transposed form is still 4x8 and needs to be reinterpreted as 8x4
+      // via shape_casts.
+      for (int64_t i = 0; i < m; ++i)
+        res = ib.create<vector::InsertOp>(vs[i], res, i);
+      if (m == 4) {
+        res = ib.create<vector::ShapeCastOp>(flattenedType, res);
+        res = ib.create<vector::ShapeCastOp>(transposedType, res);
+      }
+
+      rewriter.replaceOp(op, res);
+      return success();
+    };
+
+    if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
+      return applyRewrite();
+    if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
+      return applyRewrite();
+    return failure();
+  }
+
+private:
+  LoweringOptions loweringOptions;
+};
+
+void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
+    RewritePatternSet &patterns, LoweringOptions options, int benefit) {
+  patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
+}

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index 679379a0a51b0..8781d1f5f1242 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRX86VectorTransforms
+  AVXTranspose.cpp
   LegalizeForLLVMExport.cpp
 
   DEPENDS
@@ -10,4 +11,5 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
   MLIRIR
   MLIRLLVMCommonConversion
   MLIRLLVMIR
+  MLIRVector
   )

diff  --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index 50e5ce8aaf9cf..93cdf05c77692 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s
+// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s
 
 #matvec_accesses = [
   affine_map<(i, j) -> (i, j)>,

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 3c32008fed81c..01e6a23698cc1 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
+// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s
+// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
+// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
+// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
 
 #dotp_accesses = [
   affine_map<(i) -> (i)>,
@@ -149,8 +149,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
 // CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
 // CHECK:    %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
-// ... bunch of extract insert to transpose B into Bt
-// CHECK:    %[[Bt:.*]] = vector.insert %{{.*}}, %{{.*}} [1, 1] : f32 into vector<2x2xf32>
+// CHECK:    %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK:    %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
 // CHECK:    %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
 // CHECK:    %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
@@ -399,28 +398,6 @@ func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) ->
    return %0: vector<16xi32>
 }
 
-// CHECK-LABEL: func @transpose23
-// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>
-// CHECK:      %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
-// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
-// CHECK:      %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
-// CHECK:      %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
-// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
-// CHECK:      %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
-// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
-// CHECK:      %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
-// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
-// CHECK:      %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
-// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
-// CHECK:      %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
-// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
-// CHECK:      return %[[T11]] : vector<3x2xf32>
-
-func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
-  %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
-  return %0 : vector<3x2xf32>
-}
-
 // CHECK-LABEL: func @nop_shape_cast
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>
 // CHECK:      return %[[A]] : vector<16xf32>

diff  --git a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir
deleted file mode 100644
index 8d51d323a1a71..0000000000000
--- a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir
+++ /dev/null
@@ -1,65 +0,0 @@
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s
-
-// Tests for lowering 2-D vector.transpose into vector.flat_transpose.
-//
-// TODO: having ShapeCastOp2DDownCastRewritePattern and
-//       ShapeCastOp2DUpCastRewritePattern too early in the greedy rewriting
-//       patterns misses opportunities to fold shape casts!
-
-// No shape cast folding expected.
-//
-// CHECK-LABEL: func @transpose44_44(
-// CHECK-SAME:  %[[A:.*]]: vector<4x4xf32>
-// CHECK:       %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
-// CHECK:       %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
-// CHECK:       %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
-//
-func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> {
-  %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
-  return %0 : vector<4x4xf32>
-}
-
-// Folds preceding shape cast as expected,
-// no following shape cast folding expected.
-//
-// FIXME: PR49590 - shape_cast not stable.
-//
-// CHECK-LABEL: func @transpose16_44(
-// CHECK-SAME:  %[[A:.*]]: vector<16xf32>
-// HECK:       %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
-// HECK:       %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
-//
-func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> {
-  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
-  %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
-  return %1 : vector<4x4xf32>
-}
-
-// No preceding shape cast folding expected,
-// but FAILS to fold following cast.
-//
-// CHECK-LABEL: func @transpose44_16(
-// CHECK-SAME:  %[[A:.*]]: vector<4x4xf32>
-// CHECK:       %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
-// CHECK:       %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
-func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> {
-  %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
-  %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
-  return %1 : vector<16xf32>
-}
-
-// Folds preceding shape cast as expected,
-// but FAILS to fold following cast.
-//
-// FIXME: PR49590 - shape_cast not stable.
-//
-// CHECK-LABEL: func @transpose16_16(
-// CHECK-SAME:  %[[A:.*]]: vector<16xf32>
-// HECK:       %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
-//
-func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> {
-  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
-  %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
-  %2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32>
-  return %2 : vector<16xf32>
-}

diff  --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index ba943873daec2..cfa63339a99af 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
+// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
 
 // CHECK-LABEL:   func @maskedload0(
 // CHECK-SAME:                      %[[A0:.*]]: memref<?xf32>,

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index f10f84bcd9891..50ed05a4b956e 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-to-vector-conversion="unroll" | FileCheck %s
+// RUN: mlir-opt %s -test-vector-to-vector-lowering="unroll" | FileCheck %s
 
 // CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
 

diff  --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
new file mode 100644
index 0000000000000..2bb66d8f8f757
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -0,0 +1,101 @@
+// RUN: mlir-opt %s -test-vector-transpose-lowering=eltwise=1 | FileCheck %s --check-prefix=ELTWISE
+// RUN: mlir-opt %s -test-vector-transpose-lowering=shuffle=1 | FileCheck %s --check-prefix=SHUFFLE
+// RUN: mlir-opt %s -test-vector-transpose-lowering=flat=1 | FileCheck %s --check-prefix=FLAT
+// RUN: mlir-opt %s -test-vector-transpose-lowering=avx2=1 | FileCheck %s --check-prefix=AVX2
+
+// ELTWISE-LABEL: func @transpose23
+// ELTWISE-SAME: %[[A:.*]]: vector<2x3xf32>
+// ELTWISE:      %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// ELTWISE:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
+// ELTWISE:      %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
+// ELTWISE:      %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
+// ELTWISE:      %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
+// ELTWISE:      %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
+// ELTWISE:      %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
+// ELTWISE:      %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
+// ELTWISE:      %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
+// ELTWISE:      %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
+// ELTWISE:      %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
+// ELTWISE:      %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
+// ELTWISE:      %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
+// ELTWISE:      return %[[T11]] : vector<3x2xf32>
+func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
+  %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
+  return %0 : vector<3x2xf32>
+}
+
+// SHUFFLE-LABEL: func @transpose
+// FLAT-LABEL: func @transpose(
+func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
+  //      SHUFFLE: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
+  //            0 4
+  // 0 1 2 3    1 5
+  // 4 5 6 7 -> 2 6
+  //            3 7
+  // SHUFFLE-NEXT: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
+  // SHUFFLE-NEXT: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
+
+  // FLAT:       vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32>
+  // FLAT:       vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32>
+  // FLAT:       vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+  return %0 : vector<4x2xf32>
+}
+
+// AVX2-LABEL: func @transpose4x8
+func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> {
+  //      AVX2: vector.extract {{.*}}[0]
+  // AVX2-NEXT: vector.extract {{.*}}[1]
+  // AVX2-NEXT: vector.extract {{.*}}[2]
+  // AVX2-NEXT: vector.extract {{.*}}[3]
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.insert {{.*}}[0]
+  // AVX2-NEXT: vector.insert {{.*}}[1]
+  // AVX2-NEXT: vector.insert {{.*}}[2]
+  // AVX2-NEXT: vector.insert {{.*}}[3]
+  // AVX2-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
+  // AVX2-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
+  return %0 : vector<8x4xf32>
+}
+
+// AVX2-LABEL: func @transpose8x8
+func @transpose8x8xf32(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
+  //      AVX2: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+  // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<8x8xf32> to vector<8x8xf32>
+  return %0 : vector<8x8xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir
deleted file mode 100644
index 1b65579b5c813..0000000000000
--- a/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir
+++ /dev/null
@@ -1,14 +0,0 @@
-// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 | FileCheck %s
-
-// CHECK-LABEL: func @transpose
-func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
-  // CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
-  //            0 4
-  // 0 1 2 3    1 5
-  // 4 5 6 7 -> 2 6
-  //            3 7
-  // CHECK: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
-  // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
-  return %0 : vector<4x2xf32>
-}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 12d57489af60b..ccd6bc5fe31a5 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -1,4 +1,4 @@
-//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
+//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -11,27 +11,31 @@
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
+using namespace mlir::linalg;
 using namespace mlir::vector;
 
 namespace {
 
-struct TestVectorToVectorConversion
-    : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
-  TestVectorToVectorConversion() = default;
-  TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
+struct TestVectorToVectorLowering
+    : public PassWrapper<TestVectorToVectorLowering, FunctionPass> {
+  TestVectorToVectorLowering() = default;
+  TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) {}
   StringRef getArgument() const final {
-    return "test-vector-to-vector-conversion";
+    return "test-vector-to-vector-lowering";
   }
   StringRef getDescription() const final {
-    return "Test conversion patterns between ops in the vector dialect";
+    return "Test lowering patterns between ops in the vector dialect";
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -95,31 +99,22 @@ struct TestVectorToVectorConversion
   }
 };
 
-struct TestVectorContractionConversion
-    : public PassWrapper<TestVectorContractionConversion, FunctionPass> {
+struct TestVectorContractionLowering
+    : public PassWrapper<TestVectorContractionLowering, FunctionPass> {
   StringRef getArgument() const final {
-    return "test-vector-contraction-conversion";
+    return "test-vector-contraction-lowering";
   }
   StringRef getDescription() const final {
-    return "Test conversion patterns that lower contract ops in the vector "
+    return "Test lowering patterns that lower contract ops in the vector "
            "dialect";
   }
-  TestVectorContractionConversion() = default;
-  TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
-  }
+  TestVectorContractionLowering() = default;
+  TestVectorContractionLowering(const TestVectorContractionLowering &pass) {}
 
   Option<bool> lowerToFlatMatrix{
       *this, "vector-lower-matrix-intrinsics",
       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
       llvm::cl::init(false)};
-  Option<bool> lowerToFlatTranspose{
-      *this, "vector-flat-transpose",
-      llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
-      llvm::cl::init(false)};
-  Option<bool> lowerToShuffleTranspose{
-      *this, "vector-shuffle-transpose",
-      llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
-      llvm::cl::init(false)};
   Option<bool> lowerToOuterProduct{
       *this, "vector-outerproduct",
       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
@@ -165,31 +160,91 @@ struct TestVectorContractionConversion
       contractLowering = VectorContractLowering::Matmul;
     VectorMultiReductionLowering vectorMultiReductionLowering =
         VectorMultiReductionLowering::InnerParallel;
-    VectorTransposeLowering transposeLowering =
-        VectorTransposeLowering::EltWise;
-    if (lowerToFlatTranspose)
-      transposeLowering = VectorTransposeLowering::Flat;
-    if (lowerToShuffleTranspose)
-      transposeLowering = VectorTransposeLowering::Shuffle;
-    VectorTransformsOptions options{
-        contractLowering, vectorMultiReductionLowering, transposeLowering};
+    VectorTransformsOptions options{contractLowering,
+                                    vectorMultiReductionLowering,
+                                    VectorTransposeLowering()};
     populateVectorBroadcastLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns, options);
     populateVectorMaskOpLoweringPatterns(patterns);
-    if (!lowerToShuffleTranspose)
-      populateVectorShapeCastLoweringPatterns(patterns);
-    populateVectorTransposeLoweringPatterns(patterns, options);
+    populateVectorShapeCastLoweringPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };
 
+struct TestVectorTransposeLowering
+    : public PassWrapper<TestVectorTransposeLowering, FunctionPass> {
+  StringRef getArgument() const final {
+    return "test-vector-transpose-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Test lowering patterns that lower contract ops in the vector "
+           "dialect";
+  }
+  TestVectorTransposeLowering() = default;
+  TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) {}
+
+  Option<bool> lowerToEltwise{
+      *this, "eltwise",
+      llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
+      llvm::cl::init(false)};
+  Option<bool> lowerToFlatTranspose{
+      *this, "flat",
+      llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
+      llvm::cl::init(false)};
+  Option<bool> lowerToShuffleTranspose{
+      *this, "shuffle",
+      llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
+      llvm::cl::init(false)};
+  Option<bool> lowerToAvx2{
+      *this, "avx2",
+      llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
+      llvm::cl::init(false)};
+
+  void runOnFunction() override {
+    RewritePatternSet patterns(&getContext());
+
+    // Test on one pattern in isolation.
+    // Explicitly disable shape_cast lowering.
+    LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
+                                              .enableVectorTransposeLowering()
+                                              .enableShapeCastLowering(false);
+    if (lowerToEltwise) {
+      options = options.setVectorTransformsOptions(
+          VectorTransformsOptions().setVectorTransposeLowering(
+              VectorTransposeLowering::EltWise));
+    }
+    if (lowerToFlatTranspose) {
+      options = options.setVectorTransformsOptions(
+          VectorTransformsOptions().setVectorTransposeLowering(
+              VectorTransposeLowering::Flat));
+    }
+    if (lowerToShuffleTranspose) {
+      options = options.setVectorTransformsOptions(
+          VectorTransformsOptions().setVectorTransposeLowering(
+              VectorTransposeLowering::Shuffle));
+    }
+    if (lowerToAvx2) {
+      options = options.enableAVX2Lowering().setAVX2LoweringOptions(
+          x86vector::avx2::LoweringOptions().setTransposeOptions(
+              x86vector::avx2::TransposeLoweringOptions()
+                  .lower4x8xf32()
+                  .lower8x8xf32()));
+    }
+
+    OpPassManager dynamicPM("builtin.func");
+    dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
+    if (failed(runPipeline(dynamicPM, getFunction())))
+      return signalPassFailure();
+  }
+};
+
 struct TestVectorUnrollingPatterns
     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
   StringRef getArgument() const final {
     return "test-vector-unrolling-patterns";
   }
   StringRef getDescription() const final {
-    return "Test conversion patterns to unroll contract ops in the vector "
+    return "Test lowering patterns to unroll contract ops in the vector "
            "dialect";
   }
   TestVectorUnrollingPatterns() = default;
@@ -248,7 +303,7 @@ struct TestVectorDistributePatterns
     return "test-vector-distribute-patterns";
   }
   StringRef getDescription() const final {
-    return "Test conversion patterns to distribute vector ops in the vector "
+    return "Test lowering patterns to distribute vector ops in the vector "
            "dialect";
   }
   TestVectorDistributePatterns() = default;
@@ -302,7 +357,7 @@ struct TestVectorToLoopPatterns
     : public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
   StringRef getArgument() const final { return "test-vector-to-forloop"; }
   StringRef getDescription() const final {
-    return "Test conversion patterns to break up a vector op into a for loop";
+    return "Test lowering patterns to break up a vector op into a for loop";
   }
   TestVectorToLoopPatterns() = default;
   TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
@@ -365,7 +420,7 @@ struct TestVectorTransferUnrollingPatterns
     return "test-vector-transfer-unrolling-patterns";
   }
   StringRef getDescription() const final {
-    return "Test conversion patterns to unroll transfer ops in the vector "
+    return "Test lowering patterns to unroll transfer ops in the vector "
            "dialect";
   }
   void runOnFunction() override {
@@ -391,7 +446,7 @@ struct TestVectorTransferFullPartialSplitPatterns
     return "test-vector-transfer-full-partial-split";
   }
   StringRef getDescription() const final {
-    return "Test conversion patterns to split "
+    return "Test lowering patterns to split "
            "transfer ops via scf.if + linalg ops";
   }
   TestVectorTransferFullPartialSplitPatterns() = default;
@@ -439,7 +494,7 @@ struct TestVectorTransferLoweringPatterns
     return "test-vector-transfer-lowering-patterns";
   }
   StringRef getDescription() const final {
-    return "Test conversion patterns to lower transfer ops to other vector ops";
+    return "Test lowering patterns to lower transfer ops to other vector ops";
   }
   void runOnFunction() override {
     RewritePatternSet patterns(&getContext());
@@ -462,7 +517,7 @@ struct TestVectorMultiReductionLoweringPatterns
     return "test-vector-multi-reduction-lowering-patterns";
   }
   StringRef getDescription() const final {
-    return "Test conversion patterns to lower vector.multi_reduction to other "
+    return "Test lowering patterns to lower vector.multi_reduction to other "
            "vector ops";
   }
   Option<bool> useOuterReductions{
@@ -495,7 +550,7 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
   }
 
   StringRef getDescription() const final {
-    return "Test conversion patterns that reducedes the rank of the vector "
+    return "Test lowering patterns that reducedes the rank of the vector "
            "transfer memory and vector operands.";
   }
 
@@ -527,10 +582,12 @@ struct TestVectorReduceToContractPatternsPatterns
 
 namespace mlir {
 namespace test {
-void registerTestVectorConversions() {
-  PassRegistration<TestVectorToVectorConversion>();
+void registerTestVectorLowerings() {
+  PassRegistration<TestVectorToVectorLowering>();
+
+  PassRegistration<TestVectorContractionLowering>();
 
-  PassRegistration<TestVectorContractionConversion>();
+  PassRegistration<TestVectorTransposeLowering>();
 
   PassRegistration<TestVectorUnrollingPatterns>();
 

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index dd24b4e507523..285b06d7aa831 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -107,7 +107,7 @@ void registerTestPreparationPassWithAllowedMemrefResults();
 void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
 void registerTestSliceAnalysisPass();
-void registerTestVectorConversions();
+void registerTestVectorLowerings();
 } // namespace test
 } // namespace mlir
 
@@ -197,7 +197,7 @@ void registerTestPasses() {
   mlir::test::registerTestRecursiveTypesPass();
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSliceAnalysisPass();
-  mlir::test::registerTestVectorConversions();
+  mlir::test::registerTestVectorLowerings();
 }
 #endif
 

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index bb3e66695c5d0..1c8a45c7e9b74 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1458,6 +1458,7 @@ cc_library(
         ":LLVMCommonConversion",
         ":LLVMDialect",
         ":StandardOps",
+        ":VectorOps",
         ":X86Vector",
         "//llvm:Core",
         "//llvm:Support",
@@ -6401,6 +6402,7 @@ cc_library(
         ":TransformUtils",
         ":VectorOps",
         ":VectorToSCF",
+        ":X86VectorTransforms",
         "//llvm:Support",
     ],
 )

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 31c0a7303f2fb..a4556d9499060 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -484,6 +484,7 @@ cc_library(
         "//mlir:Affine",
         "//mlir:Analysis",
         "//mlir:LinalgOps",
+        "//mlir:LinalgTransforms",
         "//mlir:MemRefDialect",
         "//mlir:Pass",
         "//mlir:SCFDialect",


        


More information about the Mlir-commits mailing list