[Mlir-commits] [mlir] 2bc4c3e - [mlir][Vector] NFC - Reorganize vector patterns

Nicolas Vasilache llvmlistbot at llvm.org
Thu Mar 23 11:30:33 PDT 2023


Author: Nicolas Vasilache
Date: 2023-03-23T11:30:25-07:00
New Revision: 2bc4c3e920ee078ef2879b00c40440e0867f0b9e

URL: https://github.com/llvm/llvm-project/commit/2bc4c3e920ee078ef2879b00c40440e0867f0b9e
DIFF: https://github.com/llvm/llvm-project/commit/2bc4c3e920ee078ef2879b00c40440e0867f0b9e.diff

LOG: [mlir][Vector] NFC - Reorganize vector patterns

Vector dialect patterns have grown enormously in the past year to a point where they are now impenetrable.
Start reorganizing them towards finer-grained control.

Differential Revision: https://reviews.llvm.org/D146736

Added: 
    mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
    mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 56f8b4bf22d21..4763b6525b934 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -110,43 +110,11 @@ void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
 void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns,
                                            PatternBenefit benefit = 1);
 
-/// Collect a set of transfer read/write lowering patterns.
-///
-/// These patterns lower transfer ops to simpler ops like `vector.load`,
-/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
-/// of a most `maxTransferRank` are lowered. This is useful when combined with
-/// VectorToSCF, which reduces the rank of vector transfer ops.
-void populateVectorTransferLoweringPatterns(
-    RewritePatternSet &patterns,
-    std::optional<unsigned> maxTransferRank = std::nullopt,
-    PatternBenefit benefit = 1);
-
 /// These patterns materialize masks for various vector ops such as transfers.
 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
                                                bool force32BitVectorIndices,
                                                PatternBenefit benefit = 1);
 
-/// Collects patterns to progressively lower vector.broadcast ops on high-D
-/// vectors to low-D vector ops.
-void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
-                                             PatternBenefit benefit = 1);
-
-/// Collects patterns to progressively lower vector mask ops into elementary
-/// selection and insertion ops.
-void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
-
-/// Collects patterns to progressively lower vector.shape_cast ops on high-D
-/// vectors into 1-D/2-D vector ops by generating data movement extract/insert
-/// ops.
-void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
-                                             PatternBenefit benefit = 1);
-
-/// Collects patterns that lower scalar vector transfer ops to memref loads and
-/// stores when beneficial.
-void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
-                                                  PatternBenefit benefit = 1);
-
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);
 
@@ -214,8 +182,8 @@ void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
 /// Creates a vector.mask operation around a maskable operation. Returns the
 /// vector.mask operation if the mask provided is valid. Otherwise, returns the
 /// maskable operation itself.
-Operation *maskOperation(OpBuilder &builder, Operation *maskableOp,
-                         Value mask, Value passthru = Value());
+Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask,
+                         Value passthru = Value());
 
 /// Creates a vector select operation that picks values from `newValue` or
 /// `passthru` for each result vector lane based on `mask`. This utility is used

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
new file mode 100644
index 0000000000000..dfadffba3883b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -0,0 +1,248 @@
+//===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
+
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+namespace vector {
+
+//===----------------------------------------------------------------------===//
+// Lowering pattern populate functions
+//===----------------------------------------------------------------------===//
+
+/// Populate the pattern set with the following patterns:
+///
+/// [OuterProductOpLowering]
+/// Progressively lower a `vector.outerproduct` to linearized
+/// `vector.extract` + `vector.fma` + `vector.insert`.
+///
+/// [ContractionOpLowering]
+/// Progressive lowering of ContractionOp.
+/// One:
+///   %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+///   %a = vector.contract with one less free/batch dimension
+///   %b = vector.contract with one less free/batch dimension
+///
+/// [ContractionOpToMatmulOpLowering]
+/// Progressively lower a `vector.contract` with row-major matmul semantics to
+/// linearized `vector.shape_cast` + `vector.matmul` on the way to
+/// `llvm.matrix.multiply`.
+///
+/// [ContractionOpToDotLowering]
+/// Progressively lower a `vector.contract` with row-major matmul semantics to
+/// linearized `vector.extract` + `vector.reduce` + `vector.insert`.
+///
+/// [ContractionOpToOuterProductOpLowering]
+/// Progressively lower a `vector.contract` with row-major matmul semantics to
+/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
+void populateVectorContractLoweringPatterns(
+    RewritePatternSet &patterns, VectorTransformsOptions options,
+    PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
+
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
+///
+/// [InnerOuterDimReductionConversion]
+/// Rewrites vector.multi_reduction such that all reduction dimensions are
+/// either innermost or outermost, by adding the proper vector.transpose
+/// operations.
+///
+/// [ReduceMultiDimReductionRank]
+/// Once in innermost or outermost reduction
+/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
+/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
+/// back.
+///
+/// [TwoDimMultiReductionToElementWise]
+/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
+/// ops. This also has an opportunity for tree-reduction (in the future).
+///
+/// [TwoDimMultiReductionToReduction]
+/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of extract +
+/// vector.reduction + insert. This can further lower to horizontal reduction
+/// ops.
+///
+/// [OneDimMultiReductionToTwoDim]
+/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
+/// either a parallel or a reduction), we lift them back up to 2-D with a simple
+/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
+/// thus fully exiting out of the vector.multi_reduction abstraction.
+void populateVectorMultiReductionLoweringPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [TransferReadToVectorLoadLowering]
+/// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D
+/// BroadcastOp until dim 1.
+void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
+                                             PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [CreateMaskOp]
+/// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1.
+///
+/// [ConstantMaskOp]
+/// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until
+/// dim 1.
+void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+
+/// Collects patterns that lower scalar vector transfer ops to memref loads and
+/// stores when beneficial.
+void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
+                                                  PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [ShapeCastOp2DDownCastRewritePattern]
+/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
+/// vectors progressively.
+///
+/// [ShapeCastOp2DUpCastRewritePattern]
+/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
+/// vectors progressively.
+///
+/// [ShapeCastOpRewritePattern]
+/// Reference lowering to fully unrolled sequences of single element ExtractOp +
+/// InsertOp. Note that applying this pattern can almost always be considered a
+/// performance bug.
+void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
+                                             PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [TransposeOpLowering]
+///
+/// [TransposeOp2DToShuffleLowering]
+///
+void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
+                                             VectorTransformsOptions options,
+                                             PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [TransferReadToVectorLoadLowering]
+/// Progressive lowering of transfer_read.This pattern supports lowering of
+/// `vector.transfer_read` to a combination of `vector.load` and
+/// `vector.broadcast`
+///
+/// [TransferWriteToVectorStoreLowering]
+/// Progressive lowering of transfer_write. This pattern supports lowering of
+/// `vector.transfer_write` to `vector.store`
+///
+/// [VectorLoadToMemrefLoadLowering]
+/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
+///
+/// [VectorStoreToMemrefStoreLowering]
+/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
+///
+/// These patterns lower transfer ops to simpler ops like `vector.load`,
+/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
+/// of a most `maxTransferRank` are lowered. This is useful when combined with
+/// VectorToSCF, which reduces the rank of vector transfer ops.
+void populateVectorTransferLoweringPatterns(
+    RewritePatternSet &patterns,
+    std::optional<unsigned> maxTransferRank = std::nullopt,
+    PatternBenefit benefit = 1);
+
+/// Collect a set of transfer read/write lowering patterns that simplify the
+/// permutation map (e.g., converting it to a minor identity map) by inserting
+/// broadcasts and transposes. More specifically:
+///
+/// [TransferReadPermutationLowering]
+/// Lower transfer_read op with permutation into a transfer_read with a
+/// permutation map composed of leading zeros followed by a minor identity +
+/// vector.transpose op.
+/// Ex:
+///     vector.transfer_read ...
+///         permutation_map: (d0, d1, d2) -> (0, d1)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2) -> (d1, 0)
+///     vector.transpose %v, [1, 0]
+///
+///     vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
+///     vector.transpose %v, [0, 1, 3, 2, 4]
+/// Note that an alternative is to transform it to linalg.transpose +
+/// vector.transfer_read to do the transpose in memory instead.
+///
+/// [TransferWritePermutationLowering]
+/// Lower transfer_write op with permutation into a transfer_write with a
+/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
+/// Ex:
+///     vector.transfer_write %v ...
+///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
+/// into:
+///     %tmp = vector.transpose %v, [2, 0, 1]
+///     vector.transfer_write %tmp ...
+///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
+///
+///     vector.transfer_write %v ...
+///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
+/// into:
+///     %tmp = vector.transpose %v, [1, 0]
+///     %v = vector.transfer_write %tmp ...
+///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
+///
+/// [TransferOpReduceRank]
+/// Lower transfer_read op with broadcast in the leading dimensions into
+/// transfer_read of lower rank + vector.broadcast.
+/// Ex: vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
+/// into:
+///     %v = vector.transfer_read ...
+///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
+///     vector.broadcast %v
+void populateVectorTransferPermutationMapLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [ScanToArithOps]
+/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
+/// vector.extract_strided_slice.
+void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
+                                        PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenGather]
+/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// outermost dimension. For example:
+///
+/// [Gather1DToConditionalLoads]
+/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
+/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
+/// loads/extracts are made conditional using `scf.if` ops.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+
+/// Populates instances of `MaskOpRewritePattern` to lower masked operations
+/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
+/// not its nested `MaskableOpInterface`.
+void populateVectorMaskLoweringPatternsForSideEffectingOps(
+    RewritePatternSet &patterns);
+
+} // namespace vector
+} // namespace mlir
+#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index d0c06f69930d2..bf89b01e2b60c 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -22,12 +22,6 @@ std::unique_ptr<Pass> createVectorBufferizePass();
 /// Creates an instance of the `vector.mask` lowering pass.
 std::unique_ptr<Pass> createLowerVectorMaskPass();
 
-/// Populates instances of `MaskOpRewritePattern` to lower masked operations
-/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
-/// not its nested `MaskableOpInterface`.
-void populateVectorMaskLoweringPatternsForSideEffectingOps(
-    RewritePatternSet &patterns);
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index af68de7e0051e..a79bbd0be0975 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -9,8 +9,8 @@
 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
 
-#include <utility>
 #include <optional>
+#include <utility>
 
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
@@ -23,42 +23,7 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace vector {
-
-//===----------------------------------------------------------------------===//
-// Vector transformation options exposed as auxiliary structs.
-//===----------------------------------------------------------------------===//
-/// Structure to control the behavior of vector transform patterns.
-struct VectorTransformsOptions {
-  /// Option to control the lowering of vector.contract.
-  VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
-  VectorTransformsOptions &
-  setVectorTransformsOptions(VectorContractLowering opt) {
-    vectorContractLowering = opt;
-    return *this;
-  }
-  /// Option to control the lowering of vector.multi_reduction.
-  VectorMultiReductionLowering vectorMultiReductionLowering =
-      VectorMultiReductionLowering::InnerParallel;
-  VectorTransformsOptions &
-  setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
-    vectorMultiReductionLowering = opt;
-    return *this;
-  }
-  /// Option to control the lowering of vector.transpose.
-  VectorTransposeLowering vectorTransposeLowering =
-      VectorTransposeLowering::EltWise;
-  VectorTransformsOptions &
-  setVectorTransposeLowering(VectorTransposeLowering opt) {
-    vectorTransposeLowering = opt;
-    return *this;
-  }
-  /// Option to control the splitting of vector transfers.
-  VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
-  VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
-    vectorTransferSplit = opt;
-    return *this;
-  }
-};
+struct VectorTransformsOptions;
 
 /// Options that control the vector unrolling.
 struct UnrollVectorOptions {
@@ -109,45 +74,6 @@ struct UnrollVectorOptions {
 // Vector transformation exposed as populate functions over rewrite patterns.
 //===----------------------------------------------------------------------===//
 
-/// Insert TransposeLowering patterns into extraction/insertion.
-void populateVectorTransposeLoweringPatterns(
-    RewritePatternSet &patterns,
-    VectorTransformsOptions options = VectorTransformsOptions(),
-    PatternBenefit benefit = 1);
-
-/// Collect a set of patterns to convert vector.multi_reduction op into
-/// a sequence of vector.reduction ops. The patterns comprise:
-/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
-/// that all reduction dimensions are either innermost or outermost, by adding
-/// the proper vector.transpose operations.
-/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
-/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
-/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
-/// back.
-/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
-/// form, with an **outermost** reduction dimension, unroll the outer dimension
-/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
-/// tree-reduction (in the future).
-/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
-/// with an **innermost** reduction dimension, unroll the outer dimension to
-/// obtain a sequence of extract + vector.reduction + insert. This can further
-/// lower to horizontal reduction ops.
-/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
-/// reduction (and are thus missing either a parallel or a reduction), we lift
-/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
-/// the other patterns can kick in, thus fully exiting out of the
-/// vector.multi_reduction abstraction.
-void populateVectorMultiReductionLoweringPatterns(
-    RewritePatternSet &patterns, VectorMultiReductionLowering options,
-    PatternBenefit benefit = 1);
-
-/// Collects patterns to progressively lower vector contraction ops on high-D
-/// into low-D reduction and product ops.
-void populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns,
-    VectorTransformsOptions options = VectorTransformsOptions(),
-    PatternBenefit benefit = 1);
-
 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
 /// semantics to a contraction with MMT semantics (matrix matrix multiplication
 /// with the RHS transposed). This specific form is meant to have the vector
@@ -174,67 +100,43 @@ void populateVectorContractCanonicalizeMatmulToMMT(
 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
                                                PatternBenefit benefit = 1);
 
-/// Collect patterns to convert scan op
-void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
-                                        PatternBenefit benefit = 1);
-
-//===----------------------------------------------------------------------===//
-// Vector.transfer patterns.
-//===----------------------------------------------------------------------===//
-/// Collect a set of transfer read/write lowering patterns that simplify the
-/// permutation map (e.g., converting it to a minor identity map) by inserting
-/// broadcasts and transposes. More specifically:
-///
-/// [TransferReadPermutationLowering]
-/// Lower transfer_read op with permutation into a transfer_read with a
-/// permutation map composed of leading zeros followed by a minor identity +
-/// vector.transpose op.
-/// Ex:
-///     vector.transfer_read ...
-///         permutation_map: (d0, d1, d2) -> (0, d1)
-/// into:
-///     %v = vector.transfer_read ...
-///         permutation_map: (d0, d1, d2) -> (d1, 0)
-///     vector.transpose %v, [1, 0]
+/// Populate `patterns` with the following patterns.
 ///
-///     vector.transfer_read ...
-///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
-/// into:
-///     %v = vector.transfer_read ...
-///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
-///     vector.transpose %v, [0, 1, 3, 2, 4]
-/// Note that an alternative is to transform it to linalg.transpose +
-/// vector.transfer_read to do the transpose in memory instead.
+///   - VectorTransferFullPartialRewriter
 ///
-/// [TransferWritePermutationLowering]
-/// Lower transfer_write op with permutation into a transfer_write with a
-/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
-/// Ex:
-///     vector.transfer_write %v ...
-///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
-/// into:
-///     %tmp = vector.transpose %v, [2, 0, 1]
-///     vector.transfer_write %tmp ...
-///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
+/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
+/// masking) fast path and a slow path.
 ///
-///     vector.transfer_write %v ...
-///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
-/// into:
-///     %tmp = vector.transpose %v, [1, 0]
-///     %v = vector.transfer_write %tmp ...
-///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
+/// Example (a 2-D vector.transfer_read):
+/// ```
+///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      // fast path, direct cast
+///      memref.cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view : compatibleMemRefType, index, index
+///    } else {
+///      // slow path, not in-bounds vector.transfer or linalg.copy.
+///      memref.cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4 : compatibleMemRefType, index, index
+//     }
+///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
+/// ```
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
 ///
-/// [TransferOpReduceRank]
-/// Lower transfer_read op with broadcast in the leading dimensions into
-/// transfer_read of lower rank + vector.broadcast.
-/// Ex: vector.transfer_read ...
-///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
-/// into:
-///     %v = vector.transfer_read ...
-///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
-///     vector.broadcast %v
-void populateVectorTransferPermutationMapLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Preconditions:
+///  1. `xferOp.permutation_map()` must be a minor identity map
+///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
+///  must be equal. This will be relaxed in the future but requires
+///  rank-reducing subviews.
+void populateVectorTransferFullPartialPatterns(
+    RewritePatternSet &patterns, const VectorTransformsOptions &options);
+
+//===----------------------------------------------------------------------===//
+// Vector.transfer patterns.
+//===----------------------------------------------------------------------===//
 
 /// Collect a set of patterns to reduce the rank of the operands of vector
 /// transfer ops to operate on the largest contigious vector.
@@ -334,220 +236,6 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
                                   const UnrollVectorOptions &options,
                                   PatternBenefit benefit = 1);
 
-/// Expands `vector.gather` ops into a series of conditional scalar loads
-/// (`vector.load` for memrefs or `tensor.extract` for tensors). These loads are
-/// conditional to avoid out-of-bounds memory accesses and guarded with `scf.if`
-/// ops. This lowering path is intended for targets that do not feature
-/// dedicated gather ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
-
-//===----------------------------------------------------------------------===//
-// Finer-grained patterns exposed for more control over individual lowerings.
-//===----------------------------------------------------------------------===//
-/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
-/// may take an extra filter to perform selection at a finer granularity.
-struct VectorTransferFullPartialRewriter : public RewritePattern {
-  using FilterConstraintType =
-      std::function<LogicalResult(VectorTransferOpInterface op)>;
-
-  explicit VectorTransferFullPartialRewriter(
-      MLIRContext *context,
-      VectorTransformsOptions options = VectorTransformsOptions(),
-      FilterConstraintType filter =
-          [](VectorTransferOpInterface op) { return success(); },
-      PatternBenefit benefit = 1)
-      : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
-        filter(std::move(filter)) {}
-
-  /// Performs the rewrite.
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  VectorTransformsOptions options;
-  FilterConstraintType filter;
-};
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to:
-/// ```
-///    %flattened_a = vector.shape_cast %a
-///    %flattened_b = vector.shape_cast %b
-///    %flattened_d = vector.matmul %flattened_a, %flattened_b
-///    %d = vector.shape_cast %%flattened_d
-///    %e = add %c, %d
-/// ```
-/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
-//
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
-/// the vector.contract op is a row-major matrix multiply.
-class ContractionOpToMatmulOpLowering
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-
-  ContractionOpToMatmulOpLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context, PatternBenefit benefit = 1,
-      FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
-        filter(std::move(constraint)) {}
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
-  FilterConstraintType filter;
-};
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to a reduction_size-unrolled sequence:
-/// ```
-///    %at = vector.transpose %a, [1, 0]
-///    %bRow0 = vector.extract %b[0]
-///    %atRow0 = vector.extract %at[0]
-///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
-///    ...
-///    %bRowK = vector.extract %b[K]
-///    %atRowK = vector.extract %at[K]
-///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
-/// the vector.contract op is a row-major matrix multiply.
-class ContractionOpToOuterProductOpLowering
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-
-  ContractionOpToOuterProductOpLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context, PatternBenefit benefit = 1,
-      FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
-        filter(std::move(constraint)) {}
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
-  FilterConstraintType filter;
-};
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to an output-size-unrolled sequence:
-/// ```
-///    %out = arith.constant ... : vector<MxNxelt_type>
-///    %bt = vector.transpose %b, [1, 0]
-///    %aRow0 = vector.extract %a[0]
-///    %btRow0 = vector.extract %bt[0]
-///    %c00 = vector.reduce %atRow0, %bRow0
-///    %out00 = vector.insert %c00, %out[0, 0]
-///    ...
-///    %aRowLast = vector.extract %at[M-1]
-///    %btRowLast = vector.extract %b[N-1]
-///    %cLastLast = vector.reduce %atRowLast, %bRowLast
-///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to Dot and
-/// the vector.contract op is a row-major matmul or matvec.
-class ContractionOpToDotLowering
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-
-  ContractionOpToDotLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context, PatternBenefit benefit = 1,
-      const FilterConstraintType &constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
-  FilterConstraintType filter;
-};
-
-/// Progressive lowering of ContractionOp.
-///
-/// One:
-///   %x = vector.contract with at least one free/batch dimension
-/// is replaced by:
-///   %a = vector.contract with one less free/batch dimension
-///   %b = vector.contract with one less free/batch dimension
-///   ..
-///   %x = combine %a %b ..
-/// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a dot-product.
-///
-/// This only kicks in when either VectorTransformsOptions is set
-/// to Dot or when other contraction patterns fail.
-class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-
-  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
-                        MLIRContext *context, PatternBenefit benefit = 1,
-                        FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
-        filter(std::move(constraint)) {}
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
-  FilterConstraintType filter;
-  // Lower one parallel dimension.
-  FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
-                                 vector::ContractionOp op, int64_t lhsIndex,
-                                 int64_t rhsIndex, Value mask) const;
-  // Lower one reduction dimension.
-  FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
-                                  vector::ContractionOp op, Value mask) const;
-};
-
 } // namespace vector
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
index 947911f9a3841..52a4c9cc368d8 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
@@ -24,17 +24,53 @@ class IfOp;
 
 namespace vector {
 
+//===----------------------------------------------------------------------===//
+// Vector transformation options exposed as auxiliary structs.
+//===----------------------------------------------------------------------===//
+/// Structure to control the behavior of vector transform patterns.
+struct VectorTransformsOptions {
+  /// Option to control the lowering of vector.contract.
+  VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
+  VectorTransformsOptions &
+  setVectorTransformsOptions(VectorContractLowering opt) {
+    vectorContractLowering = opt;
+    return *this;
+  }
+  /// Option to control the lowering of vector.multi_reduction.
+  VectorMultiReductionLowering vectorMultiReductionLowering =
+      VectorMultiReductionLowering::InnerParallel;
+  VectorTransformsOptions &
+  setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
+    vectorMultiReductionLowering = opt;
+    return *this;
+  }
+  /// Option to control the lowering of vector.transpose.
+  VectorTransposeLowering vectorTransposeLowering =
+      VectorTransposeLowering::EltWise;
+  VectorTransformsOptions &
+  setVectorTransposeLowering(VectorTransposeLowering opt) {
+    vectorTransposeLowering = opt;
+    return *this;
+  }
+  /// Option to control the splitting of vector transfers.
+  VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
+  VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
+    vectorTransferSplit = opt;
+    return *this;
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Standalone transformations and helpers.
 //===----------------------------------------------------------------------===//
-/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
-/// masking) fastpath and a slowpath.
-/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
-/// newly created conditional upon function return.
-/// To accomodate for the fact that the original vector.transfer indexing may be
-/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
-/// scf.if op returns a view and values of type index.
-/// At this time, only vector.transfer_read case is implemented.
+/// Split a vector.transfer operation into an in-bounds (i.e., no
+/// out-of-bounds masking) fastpath and a slowpath. If `ifOp` is not null and
+/// the result is `success, the `ifOp` points to the newly created conditional
+/// upon function return. To accomodate for the fact that the original
+/// vector.transfer indexing may be arbitrary and the slow path indexes
+/// @[0...0] in the temporary buffer, the scf.if op returns a view and values
+/// of type index. At this time, only vector.transfer_read case is
+/// implemented.
 ///
 /// Example (a 2-D vector.transfer_read):
 /// ```
@@ -51,15 +87,16 @@ namespace vector {
 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
 ///      scf.yield %4 : compatibleMemRefType, index, index
 //     }
-///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
+///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ...
+///    true]}
 /// ```
 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
 ///
 /// Preconditions:
 ///  1. `xferOp.permutation_map()` must be a minor identity map
-///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
-///  must be equal. This will be relaxed in the future but requires
-///  rank-reducing subviews.
+///  2. the rank of the `xferOp.memref()` and the rank of the
+///  `xferOp.vector()` must be equal. This will be relaxed in the future but
+///  requires rank-reducing subviews.
 LogicalResult splitFullAndPartialTransfer(
     RewriterBase &b, VectorTransferOpInterface xferOp,
     VectorTransformsOptions options = VectorTransformsOptions(),

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c56d03f6f31d7..05def0f45d7fb 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index fb544df18324b..3f1b107f6f8e0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
@@ -64,10 +65,11 @@ void LowerVectorToLLVMPass::runOnOperation() {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
     populateVectorBroadcastLoweringPatterns(patterns);
-    populateVectorContractLoweringPatterns(patterns);
+    populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
     populateVectorMaskOpLoweringPatterns(patterns);
     populateVectorShapeCastLoweringPatterns(patterns);
-    populateVectorTransposeLoweringPatterns(patterns);
+    populateVectorTransposeLoweringPatterns(patterns,
+                                            VectorTransformsOptions());
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index d8070b34a761d..ec2e2aa4c0624 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -10,8 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <type_traits>
 #include <optional>
+#include <type_traits>
 
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
index eb97c6e168e5c..b7d9812ada0b1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -20,5 +20,5 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
   MLIRSideEffectInterfaces
   MLIRTransformDialect
   MLIRTransformDialectUtils
-  MLIRVectorDialect
+  MLIRVectorTransforms
   )

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d98eb3b781fc5..e3c1429ade54a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/Dialect/Transform/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 60996b9add614..136d234742b8d 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -7,13 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
-
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/Parser/Parser.h"
@@ -82,10 +83,9 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
 
     // In the future we may want to more finely select particular stages.
     // Stage 1: contraction lowerings.
-    patterns.add<mlir::vector::ContractionOpToOuterProductOpLowering,
-                 mlir::vector::ContractionOpToMatmulOpLowering,
-                 mlir::vector::ContractionOpLowering>(vectorTransformOptions,
-                                                      ctx);
+    populateVectorContractLoweringPatterns(
+        patterns, vectorTransformOptions, /*benefit=*/1,
+        /*disableOuterProductLowering*/ true);
     vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
 
     // Stage 2: multi-reduction lowerings.
@@ -93,8 +93,7 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
         patterns, vectorTransformOptions.vectorMultiReductionLowering);
 
     // Stage 3: Rewrite vector.transfer into full and partial parts.
-    patterns.add<vector::VectorTransferFullPartialRewriter>(
-        ctx, vectorTransformOptions);
+    populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
 
     // Stage 4: Lower vector transfers.
     vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank);
@@ -107,8 +106,8 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
     vector::populateVectorShapeCastLoweringPatterns(patterns);
 
     // Stage 7: Lower vector.transpose.
-    vector::populateVectorTransposeLoweringPatterns(patterns,
-                                                    vectorTransformOptions);
+    vector::populateVectorTransposeLoweringPatterns(
+        patterns, vectorTransformOptions, /*benefit=*/1);
     if (getTransposeAvx2Lowering())
       x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
           patterns, avx2LoweringOptions, /*benefit=*/10);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 6fb1b8c18a122..f17208e193b3c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,14 +1,20 @@
 add_mlir_dialect_library(MLIRVectorTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  LowerVectorBroadcast.cpp
+  LowerVectorContract.cpp
+  LowerVectorGather.cpp
   LowerVectorMask.cpp
+  LowerVectorMultiReduction.cpp
+  LowerVectorScan.cpp
+  LowerVectorShapeCast.cpp
+  LowerVectorTransfer.cpp
+  LowerVectorTranspose.cpp
   VectorDistribute.cpp
   VectorDropLeadUnitDim.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
-  VectorMultiDimReductionTransforms.cpp
   VectorTransferOpTransforms.cpp
   VectorTransferSplitRewritePatterns.cpp
-  VectorTransferPermutationMapRewritePatterns.cpp
   VectorTransforms.cpp
   VectorUnroll.cpp
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
new file mode 100644
index 0000000000000..ad538fe4a6828
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -0,0 +1,156 @@
+//===- LowerVectorBroadcast.cpp - Lower 'vector.broadcast' operation ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.broadcast' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-broadcast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// Progressive lowering of BroadcastOp.
+class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    VectorType dstType = op.getResultVectorType();
+    VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
+    Type eltType = dstType.getElementType();
+
+    // Scalar to any vector can use splat.
+    if (!srcType) {
+      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
+      return success();
+    }
+
+    // Determine rank of source and destination.
+    int64_t srcRank = srcType.getRank();
+    int64_t dstRank = dstType.getRank();
+
+    // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
+    if (srcRank <= 1 && dstRank == 1) {
+      Value ext;
+      if (srcRank == 0)
+        ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
+      else
+        ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
+      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+      return success();
+    }
+
+    // Duplicate this rank.
+    // For example:
+    //   %x = broadcast %y  : k-D to n-D, k < n
+    // becomes:
+    //   %b = broadcast %y  : k-D to (n-1)-D
+    //   %x = [%b,%b,%b,%b] : n-D
+    // becomes:
+    //   %b = [%y,%y]       : (n-1)-D
+    //   %x = [%b,%b,%b,%b] : n-D
+    if (srcRank < dstRank) {
+      // Duplication.
+      VectorType resType =
+          VectorType::get(dstType.getShape().drop_front(), eltType);
+      Value bcst =
+          rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
+      Value result = rewriter.create<arith::ConstantOp>(
+          loc, dstType, rewriter.getZeroAttr(dstType));
+      for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
+        result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+      rewriter.replaceOp(op, result);
+      return success();
+    }
+
+    // Find non-matching dimension, if any.
+    assert(srcRank == dstRank);
+    int64_t m = -1;
+    for (int64_t r = 0; r < dstRank; r++)
+      if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
+        m = r;
+        break;
+      }
+
+    // All trailing dimensions are the same. Simply pass through.
+    if (m == -1) {
+      rewriter.replaceOp(op, op.getSource());
+      return success();
+    }
+
+    // Any non-matching dimension forces a stretch along this rank.
+    // For example:
+    //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
+    // becomes:
+    //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
+    //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
+    //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
+    //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
+    //   %x = [%a,%b,%c,%d]
+    // becomes:
+    //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
+    //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
+    //   %a = [%u, %v]
+    //   ..
+    //   %x = [%a,%b,%c,%d]
+    VectorType resType =
+        VectorType::get(dstType.getShape().drop_front(), eltType);
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, dstType, rewriter.getZeroAttr(dstType));
+    if (m == 0) {
+      // Stetch at start.
+      Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
+      Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
+      for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
+        result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+    } else {
+      // Stetch not at start.
+      for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
+        Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
+        Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
+        result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+      }
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorBroadcastLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<BroadcastOpLowering>(patterns.getContext(), benefit);
+}

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
new file mode 100644
index 0000000000000..1280cfef0b645
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -0,0 +1,1329 @@
+//===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.contract' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-contract-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+// Helper to find an index in an affine map.
+static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
+  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+    int64_t idx = map.getDimPosition(i);
+    if (idx == index)
+      return i;
+  }
+  return std::nullopt;
+}
+
+// Helper to construct iterator types with one index removed.
+static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
+                                         int64_t index) {
+  SmallVector<Attribute> results;
+  for (const auto &it : llvm::enumerate(iteratorTypes)) {
+    int64_t idx = it.index();
+    if (idx == index)
+      continue;
+    results.push_back(it.value());
+  }
+  return results;
+}
+
+// Helper to construct an affine map with one index removed.
+static AffineMap adjustMap(AffineMap map, int64_t index,
+                           PatternRewriter &rewriter) {
+  auto *ctx = rewriter.getContext();
+  SmallVector<AffineExpr> results;
+  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+    int64_t idx = map.getDimPosition(i);
+    if (idx == index)
+      continue;
+    // Re-insert remaining indices, but renamed when occurring
+    // after the removed index.
+    auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
+    results.push_back(targetExpr);
+  }
+  return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
+}
+
+// Helper method to possibly drop a dimension in a load.
+// TODO
+static Value reshapeLoad(Location loc, Value val, VectorType type,
+                         int64_t index, int64_t pos,
+                         PatternRewriter &rewriter) {
+  if (index == -1)
+    return val;
+  Type lowType = VectorType::Builder(type).dropDim(0);
+  // At extraction dimension?
+  if (index == 0) {
+    auto posAttr = rewriter.getI64ArrayAttr(pos);
+    return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
+  }
+  // Unroll leading dimensions.
+  VectorType vType = lowType.cast<VectorType>();
+  Type resType = VectorType::Builder(type).dropDim(index);
+  auto resVectorType = resType.cast<VectorType>();
+  Value result = rewriter.create<arith::ConstantOp>(
+      loc, resVectorType, rewriter.getZeroAttr(resVectorType));
+  for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
+    auto posAttr = rewriter.getI64ArrayAttr(d);
+    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
+    Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
+    result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
+                                               posAttr);
+  }
+  return result;
+}
+
+// Helper method to possibly drop a dimension in a store.
+// TODO
+static Value reshapeStore(Location loc, Value val, Value result,
+                          VectorType type, int64_t index, int64_t pos,
+                          PatternRewriter &rewriter) {
+  // Unmodified?
+  if (index == -1)
+    return val;
+  // At insertion dimension?
+  if (index == 0) {
+    auto posAttr = rewriter.getI64ArrayAttr(pos);
+    return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
+  }
+  // Unroll leading dimensions.
+  Type lowType = VectorType::Builder(type).dropDim(0);
+  VectorType vType = lowType.cast<VectorType>();
+  Type insType = VectorType::Builder(vType).dropDim(0);
+  for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
+    auto posAttr = rewriter.getI64ArrayAttr(d);
+    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
+    Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
+    Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
+    result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
+  }
+  return result;
+}
+
+/// Helper to create arithmetic operation associated with a kind of contraction.
+static std::optional<Value>
+createContractArithOp(Location loc, Value x, Value y, Value acc,
+                      vector::CombiningKind kind, PatternRewriter &rewriter,
+                      bool isInt, Value mask = Value()) {
+  using vector::CombiningKind;
+  Value mul;
+
+  if (isInt) {
+    if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
+      // Only valid for floating point types.
+      return std::nullopt;
+    mul = rewriter.create<arith::MulIOp>(loc, x, y);
+  } else {
+    // Float case.
+    if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
+        kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
+        kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
+        kind == CombiningKind::XOR)
+      // Only valid for integer types.
+      return std::nullopt;
+    // Special case for fused multiply-add.
+    if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
+      Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
+      if (mask)
+        // The fma op doesn't need explicit masking. However, fma ops used in
+        // reductions must preserve previous 'acc' values for masked-out lanes.
+        fma = selectPassthru(rewriter, mask, fma, acc);
+      return fma;
+    }
+    mul = rewriter.create<arith::MulFOp>(loc, x, y);
+  }
+
+  if (!acc)
+    return std::optional<Value>(mul);
+
+  return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
+}
+
+/// Return the positions of the reductions in the given map.
+static SmallVector<int64_t> getReductionIndex(AffineMap map,
+                                              ArrayAttr iteratorTypes) {
+  SmallVector<int64_t> dimsIdx;
+  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+    if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
+      dimsIdx.push_back(i);
+  }
+  return dimsIdx;
+}
+
+/// Look for a given dimension in an affine map and return its position. Return
+/// std::nullopt if the dimension is not in the map results.
+static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
+  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+    if (map.getDimPosition(i) == dim)
+      return i;
+  }
+  return std::nullopt;
+}
+
+/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
+/// operands `x` and `y`.
+static Value createAdd(Location loc, Value x, Value y, bool isInt,
+                       PatternRewriter &rewriter) {
+  if (isInt)
+    return rewriter.create<arith::AddIOp>(loc, x, y);
+  return rewriter.create<arith::AddFOp>(loc, x, y);
+}
+
+/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
+/// operands `x and `y`.
+static Value createMul(Location loc, Value x, Value y, bool isInt,
+                       PatternRewriter &rewriter) {
+  if (isInt)
+    return rewriter.create<arith::MulIOp>(loc, x, y);
+  return rewriter.create<arith::MulFOp>(loc, x, y);
+}
+
+namespace {
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+///    %flattened_a = vector.shape_cast %a
+///    %flattened_b = vector.shape_cast %b
+///    %flattened_d = vector.matmul %flattened_a, %flattened_b
+///    %d = vector.shape_cast %%flattened_d
+///    %e = add %c, %d
+/// ```
+/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
+//
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToMatmulOpLowering
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
+
+  ContractionOpToMatmulOpLowering(
+      vector::VectorTransformsOptions vectorTransformOptions,
+      MLIRContext *context, PatternBenefit benefit = 1,
+      FilterConstraintType constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+        vectorTransformOptions(vectorTransformOptions),
+        filter(std::move(constraint)) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+  FilterConstraintType filter;
+};
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to a reduction_size-unrolled sequence:
+/// ```
+///    %at = vector.transpose %a, [1, 0]
+///    %bRow0 = vector.extract %b[0]
+///    %atRow0 = vector.extract %at[0]
+///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
+///    ...
+///    %bRowK = vector.extract %b[K]
+///    %atRowK = vector.extract %at[K]
+///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToOuterProductOpLowering
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
+
+  ContractionOpToOuterProductOpLowering(
+      vector::VectorTransformsOptions vectorTransformOptions,
+      MLIRContext *context, PatternBenefit benefit = 1,
+      FilterConstraintType constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+        vectorTransformOptions(vectorTransformOptions),
+        filter(std::move(constraint)) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+  FilterConstraintType filter;
+};
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to an output-size-unrolled sequence:
+/// ```
+///    %out = arith.constant ... : vector<MxNxelt_type>
+///    %bt = vector.transpose %b, [1, 0]
+///    %aRow0 = vector.extract %a[0]
+///    %btRow0 = vector.extract %bt[0]
+///    %c00 = vector.reduce %atRow0, %bRow0
+///    %out00 = vector.insert %c00, %out[0, 0]
+///    ...
+///    %aRowLast = vector.extract %at[M-1]
+///    %btRowLast = vector.extract %b[N-1]
+///    %cLastLast = vector.reduce %atRowLast, %bRowLast
+///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to Dot and
+/// the vector.contract op is a row-major matmul or matvec.
+class ContractionOpToDotLowering
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
+
+  ContractionOpToDotLowering(
+      vector::VectorTransformsOptions vectorTransformOptions,
+      MLIRContext *context, PatternBenefit benefit = 1,
+      const FilterConstraintType &constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+  FilterConstraintType filter;
+};
+
+/// Progressive lowering of ContractionOp.
+///
+/// One:
+///   %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+///   %a = vector.contract with one less free/batch dimension
+///   %b = vector.contract with one less free/batch dimension
+///   ..
+///   %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a dot-product.
+///
+/// This only kicks in when either VectorTransformsOptions is set
+/// to Dot or when other contraction patterns fail.
+class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
+
+  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
+                        MLIRContext *context, PatternBenefit benefit = 1,
+                        FilterConstraintType constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+        vectorTransformOptions(vectorTransformOptions),
+        filter(std::move(constraint)) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+  FilterConstraintType filter;
+  // Lower one parallel dimension.
+  FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
+                                 vector::ContractionOp op, int64_t lhsIndex,
+                                 int64_t rhsIndex, Value mask) const;
+  // Lower one reduction dimension.
+  FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
+                                  vector::ContractionOp op, Value mask) const;
+};
+
+/// Generate a vector implementation for matmat, matvec and tmatvec.
+/// This unrolls outer-products along the reduction dimension.
+struct UnrolledOuterProductGenerator
+    : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
+  UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
+      : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
+        kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
+        res(op.getAcc()), lhsType(op.getLhsType()) {
+    auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+    if (maskableOp.isMasked())
+      mask = maskableOp.getMaskingOp().getMask();
+  }
+
+  Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
+    if (!v)
+      return v;
+    return rewriter.create<vector::TransposeOp>(loc, v, perm);
+  }
+
+  Value promote(Value v, Type dstElementType) {
+    Type elementType = v.getType();
+    auto vecType = elementType.dyn_cast<VectorType>();
+    if (vecType)
+      elementType = vecType.getElementType();
+    if (elementType == dstElementType)
+      return v;
+    Type promotedType = dstElementType;
+    if (vecType)
+      promotedType = VectorType::get(vecType.getShape(), promotedType);
+    if (dstElementType.isa<FloatType>())
+      return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
+    return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
+  }
+
+  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
+                             std::optional<Value> maybeMask = std::nullopt) {
+    assert(reductionSize > 0);
+    // Incremental support for masking.
+    if (mask && !maybeMask.has_value())
+      return failure();
+
+    Type resElementType = res.getType().cast<VectorType>().getElementType();
+    for (int64_t k = 0; k < reductionSize; ++k) {
+      Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
+      Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
+      extractA = promote(extractA, resElementType);
+      extractB = promote(extractB, resElementType);
+      Value extractMask;
+      if (maybeMask.has_value() && maybeMask.value())
+        extractMask =
+            rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
+
+      Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
+          loc, res.getType(), extractA, extractB, res, kind);
+      res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
+    }
+    return res;
+  }
+
+  /// Two outer parallel, one inner reduction (matmat flavor).
+  FailureOr<Value> matmat() {
+    if (!iters({Par(), Par(), Red()}))
+      return failure();
+    // Set up the parallel/reduction structure in the right form.
+    AffineExpr m, n, k;
+    bindDims(rewriter.getContext(), m, n, k);
+    // Classical row-major matmul:  Just permute the lhs.
+    if (layout({{m, k}, {k, n}, {m, n}}))
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
+                       t(mask, {2, 0, 1}));
+    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+    if (layout({{m, k}, {n, k}, {m, n}})) {
+      Value tlhs = t(lhs);
+      return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
+    }
+    // No need to permute anything.
+    if (layout({{k, m}, {k, n}, {m, n}}))
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+    // Just permute the rhs.
+    if (layout({{k, m}, {n, k}, {m, n}}))
+      return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
+    // Transposed output: swap RHS and LHS.
+    // Classical row-major matmul: permute the lhs.
+    if (layout({{m, k}, {k, n}, {n, m}}))
+      return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
+    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+    if (layout({{m, k}, {n, k}, {n, m}})) {
+      Value trhs = t(rhs);
+      return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
+    }
+    if (layout({{k, m}, {k, n}, {n, m}}))
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+    if (layout({{k, m}, {n, k}, {n, m}}))
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+    return failure();
+  }
+
+  /// One outer parallel, one inner reduction (matvec flavor)
+  FailureOr<Value> matvec() {
+    if (!iters({Par(), Red()}))
+      return failure();
+    AffineExpr m, k;
+    bindDims(rewriter.getContext(), m, k);
+
+    // Case mat-vec: transpose.
+    if (layout({{m, k}, {k}, {m}}))
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
+    // Case mat-trans-vec: ready to go.
+    if (layout({{k, m}, {k}, {m}}))
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+    // Case vec-mat: swap and transpose.
+    if (layout({{k}, {m, k}, {m}}))
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+    // Case vec-mat-trans: swap and ready to go.
+    if (layout({{k}, {k, m}, {m}}))
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+    return failure();
+  }
+
+  //
+  // One outer reduction, one inner parallel (tmatvec flavor)
+  //
+  FailureOr<Value> tmatvec() {
+    if (!iters({Red(), Par()}))
+      return failure();
+    AffineExpr k, m;
+    bindDims(rewriter.getContext(), k, m);
+
+    // Case mat-vec: transpose.
+    if (layout({{m, k}, {k}, {m}}))
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+    // Case mat-trans-vec: ready to go.
+    if (layout({{k, m}, {k}, {m}}))
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+    // Case vec-mat: swap and transpose.
+    if (layout({{k}, {m, k}, {m}}))
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+    // Case vec-mat-trans: swap and ready to go.
+    if (layout({{k}, {k, m}, {m}}))
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+    return failure();
+  }
+
+private:
+  vector::CombiningKind kind;
+  Value lhs, rhs, res, mask;
+  VectorType lhsType;
+};
+
+/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to a reduction_size-unrolled sequence:
+/// ```
+///    %at = vector.transpose %a, [1, 0]
+///    %bRow0 = vector.extract %b[0]
+///    %atRow0 = vector.extract %at[0]
+///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
+///    ...
+///    %bRowK = vector.extract %b[K]
+///    %atRowK = vector.extract %at[K]
+///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
+/// otherwise supports any layout permutation of the matrix-multiply.
+LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
+    vector::ContractionOp op, PatternRewriter &rewriter) const {
+  // TODO: Remove native masks from contraction op?
+  if (!op.getMasks().empty())
+    return failure();
+
+  if (vectorTransformOptions.vectorContractLowering !=
+      vector::VectorContractLowering::OuterProduct)
+    return failure();
+
+  if (failed(filter(op)))
+    return failure();
+
+  // Vector mask setup.
+  OpBuilder::InsertionGuard guard(rewriter);
+  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+  Operation *rootOp;
+  if (maskableOp.isMasked()) {
+    rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+    rootOp = maskableOp.getMaskingOp();
+  } else {
+    rootOp = op;
+  }
+
+  UnrolledOuterProductGenerator e(rewriter, op);
+  FailureOr<Value> matmatRes = e.matmat();
+  if (succeeded(matmatRes)) {
+    rewriter.replaceOp(rootOp, *matmatRes);
+    return success();
+  }
+  FailureOr<Value> matvecRes = e.matvec();
+  if (succeeded(matvecRes)) {
+    rewriter.replaceOp(rootOp, *matvecRes);
+    return success();
+  }
+  FailureOr<Value> tmatvecRes = e.tmatvec();
+  if (succeeded(tmatvecRes)) {
+    rewriter.replaceOp(rootOp, *tmatvecRes);
+    return success();
+  }
+
+  return failure();
+}
+
+LogicalResult
+ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
+                                            PatternRewriter &rewriter) const {
+  // TODO: Support vector.mask.
+  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+  if (maskableOp.isMasked())
+    return failure();
+
+  // TODO: Remove native masks from contraction op?
+  if (!op.getMasks().empty())
+    return failure();
+
+  if (failed(filter(op)))
+    return failure();
+
+  if (vectorTransformOptions.vectorContractLowering !=
+      vector::VectorContractLowering::Dot)
+    return failure();
+
+  auto iteratorTypes = op.getIteratorTypes().getValue();
+  static constexpr std::array<int64_t, 2> perm = {1, 0};
+  Location loc = op.getLoc();
+  Value lhs = op.getLhs(), rhs = op.getRhs();
+
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+  AffineExpr m, n, k;
+  bindDims(rewriter.getContext(), m, n, k);
+  SmallVector<AffineMap> maps = op.getIndexingMapsArray();
+  //
+  // In the following we wish to make the reduction dimension innermost so we
+  // can load vectors and just fmul + reduce into a scalar.
+  //
+  if (isParallelIterator(iteratorTypes[0]) &&
+      isParallelIterator(iteratorTypes[1]) &&
+      isReductionIterator(iteratorTypes[2])) {
+    //
+    // Two outer parallel, one inner reduction (matmat flavor).
+    //
+    if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
+      // No need to permute anything.
+    } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+      // This is the classical row-major matmul. Just permute the lhs.
+      Value tmp = lhs;
+      lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+      rhs = tmp;
+    } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+      std::swap(lhs, rhs);
+    } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+      Value tmp = lhs;
+      lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+      rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
+    } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+      Value tmp = rhs;
+      rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+      lhs = tmp;
+    } else {
+      return failure();
+    }
+  } else if (isParallelIterator(iteratorTypes[0]) &&
+             isReductionIterator(iteratorTypes[1])) {
+    //
+    // One outer parallel, one inner reduction (matvec flavor)
+    //
+    if (maps == infer({{m, n}, {n}, {m}})) {
+      // No need to permute anything.
+    } else if (maps == infer({{n, m}, {n}, {m}})) {
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else if (maps == infer({{n}, {m, n}, {m}})) {
+      std::swap(lhs, rhs);
+    } else if (maps == infer({{n}, {n, m}, {m}})) {
+      std::swap(lhs, rhs);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+    } else {
+      return failure();
+    }
+  } else {
+    return failure();
+  }
+
+  VectorType dstType = op.getResultType().cast<VectorType>();
+  assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
+         "Expected dst type of rank 1 or 2");
+
+  unsigned rank = dstType.getRank();
+  unsigned dstRows = dstType.getShape()[0];
+  unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
+
+  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
+  Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
+                                                 rewriter.getZeroAttr(dstType));
+  bool isInt = dstType.getElementType().isa<IntegerType>();
+  for (unsigned r = 0; r < dstRows; ++r) {
+    Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
+    for (unsigned c = 0; c < dstColumns; ++c) {
+      Value b = rank == 1
+                    ? rhs
+                    : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+      Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
+      Value reduced = rewriter.create<vector::ReductionOp>(
+          op.getLoc(), vector::CombiningKind::ADD, m);
+
+      SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
+                                              : SmallVector<int64_t, 2>{r, c};
+      res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
+    }
+  }
+  if (auto acc = op.getAcc())
+    res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
+  rewriter.replaceOp(op, res);
+  return success();
+}
+
+/// Lower vector.contract with all size one reduction dimensions to
+/// elementwise ops when possible.
+struct ContractOpToElementwise
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern::OpRewritePattern;
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
+  ContractOpToElementwise(
+      vector::VectorTransformsOptions vectorTransformOptions,
+      MLIRContext *context, PatternBenefit benefit = 1,
+      const FilterConstraintType &constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    // TODO: Support vector.mask.
+    auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
+    // TODO: Remove native masks from contraction op?
+    if (!contractOp.getMasks().empty())
+      return failure();
+
+    if (failed(filter(contractOp)))
+      return failure();
+
+    if (vectorTransformOptions.vectorContractLowering !=
+        vector::VectorContractLowering::ParallelArith)
+      return failure();
+
+    ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
+    ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
+    AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
+    AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
+    SmallVector<int64_t> lhsReductionDims =
+        getReductionIndex(lhsMap, contractOp.getIteratorTypes());
+    SmallVector<int64_t> rhsReductionDims =
+        getReductionIndex(rhsMap, contractOp.getIteratorTypes());
+    // All the reduction dimensions must be a size 1.
+    for (int64_t dim : lhsReductionDims) {
+      if (lhsShape[dim] != 1)
+        return failure();
+    }
+    for (int64_t dim : rhsReductionDims) {
+      if (rhsShape[dim] != 1)
+        return failure();
+    }
+    AffineMap accMap = contractOp.getIndexingMapsArray()[2];
+    unsigned numParallelDims = accMap.getNumResults();
+    unsigned numLhsDimToBroadcast =
+        numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
+    unsigned numRhsDimToBroadcast =
+        numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
+    SmallVector<int64_t> lhsDims;
+    SmallVector<int64_t> lhsTranspose;
+    SmallVector<int64_t> rhsDims;
+    SmallVector<int64_t> rhsTranspose;
+    for (int64_t dim : lhsReductionDims)
+      lhsTranspose.push_back(numLhsDimToBroadcast + dim);
+    for (int64_t dim : rhsReductionDims)
+      rhsTranspose.push_back(numRhsDimToBroadcast + dim);
+    // Loop through the parallel dimensions to calculate the dimensions to
+    // broadcast and to permute in order to extract only parallel dimensions.
+    for (unsigned i = 0; i < numParallelDims; i++) {
+      std::optional<unsigned> lhsDim =
+          getDimPosition(lhsMap, accMap.getDimPosition(i));
+      if (lhsDim) {
+        lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
+      } else {
+        // If the parallel dimension doesn't exist we will have to broadcast it.
+        lhsDims.push_back(
+            contractOp.getResultType().cast<VectorType>().getDimSize(i));
+        lhsTranspose.push_back(lhsDims.size() - 1);
+      }
+      std::optional<unsigned> rhsDim =
+          getDimPosition(rhsMap, accMap.getDimPosition(i));
+      if (rhsDim) {
+        rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
+      } else {
+        // If the parallel dimension doesn't exist we will have to broadcast it.
+        rhsDims.push_back(
+            contractOp.getResultType().cast<VectorType>().getDimSize(i));
+        rhsTranspose.push_back(rhsDims.size() - 1);
+      }
+    }
+    Value newLhs = contractOp.getLhs();
+    Value newRhs = contractOp.getRhs();
+    Location loc = contractOp.getLoc();
+    if (!lhsDims.empty()) {
+      lhsDims.append(lhsShape.begin(), lhsShape.end());
+      auto expandedType =
+          VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
+      newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
+    }
+    if (!rhsDims.empty()) {
+      rhsDims.append(rhsShape.begin(), rhsShape.end());
+      auto expandedType =
+          VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
+      newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
+    }
+    bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
+    newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
+    newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
+    SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
+    SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
+    newLhs = rewriter.create<vector::ExtractOp>(
+        loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
+    newRhs = rewriter.create<vector::ExtractOp>(
+        loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
+    std::optional<Value> result =
+        createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
+                              contractOp.getKind(), rewriter, isInt);
+    rewriter.replaceOp(contractOp, {*result});
+    return success();
+  }
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+  FilterConstraintType filter;
+};
+
+/// Progressive lowering of ContractionOp.
+/// One:
+///   %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+///   %a = vector.contract with one less free/batch dimension
+///   %b = vector.contract with one less free/batch dimension
+///   ..
+///   %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a dot-product.
+///
+/// This only kicks in when either VectorTransformsOptions is set
+/// to DOT or when other contraction patterns fail.
+//
+// TODO: break down into transpose/reshape/cast ops
+//               when they become available to avoid code dup
+// TODO: investigate lowering order impact on performance
+LogicalResult
+ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
+                                       PatternRewriter &rewriter) const {
+  // TODO: Remove native masks from contraction op?
+  if (!op.getMasks().empty())
+    return failure();
+
+  if (failed(filter(op)))
+    return failure();
+
+  // TODO: support mixed mode contract lowering.
+  if (op.getLhsType().getElementType() !=
+          getElementTypeOrSelf(op.getAccType()) ||
+      op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
+    return failure();
+
+  // TODO: the code below assumes the default contraction, make sure it supports
+  // other kinds before enabling this lowering.
+  if (op.getKind() != vector::CombiningKind::ADD) {
+    return rewriter.notifyMatchFailure(
+        op, "contractions other than 'add' not supported");
+  }
+
+  // TODO: implement benefits, cost models.
+  MLIRContext *ctx = op.getContext();
+  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
+  if (succeeded(pat1.matchAndRewrite(op, rewriter)))
+    return success();
+  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
+  if (succeeded(pat2.matchAndRewrite(op, rewriter)))
+    return success();
+  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
+  if (succeeded(pat3.matchAndRewrite(op, rewriter)))
+    return success();
+  ContractOpToElementwise pat4(vectorTransformOptions, ctx);
+  if (succeeded(pat4.matchAndRewrite(op, rewriter)))
+    return success();
+
+  // Vector mask setup.
+  OpBuilder::InsertionGuard guard(rewriter);
+  Operation *rootOp = op;
+  Value mask;
+  if (op.isMasked()) {
+    rewriter.setInsertionPoint(op.getMaskingOp());
+    rootOp = op.getMaskingOp();
+    mask = op.getMaskingOp().getMask();
+  }
+
+  // Find first batch dimension in LHS/RHS, and lower when found.
+  std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
+  if (!batchDimMap.empty()) {
+    int64_t lhsIndex = batchDimMap[0].first;
+    int64_t rhsIndex = batchDimMap[0].second;
+    auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
+    if (failed(newOp))
+      return failure();
+    rewriter.replaceOp(rootOp, *newOp);
+    return success();
+  }
+
+  // Collect contracting dimensions.
+  std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
+      op.getContractingDimMap();
+  DenseSet<int64_t> lhsContractingDimSet;
+  DenseSet<int64_t> rhsContractingDimSet;
+  for (auto &dimPair : contractingDimMap) {
+    lhsContractingDimSet.insert(dimPair.first);
+    rhsContractingDimSet.insert(dimPair.second);
+  }
+
+  // Find first free dimension in LHS, and lower when found.
+  VectorType lhsType = op.getLhsType();
+  for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
+    if (lhsContractingDimSet.count(lhsIndex) == 0) {
+      auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
+      if (failed(newOp))
+        return failure();
+      rewriter.replaceOp(rootOp, *newOp);
+      return success();
+    }
+  }
+
+  // Find first free dimension in RHS, and lower when found.
+  VectorType rhsType = op.getRhsType();
+  for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
+    if (rhsContractingDimSet.count(rhsIndex) == 0) {
+      auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
+      if (failed(newOp))
+        return failure();
+      rewriter.replaceOp(rootOp, *newOp);
+      return success();
+    }
+  }
+
+  // Lower the first remaining reduction dimension.
+  if (!contractingDimMap.empty()) {
+    auto newOp = lowerReduction(rewriter, op, mask);
+    if (failed(newOp))
+      return failure();
+    rewriter.replaceOp(rootOp, *newOp);
+    return success();
+  }
+
+  return failure();
+}
+
+// Lower one parallel dimension.
+// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
+// TODO: consider reusing existing contract unrolling
+FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
+                                                      vector::ContractionOp op,
+                                                      int64_t lhsIndex,
+                                                      int64_t rhsIndex,
+                                                      Value mask) const {
+  VectorType lhsType = op.getLhsType();
+  VectorType rhsType = op.getRhsType();
+  VectorType resType = op.getResultType().cast<VectorType>();
+  // Find the iterator type index and result index.
+  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
+  int64_t iterIndex = -1;
+  int64_t dimSize = -1;
+  if (lhsIndex >= 0) {
+    iterIndex = iMap[0].getDimPosition(lhsIndex);
+    if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
+      return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+        diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
+             << " to map to the same dimension";
+      });
+    dimSize = lhsType.getDimSize(lhsIndex);
+  } else if (rhsIndex >= 0) {
+    iterIndex = iMap[1].getDimPosition(rhsIndex);
+    dimSize = rhsType.getDimSize(rhsIndex);
+  }
+  if (iterIndex < 0)
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "expected either lhsIndex=" << lhsIndex
+           << " or rhsIndex=" << rhsIndex << " to be nonnegative";
+    });
+  // value_or(-1) means that we tolerate a dimension not appearing
+  // in the result map. That can't happen for actual parallel iterators, but
+  // the caller ContractionOpLowering::matchAndRewrite is currently calling
+  // lowerParallel also for the case of unit-size reduction dims appearing only
+  // on one of LHS or RHS, not both. At the moment, such cases are created by
+  // CastAwayContractionLeadingOneDim, so we need to either support that or
+  // modify that pattern.
+  int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
+  if (resIndex == -1 && dimSize != 1)
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "expected the dimension for iterIndex=" << iterIndex
+           << " to either appear in the result map, or to be a unit dimension";
+    });
+
+  // Construct new iterator types and affine map array attribute.
+  std::array<AffineMap, 3> lowIndexingMaps = {
+      adjustMap(iMap[0], iterIndex, rewriter),
+      adjustMap(iMap[1], iterIndex, rewriter),
+      adjustMap(iMap[2], iterIndex, rewriter)};
+  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
+  auto lowIter =
+      rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
+  // Unroll into a series of lower dimensional vector.contract ops.
+  Location loc = op.getLoc();
+  Value result = rewriter.create<arith::ConstantOp>(
+      loc, resType, rewriter.getZeroAttr(resType));
+
+  for (int64_t d = 0; d < dimSize; ++d) {
+    auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
+    auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
+    auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
+
+    Value lowMask;
+    if (mask)
+      lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
+                            iterIndex, d, rewriter);
+
+    Operation *lowContract = rewriter.create<vector::ContractionOp>(
+        loc, lhs, rhs, acc, lowAffine, lowIter);
+    lowContract = maskOperation(rewriter, lowContract, lowMask);
+    result = reshapeStore(loc, lowContract->getResult(0), result, resType,
+                          resIndex, d, rewriter);
+  }
+  return result;
+}
+
+// Lower one reduction dimension.
+FailureOr<Value> ContractionOpLowering::lowerReduction(
+    PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
+  auto loc = op.getLoc();
+  VectorType lhsType = op.getLhsType();
+  VectorType rhsType = op.getRhsType();
+  Type resType = op.getResultType();
+  if (resType.isa<VectorType>())
+    return rewriter.notifyMatchFailure(op,
+                                       "did not expect a VectorType result");
+  bool isInt = resType.isa<IntegerType>();
+  // Use iterator index 0.
+  int64_t iterIndex = 0;
+  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
+  std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
+  std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
+  if (!lookupLhs.has_value())
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
+    });
+  if (!lookupRhs.has_value())
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
+    });
+  int64_t lhsIndex = *lookupLhs;
+  int64_t rhsIndex = *lookupRhs;
+  int64_t dimSize = lhsType.getDimSize(lhsIndex);
+  if (dimSize != rhsType.getDimSize(rhsIndex))
+    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+      diag << "expect LHS dimension " << lhsIndex
+           << " to have the same size as RHS dimension " << rhsIndex;
+    });
+  // Base case.
+  if (lhsType.getRank() == 1) {
+    if (rhsType.getRank() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "When LHS has rank 1, expected also RHS to have rank 1");
+    Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
+    auto kind = vector::CombiningKind::ADD;
+
+    Value acc = op.getAcc();
+    Operation *reductionOp =
+        acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
+            : rewriter.create<vector::ReductionOp>(loc, kind, m);
+    return maskOperation(rewriter, reductionOp, mask)->getResult(0);
+  }
+  // Construct new iterator types and affine map array attribute.
+  std::array<AffineMap, 3> lowIndexingMaps = {
+      adjustMap(iMap[0], iterIndex, rewriter),
+      adjustMap(iMap[1], iterIndex, rewriter),
+      adjustMap(iMap[2], iterIndex, rewriter)};
+  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
+  auto lowIter =
+      rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
+  // Unroll into a series of lower dimensional vector.contract ops.
+  // By feeding the initial accumulator into the first contraction,
+  // and the result of each contraction into the next, eventually
+  // the sum of all reductions is computed.
+  Value result = op.getAcc();
+  for (int64_t d = 0; d < dimSize; ++d) {
+    auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
+    auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
+    Value newMask;
+    if (mask)
+      newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
+                            iterIndex, d, rewriter);
+
+    Operation *newContract = rewriter.create<vector::ContractionOp>(
+        loc, lhs, rhs, result, lowAffine, lowIter);
+    result = maskOperation(rewriter, newContract, newMask)->getResult(0);
+  }
+  return result;
+}
+
+/// Progressive lowering of OuterProductOp.
+/// One:
+///   %x = vector.outerproduct %lhs, %rhs, %acc
+/// is replaced by:
+///   %z = zero-result
+///   %0 = vector.extract %lhs[0]
+///   %1 = vector.broadcast %0
+///   %2 = vector.extract %acc[0]
+///   %3 = vector.fma %1, %rhs, %2
+///   %4 = vector.insert %3, %z[0]
+///   ..
+///   %x = vector.insert %.., %..[N-1]
+///
+class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::OuterProductOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    VectorType lhsType = op.getOperandVectorTypeLHS();
+    VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
+    VectorType resType = op.getResultVectorType();
+    Type eltType = resType.getElementType();
+    bool isInt = eltType.isa<IntegerType, IndexType>();
+    Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
+    vector::CombiningKind kind = op.getKind();
+
+    // Vector mask setup.
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+    Operation *rootOp;
+    Value mask;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = op;
+    }
+
+    if (!rhsType) {
+      // Special case: AXPY operation.
+      Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
+      std::optional<Value> mult = createContractArithOp(
+          loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
+      if (!mult.has_value())
+        return failure();
+      rewriter.replaceOp(rootOp, *mult);
+      return success();
+    }
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resType, rewriter.getZeroAttr(resType));
+    for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
+      auto pos = rewriter.getI64ArrayAttr(d);
+      Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
+      Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
+      Value r = nullptr;
+      if (acc)
+        r = rewriter.create<vector::ExtractOp>(loc, acc, pos);
+      Value extrMask;
+      if (mask)
+        extrMask = rewriter.create<vector::ExtractOp>(loc, mask, pos);
+
+      std::optional<Value> m = createContractArithOp(
+          loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
+      if (!m.has_value())
+        return failure();
+      result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
+    }
+
+    rewriter.replaceOp(rootOp, result);
+    return success();
+  }
+};
+
+/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+///    %mta = maybe_transpose
+///    %mtb = maybe_transpose
+///    %flattened_a = vector.shape_cast %mta
+///    %flattened_b = vector.shape_cast %mtb
+///    %flattened_d = vector.matmul %flattened_a, %flattened_b
+///    %mtd = vector.shape_cast %flattened_d
+///    %d = maybe_untranspose %mtd
+///    %e = add %c, %d
+/// ```
+/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
+//
+/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
+/// vector.transpose operations are inserted if the vector.contract op is not a
+/// row-major matrix multiply.
+LogicalResult
+ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
+                                                 PatternRewriter &rew) const {
+  // TODO: Support vector.mask.
+  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+  if (maskableOp.isMasked())
+    return failure();
+
+  // TODO: Remove native masks from contraction op?
+  if (!op.getMasks().empty())
+    return failure();
+  if (vectorTransformOptions.vectorContractLowering !=
+      vector::VectorContractLowering::Matmul)
+    return failure();
+  if (failed(filter(op)))
+    return failure();
+
+  auto iteratorTypes = op.getIteratorTypes().getValue();
+  if (!isParallelIterator(iteratorTypes[0]) ||
+      !isParallelIterator(iteratorTypes[1]) ||
+      !isReductionIterator(iteratorTypes[2]))
+    return failure();
+
+  Type elementType = op.getLhsType().getElementType();
+  if (!elementType.isIntOrFloat())
+    return failure();
+
+  Type dstElementType = op.getType();
+  if (auto vecType = dstElementType.dyn_cast<VectorType>())
+    dstElementType = vecType.getElementType();
+  if (elementType != dstElementType)
+    return failure();
+
+  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
+  // Bail out if the contraction cannot be put in this form.
+  MLIRContext *ctx = op.getContext();
+  Location loc = op.getLoc();
+  AffineExpr m, n, k;
+  bindDims(rew.getContext(), m, n, k);
+  // LHS must be A(m, k) or A(k, m).
+  Value lhs = op.getLhs();
+  auto lhsMap = op.getIndexingMapsArray()[0];
+  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
+    lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
+  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
+    return failure();
+
+  // RHS must be B(k, n) or B(n, k).
+  Value rhs = op.getRhs();
+  auto rhsMap = op.getIndexingMapsArray()[1];
+  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
+    rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
+  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
+    return failure();
+
+  // At this point lhs and rhs are in row-major.
+  VectorType lhsType = lhs.getType().cast<VectorType>();
+  VectorType rhsType = rhs.getType().cast<VectorType>();
+  int64_t lhsRows = lhsType.getDimSize(0);
+  int64_t lhsColumns = lhsType.getDimSize(1);
+  int64_t rhsColumns = rhsType.getDimSize(1);
+
+  Type flattenedLHSType =
+      VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
+  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
+
+  Type flattenedRHSType =
+      VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
+  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
+
+  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
+                                           rhsColumns);
+  mul = rew.create<vector::ShapeCastOp>(
+      loc,
+      VectorType::get({lhsRows, rhsColumns},
+                      getElementTypeOrSelf(op.getAcc().getType())),
+      mul);
+
+  // ACC must be C(m, n) or C(n, m).
+  auto accMap = op.getIndexingMapsArray()[2];
+  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
+    mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
+  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
+    llvm_unreachable("invalid contraction semantics");
+
+  Value res =
+      elementType.isa<IntegerType>()
+          ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
+          : static_cast<Value>(
+                rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
+
+  rew.replaceOp(op, res);
+  return success();
+}
+} // namespace
+
+void mlir::vector::populateVectorContractLoweringPatterns(
+    RewritePatternSet &patterns, VectorTransformsOptions options,
+    PatternBenefit benefit, bool disableOuterProductLowering) {
+  if (!disableOuterProductLowering)
+    patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
+  patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
+               ContractionOpToOuterProductOpLowering>(
+      options, patterns.getContext(), benefit);
+}

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
new file mode 100644
index 0000000000000..dc10cb6278cb8
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -0,0 +1,173 @@
+//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.scan' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-broadcast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
+///        ... into vector<2x3xf32>
+///
+/// ==>
+///
+/// %0   = arith.constant dense<0.0> : vector<2x3xf32>
+/// %g0  = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
+/// %1   = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %g1  = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
+/// %g   = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
+struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultTy = op.getType();
+    if (resultTy.getRank() < 2)
+      return rewriter.notifyMatchFailure(op, "already flat");
+
+    Location loc = op.getLoc();
+    Value indexVec = op.getIndexVec();
+    Value maskVec = op.getMask();
+    Value passThruVec = op.getPassThru();
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultTy, rewriter.getZeroAttr(resultTy));
+
+    Type subTy = VectorType::get(resultTy.getShape().drop_front(),
+                                 resultTy.getElementType());
+
+    for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+      int64_t thisIdx[1] = {i};
+
+      Value indexSubVec =
+          rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
+      Value maskSubVec =
+          rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
+      Value passThruSubVec =
+          rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
+      Value subGather = rewriter.create<vector::GatherOp>(
+          loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
+          passThruSubVec);
+      result =
+          rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
+/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
+/// loads/extracts are made conditional using `scf.if` ops.
+struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultTy = op.getType();
+    if (resultTy.getRank() != 1)
+      return rewriter.notifyMatchFailure(op, "unsupported rank");
+
+    Location loc = op.getLoc();
+    Type elemTy = resultTy.getElementType();
+    // Vector type with a single element. Used to generate `vector.loads`.
+    VectorType elemVecTy = VectorType::get({1}, elemTy);
+
+    Value condMask = op.getMask();
+    Value base = op.getBase();
+    Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
+        loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
+        op.getIndexVec());
+    auto baseOffsets = llvm::to_vector(op.getIndices());
+    Value lastBaseOffset = baseOffsets.back();
+
+    Value result = op.getPassThru();
+
+    // Emit a conditional access for each vector element.
+    for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
+      int64_t thisIdx[1] = {i};
+      Value condition =
+          rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
+      Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
+      baseOffsets.back() =
+          rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
+
+      auto loadBuilder = [&](OpBuilder &b, Location loc) {
+        Value extracted;
+        if (isa<MemRefType>(base.getType())) {
+          // `vector.load` does not support scalar result; emit a vector load
+          // and extract the single result instead.
+          Value load =
+              b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
+          int64_t zeroIdx[1] = {0};
+          extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
+        } else {
+          extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
+        }
+
+        Value newResult =
+            b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
+        b.create<scf::YieldOp>(loc, newResult);
+      };
+      auto passThruBuilder = [result](OpBuilder &b, Location loc) {
+        b.create<scf::YieldOp>(loc, result);
+      };
+
+      result =
+          rewriter
+              .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
+                                 /*elseBuilder=*/passThruBuilder)
+              .getResult(0);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorGatherLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
+                                                          benefit);
+}

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 7c66e65fdef8b..e318d4dc15915 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements target-independent rewrites and utilitites to lower the
+// This file implements target-independent rewrites and utilities to lower the
 // 'vector.mask' operation.
 //
 //===----------------------------------------------------------------------===//
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -30,6 +31,147 @@ namespace vector {
 using namespace mlir;
 using namespace mlir::vector;
 
+//===----------------------------------------------------------------------===//
+// populateVectorMaskOpLoweringPatterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Progressive lowering of CreateMaskOp.
+/// One:
+///   %x = vector.create_mask %a, ... : vector<dx...>
+/// is replaced by:
+///   %l = vector.create_mask ... : vector<...>  ; one lower rank
+///   %0 = arith.cmpi "slt", %ci, %a       |
+///   %1 = select %0, %l, %zeroes    |
+///   %r = vector.insert %1, %pr [i] | d-times
+///   %x = ....
+/// until a one-dimensional vector is reached.
+class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
+                                PatternRewriter &rewriter) const override {
+    auto dstType = op.getResult().getType().cast<VectorType>();
+    int64_t rank = dstType.getRank();
+    if (rank <= 1)
+      return rewriter.notifyMatchFailure(
+          op, "0-D and 1-D vectors are handled separately");
+
+    auto loc = op.getLoc();
+    auto eltType = dstType.getElementType();
+    int64_t dim = dstType.getDimSize(0);
+    Value idx = op.getOperand(0);
+
+    VectorType lowType =
+        VectorType::get(dstType.getShape().drop_front(), eltType);
+    Value trueVal = rewriter.create<vector::CreateMaskOp>(
+        loc, lowType, op.getOperands().drop_front());
+    Value falseVal = rewriter.create<arith::ConstantOp>(
+        loc, lowType, rewriter.getZeroAttr(lowType));
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, dstType, rewriter.getZeroAttr(dstType));
+    for (int64_t d = 0; d < dim; d++) {
+      Value bnd =
+          rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
+      Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+                                                 bnd, idx);
+      Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
+      auto pos = rewriter.getI64ArrayAttr(d);
+      result =
+          rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/// Progressive lowering of ConstantMaskOp.
+/// One:
+///   %x = vector.constant_mask [a,b]
+/// is replaced by:
+///   %z = zero-result
+///   %l = vector.constant_mask [b]
+///   %4 = vector.insert %l, %z[0]
+///   ..
+///   %x = vector.insert %l, %..[a-1]
+/// until a one-dimensional vector is reached. All these operations
+/// will be folded at LLVM IR level.
+class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto dstType = op.getType();
+    auto eltType = dstType.getElementType();
+    auto dimSizes = op.getMaskDimSizes();
+    int64_t rank = dstType.getRank();
+
+    if (rank == 0) {
+      assert(dimSizes.size() == 1 &&
+             "Expected exactly one dim size for a 0-D vector");
+      bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          op, dstType,
+          DenseIntElementsAttr::get(
+              VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
+              ArrayRef<bool>{value}));
+      return success();
+    }
+
+    // Scalable constant masks can only be lowered for the "none set" case.
+    if (dstType.cast<VectorType>().isScalable()) {
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          op, DenseElementsAttr::get(dstType, false));
+      return success();
+    }
+
+    int64_t trueDim = std::min(dstType.getDimSize(0),
+                               dimSizes[0].cast<IntegerAttr>().getInt());
+
+    if (rank == 1) {
+      // Express constant 1-D case in explicit vector form:
+      //   [T,..,T,F,..,F].
+      SmallVector<bool> values(dstType.getDimSize(0));
+      for (int64_t d = 0; d < trueDim; d++)
+        values[d] = true;
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          op, dstType, rewriter.getBoolVectorAttr(values));
+      return success();
+    }
+
+    VectorType lowType =
+        VectorType::get(dstType.getShape().drop_front(), eltType);
+    SmallVector<int64_t> newDimSizes;
+    for (int64_t r = 1; r < rank; r++)
+      newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
+    Value trueVal = rewriter.create<vector::ConstantMaskOp>(
+        loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, dstType, rewriter.getZeroAttr(dstType));
+    for (int64_t d = 0; d < trueDim; d++) {
+      auto pos = rewriter.getI64ArrayAttr(d);
+      result =
+          rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorMaskOpLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
+      patterns.getContext(), benefit);
+}
+
+//===----------------------------------------------------------------------===//
+// populateVectorMaskLoweringPatternsForSideEffectingOps
+//===----------------------------------------------------------------------===//
+
 namespace {
 
 /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
similarity index 98%
rename from mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
rename to mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index b790d141415aa..1744c46db5886 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -1,4 +1,4 @@
-//===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
+//===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
 //
 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 //
-/// This file implements target-independent rewrites of MultiDimReductionOp.
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.multi_reduction' operation.
 //
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/TypeUtilities.h"
 
@@ -19,6 +20,7 @@
 
 using namespace mlir;
 
+namespace {
 /// This file implements the following transformations as composable atomic
 /// patterns.
 
@@ -441,6 +443,7 @@ struct OneDimMultiReductionToTwoDim
     return success();
   }
 };
+} // namespace
 
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
new file mode 100644
index 0000000000000..eb2deba7bc46b
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -0,0 +1,251 @@
+//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.scan' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-broadcast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+/// This function constructs the appropriate integer or float
+/// operation given the vector combining kind and operands. The
+/// supported int operations are : add, mul, min (signed/unsigned),
+/// max(signed/unsigned), and, or, xor. The supported float
+/// operations are : add, mul, min and max.
+static Value genOperator(Location loc, Value x, Value y,
+                         vector::CombiningKind kind,
+                         PatternRewriter &rewriter) {
+  using vector::CombiningKind;
+
+  auto elType = x.getType().cast<VectorType>().getElementType();
+  bool isInt = elType.isIntOrIndex();
+
+  Value combinedResult{nullptr};
+  switch (kind) {
+  case CombiningKind::ADD:
+    if (isInt)
+      combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
+    else
+      combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
+    break;
+  case CombiningKind::MUL:
+    if (isInt)
+      combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
+    else
+      combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
+    break;
+  case CombiningKind::MINUI:
+    combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
+    break;
+  case CombiningKind::MINSI:
+    combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
+    break;
+  case CombiningKind::MAXUI:
+    combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
+    break;
+  case CombiningKind::MAXSI:
+    combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
+    break;
+  case CombiningKind::AND:
+    combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
+    break;
+  case CombiningKind::OR:
+    combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
+    break;
+  case CombiningKind::XOR:
+    combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
+    break;
+  case CombiningKind::MINF:
+    combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
+    break;
+  case CombiningKind::MAXF:
+    combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
+    break;
+  }
+  return combinedResult;
+}
+
+/// This function checks to see if the vector combining kind
+/// is consistent with the integer or float element type.
+static bool isValidKind(bool isInt, vector::CombiningKind kind) {
+  using vector::CombiningKind;
+  enum class KindType { FLOAT, INT, INVALID };
+  KindType type{KindType::INVALID};
+  switch (kind) {
+  case CombiningKind::MINF:
+  case CombiningKind::MAXF:
+    type = KindType::FLOAT;
+    break;
+  case CombiningKind::MINUI:
+  case CombiningKind::MINSI:
+  case CombiningKind::MAXUI:
+  case CombiningKind::MAXSI:
+  case CombiningKind::AND:
+  case CombiningKind::OR:
+  case CombiningKind::XOR:
+    type = KindType::INT;
+    break;
+  case CombiningKind::ADD:
+  case CombiningKind::MUL:
+    type = isInt ? KindType::INT : KindType::FLOAT;
+    break;
+  }
+  bool isValidIntKind = (type == KindType::INT) && isInt;
+  bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
+  return (isValidIntKind || isValidFloatKind);
+}
+
+namespace {
+/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
+/// vector.extract_strided_slice.
+///
+/// Example:
+///
+/// ```
+///   %0:2 = vector.scan <add>, %arg0, %arg1
+///     {inclusive = true, reduction_dim = 1} :
+///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
+/// ```
+///
+/// is converted to:
+///
+/// ```
+///   %cst = arith.constant dense<0> : vector<2x3xi32>
+///   %0 = vector.extract_strided_slice %arg0
+///     {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]}
+///       : vector<2x3xi32> to vector<2x1xi32>
+///   %1 = vector.insert_strided_slice %0, %cst
+///     {offsets = [0, 0], strides = [1, 1]}
+///       : vector<2x1xi32> into vector<2x3xi32>
+///   %2 = vector.extract_strided_slice %arg0
+///     {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
+///       : vector<2x3xi32> to vector<2x1xi32>
+///   %3 = arith.muli %0, %2 : vector<2x1xi32>
+///   %4 = vector.insert_strided_slice %3, %1
+///     {offsets = [0, 1], strides = [1, 1]}
+///       : vector<2x1xi32> into vector<2x3xi32>
+///   %5 = vector.extract_strided_slice %arg0
+///     {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
+///       : vector<2x3xi32> to vector<2x1xi32>
+///   %6 = arith.muli %3, %5 : vector<2x1xi32>
+///   %7 = vector.insert_strided_slice %6, %4
+///     {offsets = [0, 2], strides = [1, 1]}
+///       : vector<2x1xi32> into vector<2x3xi32>
+///   %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32>
+///   return %7, %8 : vector<2x3xi32>, vector<2xi32>
+/// ```
+struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ScanOp scanOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = scanOp.getLoc();
+    VectorType destType = scanOp.getDestType();
+    ArrayRef<int64_t> destShape = destType.getShape();
+    auto elType = destType.getElementType();
+    bool isInt = elType.isIntOrIndex();
+    if (!isValidKind(isInt, scanOp.getKind()))
+      return failure();
+
+    VectorType resType = VectorType::get(destShape, elType);
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resType, rewriter.getZeroAttr(resType));
+    int64_t reductionDim = scanOp.getReductionDim();
+    bool inclusive = scanOp.getInclusive();
+    int64_t destRank = destType.getRank();
+    VectorType initialValueType = scanOp.getInitialValueType();
+    int64_t initialValueRank = initialValueType.getRank();
+
+    SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
+    reductionShape[reductionDim] = 1;
+    VectorType reductionType = VectorType::get(reductionShape, elType);
+    SmallVector<int64_t> offsets(destRank, 0);
+    SmallVector<int64_t> strides(destRank, 1);
+    SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
+    sizes[reductionDim] = 1;
+    ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
+    ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
+
+    Value lastOutput, lastInput;
+    for (int i = 0; i < destShape[reductionDim]; i++) {
+      offsets[reductionDim] = i;
+      ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
+      Value input = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
+          scanStrides);
+      Value output;
+      if (i == 0) {
+        if (inclusive) {
+          output = input;
+        } else {
+          if (initialValueRank == 0) {
+            // ShapeCastOp cannot handle 0-D vectors
+            output = rewriter.create<vector::BroadcastOp>(
+                loc, input.getType(), scanOp.getInitialValue());
+          } else {
+            output = rewriter.create<vector::ShapeCastOp>(
+                loc, input.getType(), scanOp.getInitialValue());
+          }
+        }
+      } else {
+        Value y = inclusive ? input : lastInput;
+        output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
+        assert(output != nullptr);
+      }
+      result = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, output, result, offsets, strides);
+      lastOutput = output;
+      lastInput = input;
+    }
+
+    Value reduction;
+    if (initialValueRank == 0) {
+      Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
+      reduction =
+          rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
+    } else {
+      reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
+                                                       lastOutput);
+    }
+
+    rewriter.replaceOp(scanOp, {result, reduction});
+    return success();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorScanLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
+}

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
new file mode 100644
index 0000000000000..bd9716cbca94c
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -0,0 +1,177 @@
+//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.shape_cast' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-shape-cast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
+/// vectors progressively on the way to target llvm.matrix intrinsics.
+/// This iterates over the most major dimension of the 2-D vector and performs
+/// rewrites into:
+///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
+class ShapeCastOp2DDownCastRewritePattern
+    : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+                                PatternRewriter &rewriter) const override {
+    auto sourceVectorType = op.getSourceVectorType();
+    auto resultVectorType = op.getResultVectorType();
+    if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
+      return failure();
+
+    auto loc = op.getLoc();
+    Value desc = rewriter.create<arith::ConstantOp>(
+        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+    unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
+    for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
+      Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
+      desc = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, vec, desc,
+          /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
+    }
+    rewriter.replaceOp(op, desc);
+    return success();
+  }
+};
+
+/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
+/// vectors progressively.
+/// This iterates over the most major dimension of the 2-D vector and performs
+/// rewrites into:
+///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
+/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
+class ShapeCastOp2DUpCastRewritePattern
+    : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+                                PatternRewriter &rewriter) const override {
+    auto sourceVectorType = op.getSourceVectorType();
+    auto resultVectorType = op.getResultVectorType();
+    if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
+      return failure();
+
+    auto loc = op.getLoc();
+    Value desc = rewriter.create<arith::ConstantOp>(
+        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+    unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
+    for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
+      Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
+          /*sizes=*/mostMinorVectorSize,
+          /*strides=*/1);
+      desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
+    }
+    rewriter.replaceOp(op, desc);
+    return success();
+  }
+};
+
+// We typically should not lower general shape cast operations into data
+// movement instructions, since the assumption is that these casts are
+// optimized away during progressive lowering. For completeness, however,
+// we fall back to a reference implementation that moves all elements
+// into the right place if we get here.
+class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto sourceVectorType = op.getSourceVectorType();
+    auto resultVectorType = op.getResultVectorType();
+
+    // Special case 2D / 1D lowerings with better implementations.
+    // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
+    int64_t srcRank = sourceVectorType.getRank();
+    int64_t resRank = resultVectorType.getRank();
+    if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
+      return failure();
+
+    // Generic ShapeCast lowering path goes all the way down to unrolled scalar
+    // extract/insert chains.
+    // TODO: consider evolving the semantics to only allow 1D source or dest and
+    // drop this potentially very expensive lowering.
+    // Compute number of elements involved in the reshape.
+    int64_t numElts = 1;
+    for (int64_t r = 0; r < srcRank; r++)
+      numElts *= sourceVectorType.getDimSize(r);
+    // Replace with data movement operations:
+    //    x[0,0,0] = y[0,0]
+    //    x[0,0,1] = y[0,1]
+    //    x[0,1,0] = y[0,2]
+    // etc., incrementing the two index vectors "row-major"
+    // within the source and result shape.
+    SmallVector<int64_t> srcIdx(srcRank);
+    SmallVector<int64_t> resIdx(resRank);
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+    for (int64_t i = 0; i < numElts; i++) {
+      if (i != 0) {
+        incIdx(srcIdx, sourceVectorType, srcRank - 1);
+        incIdx(resIdx, resultVectorType, resRank - 1);
+      }
+      Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+      result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+private:
+  static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
+    assert(0 <= r && r < tp.getRank());
+    if (++idx[r] == tp.getDimSize(r)) {
+      idx[r] = 0;
+      incIdx(idx, tp, r - 1);
+    }
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorShapeCastLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ShapeCastOp2DDownCastRewritePattern,
+               ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
+      patterns.getContext(), benefit);
+}

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
similarity index 57%
rename from mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
rename to mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 68d9a349478bf..c2ce9aa10a850 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -14,7 +14,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 
 using namespace mlir;
@@ -46,6 +46,11 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
   return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
 }
 
+//===----------------------------------------------------------------------===//
+// populateVectorTransferPermutationMapLoweringPatterns
+//===----------------------------------------------------------------------===//
+
+namespace {
 /// Lower transfer_read op with permutation into a transfer_read with a
 /// permutation map composed of leading zeros followed by a minor identiy +
 /// vector.transpose op.
@@ -332,6 +337,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
   }
 };
 
+} // namespace
+
 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns
@@ -339,3 +346,239 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
            TransferOpReduceRank, TransferWriteNonPermutationLowering>(
           patterns.getContext(), benefit);
 }
+
+//===----------------------------------------------------------------------===//
+// populateVectorTransferLoweringPatterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Progressive lowering of transfer_read. This pattern supports lowering of
+/// `vector.transfer_read` to a combination of `vector.load` and
+/// `vector.broadcast` if all of the following hold:
+/// - Stride of most minor memref dimension must be 1.
+/// - Out-of-bounds masking is not required.
+/// - If the memref's element type is a vector type then it coincides with the
+///   result type.
+/// - The permutation map doesn't perform permutation (broadcasting is allowed).
+struct TransferReadToVectorLoadLowering
+    : public OpRewritePattern<vector::TransferReadOp> {
+  TransferReadToVectorLoadLowering(MLIRContext *context,
+                                   std::optional<unsigned> maxRank,
+                                   PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransferReadOp>(context, benefit),
+        maxTransferRank(maxRank) {}
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp read,
+                                PatternRewriter &rewriter) const override {
+    if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
+      return failure();
+
+    SmallVector<unsigned> broadcastedDims;
+    // Permutations are handled by VectorToSCF or
+    // populateVectorTransferPermutationMapLoweringPatterns.
+    // We let the 0-d corner case pass-through as it is supported.
+    if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
+            &broadcastedDims))
+      return failure();
+
+    auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
+    if (!memRefType)
+      return failure();
+
+    // Non-unit strides are handled by VectorToSCF.
+    if (!vector::isLastMemrefDimUnitStride(memRefType))
+      return failure();
+
+    // If there is broadcasting involved then we first load the unbroadcasted
+    // vector, and then broadcast it with `vector.broadcast`.
+    ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
+    SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
+                                                  vectorShape.end());
+    for (unsigned i : broadcastedDims)
+      unbroadcastedVectorShape[i] = 1;
+    VectorType unbroadcastedVectorType = VectorType::get(
+        unbroadcastedVectorShape, read.getVectorType().getElementType());
+
+    // `vector.load` supports vector types as memref's elements only when the
+    // resulting vector type is the same as the element type.
+    auto memrefElTy = memRefType.getElementType();
+    if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
+      return failure();
+
+    // Otherwise, element types of the memref and the vector must match.
+    if (!memrefElTy.isa<VectorType>() &&
+        memrefElTy != read.getVectorType().getElementType())
+      return failure();
+
+    // Out-of-bounds dims are handled by MaterializeTransferMask.
+    if (read.hasOutOfBoundsDim())
+      return failure();
+
+    // Create vector load op.
+    Operation *loadOp;
+    if (read.getMask()) {
+      Value fill = rewriter.create<vector::SplatOp>(
+          read.getLoc(), unbroadcastedVectorType, read.getPadding());
+      loadOp = rewriter.create<vector::MaskedLoadOp>(
+          read.getLoc(), unbroadcastedVectorType, read.getSource(),
+          read.getIndices(), read.getMask(), fill);
+    } else {
+      loadOp = rewriter.create<vector::LoadOp>(
+          read.getLoc(), unbroadcastedVectorType, read.getSource(),
+          read.getIndices());
+    }
+
+    // Insert a broadcasting op if required.
+    if (!broadcastedDims.empty()) {
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+          read, read.getVectorType(), loadOp->getResult(0));
+    } else {
+      rewriter.replaceOp(read, loadOp->getResult(0));
+    }
+
+    return success();
+  }
+
+  std::optional<unsigned> maxTransferRank;
+};
+
+/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
+// TODO: we shouldn't cross the vector/scalar domains just for this
+// but atm we lack the infra to avoid it. Possible solutions include:
+// - go directly to LLVM + bitcast
+// - introduce a bitcast op and likely a new pointer dialect
+// - let memref.load/store additionally support the 0-d vector case
+// There are still deeper data layout issues lingering even in this
+// trivial case (for architectures for which this matters).
+struct VectorLoadToMemrefLoadLowering
+    : public OpRewritePattern<vector::LoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    auto vecType = loadOp.getVectorType();
+    if (vecType.getNumElements() != 1)
+      return failure();
+    auto memrefLoad = rewriter.create<memref::LoadOp>(
+        loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
+                                                     memrefLoad);
+    return success();
+  }
+};
+
+/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
+struct VectorStoreToMemrefStoreLowering
+    : public OpRewritePattern<vector::StoreOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    auto vecType = storeOp.getVectorType();
+    if (vecType.getNumElements() != 1)
+      return failure();
+    Value extracted;
+    if (vecType.getRank() == 0) {
+      // TODO: Unifiy once ExtractOp supports 0-d vectors.
+      extracted = rewriter.create<vector::ExtractElementOp>(
+          storeOp.getLoc(), storeOp.getValueToStore());
+    } else {
+      SmallVector<int64_t> indices(vecType.getRank(), 0);
+      extracted = rewriter.create<vector::ExtractOp>(
+          storeOp.getLoc(), storeOp.getValueToStore(), indices);
+    }
+
+    rewriter.replaceOpWithNewOp<memref::StoreOp>(
+        storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
+    return success();
+  }
+};
+
+/// Progressive lowering of transfer_write. This pattern supports lowering of
+/// `vector.transfer_write` to `vector.store` if all of the following hold:
+/// - Stride of most minor memref dimension must be 1.
+/// - Out-of-bounds masking is not required.
+/// - If the memref's element type is a vector type then it coincides with the
+///   type of the written value.
+/// - The permutation map is the minor identity map (neither permutation nor
+///   broadcasting is allowed).
+struct TransferWriteToVectorStoreLowering
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  TransferWriteToVectorStoreLowering(MLIRContext *context,
+                                     std::optional<unsigned> maxRank,
+                                     PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
+        maxTransferRank(maxRank) {}
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+                                PatternRewriter &rewriter) const override {
+    if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
+      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+        diag << "rank exceeds maxTransferRank: " << write;
+      });
+
+    // Permutations are handled by VectorToSCF or
+    // populateVectorTransferPermutationMapLoweringPatterns.
+    if ( // pass-through for the 0-d corner case.
+        !write.getPermutationMap().isMinorIdentity())
+      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+        diag << "permutation map is not minor identity: " << write;
+      });
+
+    auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
+    if (!memRefType)
+      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+        diag << "not a memref type: " << write;
+      });
+
+    // Non-unit strides are handled by VectorToSCF.
+    if (!vector::isLastMemrefDimUnitStride(memRefType))
+      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+        diag << "most minor stride is not 1: " << write;
+      });
+
+    // `vector.store` supports vector types as memref's elements only when the
+    // type of the vector value being written is the same as the element type.
+    auto memrefElTy = memRefType.getElementType();
+    if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
+      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+        diag << "elemental type mismatch: " << write;
+      });
+
+    // Otherwise, element types of the memref and the vector must match.
+    if (!memrefElTy.isa<VectorType>() &&
+        memrefElTy != write.getVectorType().getElementType())
+      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+        diag << "elemental type mismatch: " << write;
+      });
+
+    // Out-of-bounds dims are handled by MaterializeTransferMask.
+    if (write.hasOutOfBoundsDim())
+      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+        diag << "out of bounds dim: " << write;
+      });
+    if (write.getMask()) {
+      rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+          write, write.getSource(), write.getIndices(), write.getMask(),
+          write.getVector());
+    } else {
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(
+          write, write.getVector(), write.getSource(), write.getIndices());
+    }
+    return success();
+  }
+
+  std::optional<unsigned> maxTransferRank;
+};
+} // namespace
+
+void mlir::vector::populateVectorTransferLoweringPatterns(
+    RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
+    PatternBenefit benefit) {
+  patterns.add<TransferReadToVectorLoadLowering,
+               TransferWriteToVectorStoreLowering>(patterns.getContext(),
+                                                   maxTransferRank, benefit);
+  patterns
+      .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
+          patterns.getContext(), benefit);
+}

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
new file mode 100644
index 0000000000000..f6e8b0c445c99
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -0,0 +1,210 @@
+//===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.transpose' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-shape-cast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
+/// transposed.
+static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
+                                   SmallVectorImpl<int64_t> &result) {
+  size_t numTransposedDims = transpose.size();
+  for (size_t transpDim : llvm::reverse(transpose)) {
+    if (transpDim != numTransposedDims - 1)
+      break;
+    numTransposedDims--;
+  }
+
+  result.append(transpose.begin(), transpose.begin() + numTransposedDims);
+}
+
+namespace {
+/// Progressive lowering of TransposeOp.
+/// One:
+///   %x = vector.transpose %y, [1, 0]
+/// is replaced by:
+///   %z = arith.constant dense<0.000000e+00>
+///   %0 = vector.extract %y[0, 0]
+///   %1 = vector.insert %0, %z [0, 0]
+///   ..
+///   %x = vector.insert .., .. [.., ..]
+class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
+                      MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit),
+        vectorTransformOptions(vectorTransformOptions) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    Value input = op.getVector();
+    VectorType inputType = op.getSourceVectorType();
+    VectorType resType = op.getResultVectorType();
+
+    // Set up convenience transposition table.
+    SmallVector<int64_t> transp;
+    for (auto attr : op.getTransp())
+      transp.push_back(attr.cast<IntegerAttr>().getInt());
+
+    if (vectorTransformOptions.vectorTransposeLowering ==
+            vector::VectorTransposeLowering::Shuffle &&
+        resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
+      return rewriter.notifyMatchFailure(
+          op, "Options specifies lowering to shuffle");
+
+    // Handle a true 2-D matrix transpose 
diff erently when requested.
+    if (vectorTransformOptions.vectorTransposeLowering ==
+            vector::VectorTransposeLowering::Flat &&
+        resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
+      Type flattenedType =
+          VectorType::get(resType.getNumElements(), resType.getElementType());
+      auto matrix =
+          rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
+      auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
+      auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
+      Value trans = rewriter.create<vector::FlatTransposeOp>(
+          loc, flattenedType, matrix, rows, columns);
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
+      return success();
+    }
+
+    // Generate unrolled extract/insert ops. We do not unroll the rightmost
+    // (i.e., highest-order) dimensions that are not transposed and leave them
+    // in vector form to improve performance. Therefore, we prune those
+    // dimensions from the shape/transpose data structures used to generate the
+    // extract/insert ops.
+    SmallVector<int64_t> prunedTransp;
+    pruneNonTransposedDims(transp, prunedTransp);
+    size_t numPrunedDims = transp.size() - prunedTransp.size();
+    auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
+    auto prunedInStrides = computeStrides(prunedInShape);
+
+    // Generates the extract/insert operations for every scalar/vector element
+    // of the leftmost transposed dimensions. We traverse every transpose
+    // element using a linearized index that we delinearize to generate the
+    // appropriate indices for the extract/insert operations.
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resType, rewriter.getZeroAttr(resType));
+    int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
+
+    for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
+         ++linearIdx) {
+      auto extractIdxs = delinearize(linearIdx, prunedInStrides);
+      SmallVector<int64_t> insertIdxs(extractIdxs);
+      applyPermutationToVector(insertIdxs, prunedTransp);
+      Value extractOp =
+          rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
+      result =
+          rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+};
+
+/// Rewrite a 2-D vector.transpose as a sequence of:
+///   vector.shape_cast 2D -> 1D
+///   vector.shuffle
+///   vector.shape_cast 1D -> 2D
+class TransposeOp2DToShuffleLowering
+    : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  TransposeOp2DToShuffleLowering(
+      vector::VectorTransformsOptions vectorTransformOptions,
+      MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit),
+        vectorTransformOptions(vectorTransformOptions) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    VectorType srcType = op.getSourceVectorType();
+    if (srcType.getRank() != 2)
+      return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
+
+    SmallVector<int64_t> transp;
+    for (auto attr : op.getTransp())
+      transp.push_back(attr.cast<IntegerAttr>().getInt());
+    if (transp[0] != 1 && transp[1] != 0)
+      return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
+
+    if (vectorTransformOptions.vectorTransposeLowering !=
+        VectorTransposeLowering::Shuffle)
+      return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
+
+    int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
+    Value casted = rewriter.create<vector::ShapeCastOp>(
+        loc, VectorType::get({m * n}, srcType.getElementType()),
+        op.getVector());
+    SmallVector<int64_t> mask;
+    mask.reserve(m * n);
+    for (int64_t j = 0; j < n; ++j)
+      for (int64_t i = 0; i < m; ++i)
+        mask.push_back(i * n + j);
+
+    Value shuffled =
+        rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        op, op.getResultVectorType(), shuffled);
+
+    return success();
+  }
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+};
+} // namespace
+
+void mlir::vector::populateVectorTransposeLoweringPatterns(
+    RewritePatternSet &patterns, VectorTransformsOptions options,
+    PatternBenefit benefit) {
+  patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
+      options, patterns.getContext(), benefit);
+}

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 38062b9893f1a..b0690f63422d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinOps.h"

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index ee23b5494f707..caf5822256bc6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -11,8 +11,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <type_traits>
 #include <optional>
+#include <type_traits>
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -92,11 +92,11 @@ static Value createInBoundsCond(RewriterBase &b,
 }
 
 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
-/// masking) fastpath and a slowpath.
+/// masking) fast path and a slow path.
 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
 /// newly created conditional upon function return.
-/// To accomodate for the fact that the original vector.transfer indexing may be
-/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// To accommodate for the fact that the original vector.transfer indexing may
+/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
 /// scf.if op returns a view and values of type index.
 /// At this time, only vector.transfer_read case is implemented.
 ///
@@ -107,11 +107,11 @@ static Value createInBoundsCond(RewriterBase &b,
 /// is transformed into:
 /// ```
 ///    %1:3 = scf.if (%inBounds) {
-///      // fastpath, direct cast
+///      // fast path, direct cast
 ///      memref.cast %A: memref<A...> to compatibleMemRefType
 ///      scf.yield %view : compatibleMemRefType, index, index
 ///    } else {
-///      // slowpath, not in-bounds vector.transfer or linalg.copy.
+///      // slow path, not in-bounds vector.transfer or linalg.copy.
 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
 ///      scf.yield %4 : compatibleMemRefType, index, index
 //     }
@@ -172,12 +172,10 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
   for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
     resShape[idx] =
         (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
-    resStrides[idx] = (aStrides[idx] == bStrides[idx])
-                          ? aStrides[idx]
-                          : ShapedType::kDynamic;
+    resStrides[idx] =
+        (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
   }
-  resOffset =
-      (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
+  resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
   return MemRefType::get(
       resShape, aT.getElementType(),
       StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
@@ -634,7 +632,34 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
   return success();
 }
 
-LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
+namespace {
+/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
+/// may take an extra filter to perform selection at a finer granularity.
+struct VectorTransferFullPartialRewriter : public RewritePattern {
+  using FilterConstraintType =
+      std::function<LogicalResult(VectorTransferOpInterface op)>;
+
+  explicit VectorTransferFullPartialRewriter(
+      MLIRContext *context,
+      VectorTransformsOptions options = VectorTransformsOptions(),
+      FilterConstraintType filter =
+          [](VectorTransferOpInterface op) { return success(); },
+      PatternBenefit benefit = 1)
+      : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
+        filter(std::move(filter)) {}
+
+  /// Performs the rewrite.
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  VectorTransformsOptions options;
+  FilterConstraintType filter;
+};
+
+} // namespace
+
+LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
   auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
   if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
@@ -642,3 +667,9 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
     return failure();
   return splitFullAndPartialTransfer(rewriter, xferOp, options);
 }
+
+void mlir::vector::populateVectorTransferFullPartialPatterns(
+    RewritePatternSet &patterns, const VectorTransformsOptions &options) {
+  patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
+                                                  options);
+}

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index fe59143ebd55f..20fc59e874ab6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -51,102 +51,6 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-// Helper to find an index in an affine map.
-static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
-  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
-    int64_t idx = map.getDimPosition(i);
-    if (idx == index)
-      return i;
-  }
-  return std::nullopt;
-}
-
-// Helper to construct iterator types with one index removed.
-static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
-                                         int64_t index) {
-  SmallVector<Attribute> results;
-  for (const auto &it : llvm::enumerate(iteratorTypes)) {
-    int64_t idx = it.index();
-    if (idx == index)
-      continue;
-    results.push_back(it.value());
-  }
-  return results;
-}
-
-// Helper to construct an affine map with one index removed.
-static AffineMap adjustMap(AffineMap map, int64_t index,
-                           PatternRewriter &rewriter) {
-  auto *ctx = rewriter.getContext();
-  SmallVector<AffineExpr> results;
-  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
-    int64_t idx = map.getDimPosition(i);
-    if (idx == index)
-      continue;
-    // Re-insert remaining indices, but renamed when occurring
-    // after the removed index.
-    auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
-    results.push_back(targetExpr);
-  }
-  return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
-}
-
-// Helper method to possibly drop a dimension in a load.
-// TODO
-static Value reshapeLoad(Location loc, Value val, VectorType type,
-                         int64_t index, int64_t pos,
-                         PatternRewriter &rewriter) {
-  if (index == -1)
-    return val;
-  Type lowType = VectorType::Builder(type).dropDim(0);
-  // At extraction dimension?
-  if (index == 0) {
-    auto posAttr = rewriter.getI64ArrayAttr(pos);
-    return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
-  }
-  // Unroll leading dimensions.
-  VectorType vType = lowType.cast<VectorType>();
-  Type resType = VectorType::Builder(type).dropDim(index);
-  auto resVectorType = resType.cast<VectorType>();
-  Value result = rewriter.create<arith::ConstantOp>(
-      loc, resVectorType, rewriter.getZeroAttr(resVectorType));
-  for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
-    auto posAttr = rewriter.getI64ArrayAttr(d);
-    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
-    Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
-    result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
-                                               posAttr);
-  }
-  return result;
-}
-
-// Helper method to possibly drop a dimension in a store.
-// TODO
-static Value reshapeStore(Location loc, Value val, Value result,
-                          VectorType type, int64_t index, int64_t pos,
-                          PatternRewriter &rewriter) {
-  // Unmodified?
-  if (index == -1)
-    return val;
-  // At insertion dimension?
-  if (index == 0) {
-    auto posAttr = rewriter.getI64ArrayAttr(pos);
-    return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
-  }
-  // Unroll leading dimensions.
-  Type lowType = VectorType::Builder(type).dropDim(0);
-  VectorType vType = lowType.cast<VectorType>();
-  Type insType = VectorType::Builder(vType).dropDim(0);
-  for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
-    auto posAttr = rewriter.getI64ArrayAttr(d);
-    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
-    Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
-    Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
-    result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
-  }
-  return result;
-}
-
 template <typename IntType>
 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(llvm::map_range(
@@ -154,61 +58,11 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
 }
 
-/// Helper to create arithmetic operation associated with a kind of contraction.
-static std::optional<Value>
-createContractArithOp(Location loc, Value x, Value y, Value acc,
-                      vector::CombiningKind kind, PatternRewriter &rewriter,
-                      bool isInt, Value mask = Value()) {
-  using vector::CombiningKind;
-  Value mul;
-
-  if (isInt) {
-    if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
-      // Only valid for floating point types.
-      return std::nullopt;
-    mul = rewriter.create<arith::MulIOp>(loc, x, y);
-  } else {
-    // Float case.
-    if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
-        kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
-        kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
-        kind == CombiningKind::XOR)
-      // Only valid for integer types.
-      return std::nullopt;
-    // Special case for fused multiply-add.
-    if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
-      Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
-      if (mask)
-        // The fma op doesn't need explicit masking. However, fma ops used in
-        // reductions must preserve previous 'acc' values for masked-out lanes.
-        fma = selectPassthru(rewriter, mask, fma, acc);
-      return fma;
-    }
-    mul = rewriter.create<arith::MulFOp>(loc, x, y);
-  }
-
-  if (!acc)
-    return std::optional<Value>(mul);
-
-  return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
-}
-
-/// Return the positions of the reductions in the given map.
-static SmallVector<int64_t> getReductionIndex(AffineMap map,
-                                              ArrayAttr iteratorTypes) {
-  SmallVector<int64_t> dimsIdx;
-  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
-    if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
-      dimsIdx.push_back(i);
-  }
-  return dimsIdx;
-}
-
-/// Look for a given dimension in an affine map and return its position. Return
-/// std::nullopt if the dimension is not in the map results.
-static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
-  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
-    if (map.getDimPosition(i) == dim)
+// Helper to find an index in an affine map.
+static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
+  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+    int64_t idx = map.getDimPosition(i);
+    if (idx == index)
       return i;
   }
   return std::nullopt;
@@ -264,735 +118,6 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
   }
 };
 
-/// Progressive lowering of BroadcastOp.
-class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::BroadcastOp op,
-                                PatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-    VectorType dstType = op.getResultVectorType();
-    VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
-    Type eltType = dstType.getElementType();
-
-    // Scalar to any vector can use splat.
-    if (!srcType) {
-      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
-      return success();
-    }
-
-    // Determine rank of source and destination.
-    int64_t srcRank = srcType.getRank();
-    int64_t dstRank = dstType.getRank();
-
-    // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
-    if (srcRank <= 1 && dstRank == 1) {
-      Value ext;
-      if (srcRank == 0)
-        ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
-      else
-        ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
-      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
-      return success();
-    }
-
-    // Duplicate this rank.
-    // For example:
-    //   %x = broadcast %y  : k-D to n-D, k < n
-    // becomes:
-    //   %b = broadcast %y  : k-D to (n-1)-D
-    //   %x = [%b,%b,%b,%b] : n-D
-    // becomes:
-    //   %b = [%y,%y]       : (n-1)-D
-    //   %x = [%b,%b,%b,%b] : n-D
-    if (srcRank < dstRank) {
-      // Duplication.
-      VectorType resType =
-          VectorType::get(dstType.getShape().drop_front(), eltType);
-      Value bcst =
-          rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
-      Value result = rewriter.create<arith::ConstantOp>(
-          loc, dstType, rewriter.getZeroAttr(dstType));
-      for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
-        result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
-      rewriter.replaceOp(op, result);
-      return success();
-    }
-
-    // Find non-matching dimension, if any.
-    assert(srcRank == dstRank);
-    int64_t m = -1;
-    for (int64_t r = 0; r < dstRank; r++)
-      if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
-        m = r;
-        break;
-      }
-
-    // All trailing dimensions are the same. Simply pass through.
-    if (m == -1) {
-      rewriter.replaceOp(op, op.getSource());
-      return success();
-    }
-
-    // Any non-matching dimension forces a stretch along this rank.
-    // For example:
-    //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
-    // becomes:
-    //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
-    //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
-    //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
-    //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
-    //   %x = [%a,%b,%c,%d]
-    // becomes:
-    //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
-    //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
-    //   %a = [%u, %v]
-    //   ..
-    //   %x = [%a,%b,%c,%d]
-    VectorType resType =
-        VectorType::get(dstType.getShape().drop_front(), eltType);
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, dstType, rewriter.getZeroAttr(dstType));
-    if (m == 0) {
-      // Stetch at start.
-      Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
-      Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
-      for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
-        result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
-    } else {
-      // Stetch not at start.
-      for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
-        Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
-        Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
-        result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
-      }
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
-/// transposed.
-void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
-                            SmallVectorImpl<int64_t> &result) {
-  size_t numTransposedDims = transpose.size();
-  for (size_t transpDim : llvm::reverse(transpose)) {
-    if (transpDim != numTransposedDims - 1)
-      break;
-    numTransposedDims--;
-  }
-
-  result.append(transpose.begin(), transpose.begin() + numTransposedDims);
-}
-
-/// Progressive lowering of TransposeOp.
-/// One:
-///   %x = vector.transpose %y, [1, 0]
-/// is replaced by:
-///   %z = arith.constant dense<0.000000e+00>
-///   %0 = vector.extract %y[0, 0]
-///   %1 = vector.insert %0, %z [0, 0]
-///   ..
-///   %x = vector.insert .., .. [.., ..]
-class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
-                      MLIRContext *context, PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransposeOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions) {}
-
-  LogicalResult matchAndRewrite(vector::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-
-    Value input = op.getVector();
-    VectorType inputType = op.getSourceVectorType();
-    VectorType resType = op.getResultVectorType();
-
-    // Set up convenience transposition table.
-    SmallVector<int64_t> transp;
-    for (auto attr : op.getTransp())
-      transp.push_back(attr.cast<IntegerAttr>().getInt());
-
-    if (vectorTransformOptions.vectorTransposeLowering ==
-            vector::VectorTransposeLowering::Shuffle &&
-        resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
-      return rewriter.notifyMatchFailure(
-          op, "Options specifies lowering to shuffle");
-
-    // Handle a true 2-D matrix transpose 
diff erently when requested.
-    if (vectorTransformOptions.vectorTransposeLowering ==
-            vector::VectorTransposeLowering::Flat &&
-        resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
-      Type flattenedType =
-          VectorType::get(resType.getNumElements(), resType.getElementType());
-      auto matrix =
-          rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
-      auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
-      auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
-      Value trans = rewriter.create<vector::FlatTransposeOp>(
-          loc, flattenedType, matrix, rows, columns);
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
-      return success();
-    }
-
-    // Generate unrolled extract/insert ops. We do not unroll the rightmost
-    // (i.e., highest-order) dimensions that are not transposed and leave them
-    // in vector form to improve performance. Therefore, we prune those
-    // dimensions from the shape/transpose data structures used to generate the
-    // extract/insert ops.
-    SmallVector<int64_t> prunedTransp;
-    pruneNonTransposedDims(transp, prunedTransp);
-    size_t numPrunedDims = transp.size() - prunedTransp.size();
-    auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
-    auto prunedInStrides = computeStrides(prunedInShape);
-
-    // Generates the extract/insert operations for every scalar/vector element
-    // of the leftmost transposed dimensions. We traverse every transpose
-    // element using a linearized index that we delinearize to generate the
-    // appropriate indices for the extract/insert operations.
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resType, rewriter.getZeroAttr(resType));
-    int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
-
-    for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
-         ++linearIdx) {
-      auto extractIdxs = delinearize(linearIdx, prunedInStrides);
-      SmallVector<int64_t> insertIdxs(extractIdxs);
-      applyPermutationToVector(insertIdxs, prunedTransp);
-      Value extractOp =
-          rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
-      result =
-          rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
-    }
-
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
-};
-
-/// Rewrite a 2-D vector.transpose as a sequence of:
-///   vector.shape_cast 2D -> 1D
-///   vector.shuffle
-///   vector.shape_cast 1D -> 2D
-class TransposeOp2DToShuffleLowering
-    : public OpRewritePattern<vector::TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  TransposeOp2DToShuffleLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context, PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransposeOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions) {}
-
-  LogicalResult matchAndRewrite(vector::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-
-    VectorType srcType = op.getSourceVectorType();
-    if (srcType.getRank() != 2)
-      return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
-
-    SmallVector<int64_t> transp;
-    for (auto attr : op.getTransp())
-      transp.push_back(attr.cast<IntegerAttr>().getInt());
-    if (transp[0] != 1 && transp[1] != 0)
-      return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
-
-    if (vectorTransformOptions.vectorTransposeLowering !=
-        VectorTransposeLowering::Shuffle)
-      return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
-
-    int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
-    Value casted = rewriter.create<vector::ShapeCastOp>(
-        loc, VectorType::get({m * n}, srcType.getElementType()),
-        op.getVector());
-    SmallVector<int64_t> mask;
-    mask.reserve(m * n);
-    for (int64_t j = 0; j < n; ++j)
-      for (int64_t i = 0; i < m; ++i)
-        mask.push_back(i * n + j);
-
-    Value shuffled =
-        rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-        op, op.getResultVectorType(), shuffled);
-
-    return success();
-  }
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
-};
-
-/// Progressive lowering of OuterProductOp.
-/// One:
-///   %x = vector.outerproduct %lhs, %rhs, %acc
-/// is replaced by:
-///   %z = zero-result
-///   %0 = vector.extract %lhs[0]
-///   %1 = vector.broadcast %0
-///   %2 = vector.extract %acc[0]
-///   %3 = vector.fma %1, %rhs, %2
-///   %4 = vector.insert %3, %z[0]
-///   ..
-///   %x = vector.insert %.., %..[N-1]
-///
-class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::OuterProductOp op,
-                                PatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-
-    VectorType lhsType = op.getOperandVectorTypeLHS();
-    VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
-    VectorType resType = op.getResultVectorType();
-    Type eltType = resType.getElementType();
-    bool isInt = eltType.isa<IntegerType, IndexType>();
-    Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
-    vector::CombiningKind kind = op.getKind();
-
-    // Vector mask setup.
-    OpBuilder::InsertionGuard guard(rewriter);
-    auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
-    Operation *rootOp;
-    Value mask;
-    if (maskableOp.isMasked()) {
-      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-      rootOp = maskableOp.getMaskingOp();
-      mask = maskableOp.getMaskingOp().getMask();
-    } else {
-      rootOp = op;
-    }
-
-    if (!rhsType) {
-      // Special case: AXPY operation.
-      Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
-      std::optional<Value> mult = createContractArithOp(
-          loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
-      if (!mult.has_value())
-        return failure();
-      rewriter.replaceOp(rootOp, *mult);
-      return success();
-    }
-
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resType, rewriter.getZeroAttr(resType));
-    for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
-      auto pos = rewriter.getI64ArrayAttr(d);
-      Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
-      Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
-      Value r = nullptr;
-      if (acc)
-        r = rewriter.create<vector::ExtractOp>(loc, acc, pos);
-      Value extrMask;
-      if (mask)
-        extrMask = rewriter.create<vector::ExtractOp>(loc, mask, pos);
-
-      std::optional<Value> m = createContractArithOp(
-          loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
-      if (!m.has_value())
-        return failure();
-      result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
-    }
-
-    rewriter.replaceOp(rootOp, result);
-    return success();
-  }
-};
-
-/// Lower vector.contract with all size one reduction dimensions to
-/// elementwise ops when possible.
-struct ContractOpToElementwise
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern::OpRewritePattern;
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-  ContractOpToElementwise(
-      vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context, PatternBenefit benefit = 1,
-      const FilterConstraintType &constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
-
-  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
-                                PatternRewriter &rewriter) const override {
-    // TODO: Support vector.mask.
-    auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
-    if (maskableOp.isMasked())
-      return failure();
-
-    // TODO: Remove native masks from contraction op?
-    if (!contractOp.getMasks().empty())
-      return failure();
-
-    if (failed(filter(contractOp)))
-      return failure();
-
-    if (vectorTransformOptions.vectorContractLowering !=
-        vector::VectorContractLowering::ParallelArith)
-      return failure();
-
-    ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
-    ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
-    AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
-    AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
-    SmallVector<int64_t> lhsReductionDims =
-        getReductionIndex(lhsMap, contractOp.getIteratorTypes());
-    SmallVector<int64_t> rhsReductionDims =
-        getReductionIndex(rhsMap, contractOp.getIteratorTypes());
-    // All the reduction dimensions must be a size 1.
-    for (int64_t dim : lhsReductionDims) {
-      if (lhsShape[dim] != 1)
-        return failure();
-    }
-    for (int64_t dim : rhsReductionDims) {
-      if (rhsShape[dim] != 1)
-        return failure();
-    }
-    AffineMap accMap = contractOp.getIndexingMapsArray()[2];
-    unsigned numParallelDims = accMap.getNumResults();
-    unsigned numLhsDimToBroadcast =
-        numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
-    unsigned numRhsDimToBroadcast =
-        numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
-    SmallVector<int64_t> lhsDims;
-    SmallVector<int64_t> lhsTranspose;
-    SmallVector<int64_t> rhsDims;
-    SmallVector<int64_t> rhsTranspose;
-    for (int64_t dim : lhsReductionDims)
-      lhsTranspose.push_back(numLhsDimToBroadcast + dim);
-    for (int64_t dim : rhsReductionDims)
-      rhsTranspose.push_back(numRhsDimToBroadcast + dim);
-    // Loop through the parallel dimensions to calculate the dimensions to
-    // broadcast and to permute in order to extract only parallel dimensions.
-    for (unsigned i = 0; i < numParallelDims; i++) {
-      std::optional<unsigned> lhsDim =
-          getDimPosition(lhsMap, accMap.getDimPosition(i));
-      if (lhsDim) {
-        lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
-      } else {
-        // If the parallel dimension doesn't exist we will have to broadcast it.
-        lhsDims.push_back(
-            contractOp.getResultType().cast<VectorType>().getDimSize(i));
-        lhsTranspose.push_back(lhsDims.size() - 1);
-      }
-      std::optional<unsigned> rhsDim =
-          getDimPosition(rhsMap, accMap.getDimPosition(i));
-      if (rhsDim) {
-        rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
-      } else {
-        // If the parallel dimension doesn't exist we will have to broadcast it.
-        rhsDims.push_back(
-            contractOp.getResultType().cast<VectorType>().getDimSize(i));
-        rhsTranspose.push_back(rhsDims.size() - 1);
-      }
-    }
-    Value newLhs = contractOp.getLhs();
-    Value newRhs = contractOp.getRhs();
-    Location loc = contractOp.getLoc();
-    if (!lhsDims.empty()) {
-      lhsDims.append(lhsShape.begin(), lhsShape.end());
-      auto expandedType =
-          VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
-      newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
-    }
-    if (!rhsDims.empty()) {
-      rhsDims.append(rhsShape.begin(), rhsShape.end());
-      auto expandedType =
-          VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
-      newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
-    }
-    bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
-    newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
-    newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
-    SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
-    SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
-    newLhs = rewriter.create<vector::ExtractOp>(
-        loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
-    newRhs = rewriter.create<vector::ExtractOp>(
-        loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
-    std::optional<Value> result =
-        createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
-                              contractOp.getKind(), rewriter, isInt);
-    rewriter.replaceOp(contractOp, {*result});
-    return success();
-  }
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
-  FilterConstraintType filter;
-};
-
-/// Progressive lowering of ConstantMaskOp.
-/// One:
-///   %x = vector.constant_mask [a,b]
-/// is replaced by:
-///   %z = zero-result
-///   %l = vector.constant_mask [b]
-///   %4 = vector.insert %l, %z[0]
-///   ..
-///   %x = vector.insert %l, %..[a-1]
-/// until a one-dimensional vector is reached. All these operations
-/// will be folded at LLVM IR level.
-class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
-                                PatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-    auto dstType = op.getType();
-    auto eltType = dstType.getElementType();
-    auto dimSizes = op.getMaskDimSizes();
-    int64_t rank = dstType.getRank();
-
-    if (rank == 0) {
-      assert(dimSizes.size() == 1 &&
-             "Expected exactly one dim size for a 0-D vector");
-      bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, dstType,
-          DenseIntElementsAttr::get(
-              VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
-              ArrayRef<bool>{value}));
-      return success();
-    }
-
-    // Scalable constant masks can only be lowered for the "none set" case.
-    if (dstType.cast<VectorType>().isScalable()) {
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, DenseElementsAttr::get(dstType, false));
-      return success();
-    }
-
-    int64_t trueDim = std::min(dstType.getDimSize(0),
-                               dimSizes[0].cast<IntegerAttr>().getInt());
-
-    if (rank == 1) {
-      // Express constant 1-D case in explicit vector form:
-      //   [T,..,T,F,..,F].
-      SmallVector<bool> values(dstType.getDimSize(0));
-      for (int64_t d = 0; d < trueDim; d++)
-        values[d] = true;
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, dstType, rewriter.getBoolVectorAttr(values));
-      return success();
-    }
-
-    VectorType lowType =
-        VectorType::get(dstType.getShape().drop_front(), eltType);
-    SmallVector<int64_t> newDimSizes;
-    for (int64_t r = 1; r < rank; r++)
-      newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
-    Value trueVal = rewriter.create<vector::ConstantMaskOp>(
-        loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, dstType, rewriter.getZeroAttr(dstType));
-    for (int64_t d = 0; d < trueDim; d++) {
-      auto pos = rewriter.getI64ArrayAttr(d);
-      result =
-          rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-/// Progressive lowering of CreateMaskOp.
-/// One:
-///   %x = vector.create_mask %a, ... : vector<dx...>
-/// is replaced by:
-///   %l = vector.create_mask ... : vector<...>  ; one lower rank
-///   %0 = arith.cmpi "slt", %ci, %a       |
-///   %1 = select %0, %l, %zeroes    |
-///   %r = vector.insert %1, %pr [i] | d-times
-///   %x = ....
-/// until a one-dimensional vector is reached.
-class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
-                                PatternRewriter &rewriter) const override {
-    auto dstType = op.getResult().getType().cast<VectorType>();
-    int64_t rank = dstType.getRank();
-    if (rank <= 1)
-      return rewriter.notifyMatchFailure(
-          op, "0-D and 1-D vectors are handled separately");
-
-    auto loc = op.getLoc();
-    auto eltType = dstType.getElementType();
-    int64_t dim = dstType.getDimSize(0);
-    Value idx = op.getOperand(0);
-
-    VectorType lowType =
-        VectorType::get(dstType.getShape().drop_front(), eltType);
-    Value trueVal = rewriter.create<vector::CreateMaskOp>(
-        loc, lowType, op.getOperands().drop_front());
-    Value falseVal = rewriter.create<arith::ConstantOp>(
-        loc, lowType, rewriter.getZeroAttr(lowType));
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, dstType, rewriter.getZeroAttr(dstType));
-    for (int64_t d = 0; d < dim; d++) {
-      Value bnd =
-          rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
-      Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
-                                                 bnd, idx);
-      Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
-      auto pos = rewriter.getI64ArrayAttr(d);
-      result =
-          rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
-/// vectors progressively on the way to target llvm.matrix intrinsics.
-/// This iterates over the most major dimension of the 2-D vector and performs
-/// rewrites into:
-///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOp2DDownCastRewritePattern
-    : public OpRewritePattern<vector::ShapeCastOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
-                                PatternRewriter &rewriter) const override {
-    auto sourceVectorType = op.getSourceVectorType();
-    auto resultVectorType = op.getResultVectorType();
-    if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
-      return failure();
-
-    auto loc = op.getLoc();
-    Value desc = rewriter.create<arith::ConstantOp>(
-        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
-    unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
-    for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
-      Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
-      desc = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, vec, desc,
-          /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
-    }
-    rewriter.replaceOp(op, desc);
-    return success();
-  }
-};
-
-/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
-/// vectors progressively.
-/// This iterates over the most major dimension of the 2-D vector and performs
-/// rewrites into:
-///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
-/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
-class ShapeCastOp2DUpCastRewritePattern
-    : public OpRewritePattern<vector::ShapeCastOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
-                                PatternRewriter &rewriter) const override {
-    auto sourceVectorType = op.getSourceVectorType();
-    auto resultVectorType = op.getResultVectorType();
-    if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
-      return failure();
-
-    auto loc = op.getLoc();
-    Value desc = rewriter.create<arith::ConstantOp>(
-        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
-    unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
-    for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
-      Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
-          /*sizes=*/mostMinorVectorSize,
-          /*strides=*/1);
-      desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
-    }
-    rewriter.replaceOp(op, desc);
-    return success();
-  }
-};
-
-// We typically should not lower general shape cast operations into data
-// movement instructions, since the assumption is that these casts are
-// optimized away during progressive lowering. For completeness, however,
-// we fall back to a reference implementation that moves all elements
-// into the right place if we get here.
-class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    auto sourceVectorType = op.getSourceVectorType();
-    auto resultVectorType = op.getResultVectorType();
-
-    // Special case 2D/1D lowerings with better implementations.
-    // TODO: make is ND/1D to allow generic ND->1D->MD.
-    int64_t srcRank = sourceVectorType.getRank();
-    int64_t resRank = resultVectorType.getRank();
-    if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
-      return failure();
-
-    // Generic ShapeCast lowering path goes all the way down to unrolled scalar
-    // extract/insert chains.
-    // TODO: consider evolving the semantics to only allow 1D source or dest and
-    // drop this potentially very expensive lowering.
-    // Compute number of elements involved in the reshape.
-    int64_t numElts = 1;
-    for (int64_t r = 0; r < srcRank; r++)
-      numElts *= sourceVectorType.getDimSize(r);
-    // Replace with data movement operations:
-    //    x[0,0,0] = y[0,0]
-    //    x[0,0,1] = y[0,1]
-    //    x[0,1,0] = y[0,2]
-    // etc., incrementing the two index vectors "row-major"
-    // within the source and result shape.
-    SmallVector<int64_t> srcIdx(srcRank);
-    SmallVector<int64_t> resIdx(resRank);
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
-    for (int64_t i = 0; i < numElts; i++) {
-      if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, srcRank - 1);
-        incIdx(resIdx, resultVectorType, resRank - 1);
-      }
-      Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-
-private:
-  static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
-    assert(0 <= r && r < tp.getRank());
-    if (++idx[r] == tp.getDimSize(r)) {
-      idx[r] = 0;
-      incIdx(idx, tp, r - 1);
-    }
-  }
-};
-
 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
 /// Ex:
 /// ```
@@ -1425,967 +550,6 @@ struct ReorderElementwiseOpsOnTranspose final
   }
 };
 
-} // namespace
-
-/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
-/// operands `x` and `y`.
-static Value createAdd(Location loc, Value x, Value y, bool isInt,
-                       PatternRewriter &rewriter) {
-  if (isInt)
-    return rewriter.create<arith::AddIOp>(loc, x, y);
-  return rewriter.create<arith::AddFOp>(loc, x, y);
-}
-
-/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
-/// operands `x and `y`.
-static Value createMul(Location loc, Value x, Value y, bool isInt,
-                       PatternRewriter &rewriter) {
-  if (isInt)
-    return rewriter.create<arith::MulIOp>(loc, x, y);
-  return rewriter.create<arith::MulFOp>(loc, x, y);
-}
-
-namespace mlir {
-
-/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to:
-/// ```
-///    %mta = maybe_transpose
-///    %mtb = maybe_transpose
-///    %flattened_a = vector.shape_cast %mta
-///    %flattened_b = vector.shape_cast %mtb
-///    %flattened_d = vector.matmul %flattened_a, %flattened_b
-///    %mtd = vector.shape_cast %flattened_d
-///    %d = maybe_untranspose %mtd
-///    %e = add %c, %d
-/// ```
-/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
-//
-/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
-/// vector.transpose operations are inserted if the vector.contract op is not a
-/// row-major matrix multiply.
-LogicalResult
-ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
-                                                 PatternRewriter &rew) const {
-  // TODO: Support vector.mask.
-  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-  if (maskableOp.isMasked())
-    return failure();
-
-  // TODO: Remove native masks from contraction op?
-  if (!op.getMasks().empty())
-    return failure();
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::Matmul)
-    return failure();
-  if (failed(filter(op)))
-    return failure();
-
-  auto iteratorTypes = op.getIteratorTypes().getValue();
-  if (!isParallelIterator(iteratorTypes[0]) ||
-      !isParallelIterator(iteratorTypes[1]) ||
-      !isReductionIterator(iteratorTypes[2]))
-    return failure();
-
-  Type elementType = op.getLhsType().getElementType();
-  if (!elementType.isIntOrFloat())
-    return failure();
-
-  Type dstElementType = op.getType();
-  if (auto vecType = dstElementType.dyn_cast<VectorType>())
-    dstElementType = vecType.getElementType();
-  if (elementType != dstElementType)
-    return failure();
-
-  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
-  // Bail out if the contraction cannot be put in this form.
-  MLIRContext *ctx = op.getContext();
-  Location loc = op.getLoc();
-  AffineExpr m, n, k;
-  bindDims(rew.getContext(), m, n, k);
-  // LHS must be A(m, k) or A(k, m).
-  Value lhs = op.getLhs();
-  auto lhsMap = op.getIndexingMapsArray()[0];
-  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
-    lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
-  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
-    return failure();
-
-  // RHS must be B(k, n) or B(n, k).
-  Value rhs = op.getRhs();
-  auto rhsMap = op.getIndexingMapsArray()[1];
-  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
-    rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
-  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
-    return failure();
-
-  // At this point lhs and rhs are in row-major.
-  VectorType lhsType = lhs.getType().cast<VectorType>();
-  VectorType rhsType = rhs.getType().cast<VectorType>();
-  int64_t lhsRows = lhsType.getDimSize(0);
-  int64_t lhsColumns = lhsType.getDimSize(1);
-  int64_t rhsColumns = rhsType.getDimSize(1);
-
-  Type flattenedLHSType =
-      VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
-  lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
-
-  Type flattenedRHSType =
-      VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
-  rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
-
-  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
-                                           rhsColumns);
-  mul = rew.create<vector::ShapeCastOp>(
-      loc,
-      VectorType::get({lhsRows, rhsColumns},
-                      getElementTypeOrSelf(op.getAcc().getType())),
-      mul);
-
-  // ACC must be C(m, n) or C(n, m).
-  auto accMap = op.getIndexingMapsArray()[2];
-  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
-    mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
-  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
-    llvm_unreachable("invalid contraction semantics");
-
-  Value res =
-      elementType.isa<IntegerType>()
-          ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
-          : static_cast<Value>(
-                rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
-
-  rew.replaceOp(op, res);
-  return success();
-}
-
-namespace {
-
-/// Generate a vector implementation for matmat, matvec and tmatvec.
-/// This unrolls outer-products along the reduction dimension.
-struct UnrolledOuterProductGenerator
-    : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
-  UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
-      : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
-        kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
-        res(op.getAcc()), lhsType(op.getLhsType()) {
-    auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-    if (maskableOp.isMasked())
-      mask = maskableOp.getMaskingOp().getMask();
-  }
-
-  Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
-    if (!v)
-      return v;
-    return rewriter.create<vector::TransposeOp>(loc, v, perm);
-  }
-
-  Value promote(Value v, Type dstElementType) {
-    Type elementType = v.getType();
-    auto vecType = elementType.dyn_cast<VectorType>();
-    if (vecType)
-      elementType = vecType.getElementType();
-    if (elementType == dstElementType)
-      return v;
-    Type promotedType = dstElementType;
-    if (vecType)
-      promotedType = VectorType::get(vecType.getShape(), promotedType);
-    if (dstElementType.isa<FloatType>())
-      return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
-    return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
-  }
-
-  FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
-                             std::optional<Value> maybeMask = std::nullopt) {
-    assert(reductionSize > 0);
-    // Incremental support for masking.
-    if (mask && !maybeMask.has_value())
-      return failure();
-
-    Type resElementType = res.getType().cast<VectorType>().getElementType();
-    for (int64_t k = 0; k < reductionSize; ++k) {
-      Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
-      Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
-      extractA = promote(extractA, resElementType);
-      extractB = promote(extractB, resElementType);
-      Value extractMask;
-      if (maybeMask.has_value() && maybeMask.value())
-        extractMask =
-            rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
-
-      Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
-          loc, res.getType(), extractA, extractB, res, kind);
-      res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
-    }
-    return res;
-  }
-
-  /// Two outer parallel, one inner reduction (matmat flavor).
-  FailureOr<Value> matmat() {
-    if (!iters({Par(), Par(), Red()}))
-      return failure();
-    // Set up the parallel/reduction structure in the right form.
-    AffineExpr m, n, k;
-    bindDims(rewriter.getContext(), m, n, k);
-    // Classical row-major matmul:  Just permute the lhs.
-    if (layout({{m, k}, {k, n}, {m, n}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
-                       t(mask, {2, 0, 1}));
-    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-    if (layout({{m, k}, {n, k}, {m, n}})) {
-      Value tlhs = t(lhs);
-      return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
-    }
-    // No need to permute anything.
-    if (layout({{k, m}, {k, n}, {m, n}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
-    // Just permute the rhs.
-    if (layout({{k, m}, {n, k}, {m, n}}))
-      return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
-    // Transposed output: swap RHS and LHS.
-    // Classical row-major matmul: permute the lhs.
-    if (layout({{m, k}, {k, n}, {n, m}}))
-      return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
-    // TODO: may be better to fail and use some vector<k> -> scalar reduction.
-    if (layout({{m, k}, {n, k}, {n, m}})) {
-      Value trhs = t(rhs);
-      return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
-    }
-    if (layout({{k, m}, {k, n}, {n, m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
-    if (layout({{k, m}, {n, k}, {n, m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
-    return failure();
-  }
-
-  /// One outer parallel, one inner reduction (matvec flavor)
-  FailureOr<Value> matvec() {
-    if (!iters({Par(), Red()}))
-      return failure();
-    AffineExpr m, k;
-    bindDims(rewriter.getContext(), m, k);
-
-    // Case mat-vec: transpose.
-    if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
-    // Case mat-trans-vec: ready to go.
-    if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
-    // Case vec-mat: swap and transpose.
-    if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
-    // Case vec-mat-trans: swap and ready to go.
-    if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
-    return failure();
-  }
-
-  //
-  // One outer reduction, one inner parallel (tmatvec flavor)
-  //
-  FailureOr<Value> tmatvec() {
-    if (!iters({Red(), Par()}))
-      return failure();
-    AffineExpr k, m;
-    bindDims(rewriter.getContext(), k, m);
-
-    // Case mat-vec: transpose.
-    if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
-    // Case mat-trans-vec: ready to go.
-    if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
-    // Case vec-mat: swap and transpose.
-    if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
-    // Case vec-mat-trans: swap and ready to go.
-    if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
-    return failure();
-  }
-
-private:
-  vector::CombiningKind kind;
-  Value lhs, rhs, res, mask;
-  VectorType lhsType;
-};
-} // namespace
-
-/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to a reduction_size-unrolled sequence:
-/// ```
-///    %at = vector.transpose %a, [1, 0]
-///    %bRow0 = vector.extract %b[0]
-///    %atRow0 = vector.extract %at[0]
-///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
-///    ...
-///    %bRowK = vector.extract %b[K]
-///    %atRowK = vector.extract %at[K]
-///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
-/// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
-    vector::ContractionOp op, PatternRewriter &rewriter) const {
-  // TODO: Remove native masks from contraction op?
-  if (!op.getMasks().empty())
-    return failure();
-
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::OuterProduct)
-    return failure();
-
-  if (failed(filter(op)))
-    return failure();
-
-  // Vector mask setup.
-  OpBuilder::InsertionGuard guard(rewriter);
-  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
-  Operation *rootOp;
-  if (maskableOp.isMasked()) {
-    rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-    rootOp = maskableOp.getMaskingOp();
-  } else {
-    rootOp = op;
-  }
-
-  UnrolledOuterProductGenerator e(rewriter, op);
-  FailureOr<Value> matmatRes = e.matmat();
-  if (succeeded(matmatRes)) {
-    rewriter.replaceOp(rootOp, *matmatRes);
-    return success();
-  }
-  FailureOr<Value> matvecRes = e.matvec();
-  if (succeeded(matvecRes)) {
-    rewriter.replaceOp(rootOp, *matvecRes);
-    return success();
-  }
-  FailureOr<Value> tmatvecRes = e.tmatvec();
-  if (succeeded(tmatvecRes)) {
-    rewriter.replaceOp(rootOp, *tmatvecRes);
-    return success();
-  }
-
-  return failure();
-}
-
-LogicalResult
-ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
-                                            PatternRewriter &rewriter) const {
-  // TODO: Support vector.mask.
-  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-  if (maskableOp.isMasked())
-    return failure();
-
-  // TODO: Remove native masks from contraction op?
-  if (!op.getMasks().empty())
-    return failure();
-
-  if (failed(filter(op)))
-    return failure();
-
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::Dot)
-    return failure();
-
-  auto iteratorTypes = op.getIteratorTypes().getValue();
-  static constexpr std::array<int64_t, 2> perm = {1, 0};
-  Location loc = op.getLoc();
-  Value lhs = op.getLhs(), rhs = op.getRhs();
-
-  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
-  AffineExpr m, n, k;
-  bindDims(rewriter.getContext(), m, n, k);
-  SmallVector<AffineMap> maps = op.getIndexingMapsArray();
-  //
-  // In the following we wish to make the reduction dimension innermost so we
-  // can load vectors and just fmul + reduce into a scalar.
-  //
-  if (isParallelIterator(iteratorTypes[0]) &&
-      isParallelIterator(iteratorTypes[1]) &&
-      isReductionIterator(iteratorTypes[2])) {
-    //
-    // Two outer parallel, one inner reduction (matmat flavor).
-    //
-    if (maps == infer({{m, k}, {k, n}, {m, n}})) {
-      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-    } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
-      // No need to permute anything.
-    } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
-      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-    } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
-      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-    } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
-      // This is the classical row-major matmul. Just permute the lhs.
-      Value tmp = lhs;
-      lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-      rhs = tmp;
-    } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
-      std::swap(lhs, rhs);
-    } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
-      Value tmp = lhs;
-      lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
-      rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
-    } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
-      Value tmp = rhs;
-      rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-      lhs = tmp;
-    } else {
-      return failure();
-    }
-  } else if (isParallelIterator(iteratorTypes[0]) &&
-             isReductionIterator(iteratorTypes[1])) {
-    //
-    // One outer parallel, one inner reduction (matvec flavor)
-    //
-    if (maps == infer({{m, n}, {n}, {m}})) {
-      // No need to permute anything.
-    } else if (maps == infer({{n, m}, {n}, {m}})) {
-      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-    } else if (maps == infer({{n}, {m, n}, {m}})) {
-      std::swap(lhs, rhs);
-    } else if (maps == infer({{n}, {n, m}, {m}})) {
-      std::swap(lhs, rhs);
-      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
-    } else {
-      return failure();
-    }
-  } else {
-    return failure();
-  }
-
-  VectorType dstType = op.getResultType().cast<VectorType>();
-  assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
-         "Expected dst type of rank 1 or 2");
-
-  unsigned rank = dstType.getRank();
-  unsigned dstRows = dstType.getShape()[0];
-  unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
-
-  // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
-  Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
-                                                 rewriter.getZeroAttr(dstType));
-  bool isInt = dstType.getElementType().isa<IntegerType>();
-  for (unsigned r = 0; r < dstRows; ++r) {
-    Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
-    for (unsigned c = 0; c < dstColumns; ++c) {
-      Value b = rank == 1
-                    ? rhs
-                    : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
-      Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
-      Value reduced = rewriter.create<vector::ReductionOp>(
-          op.getLoc(), vector::CombiningKind::ADD, m);
-
-      SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
-                                              : SmallVector<int64_t, 2>{r, c};
-      res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
-    }
-  }
-  if (auto acc = op.getAcc())
-    res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
-  rewriter.replaceOp(op, res);
-  return success();
-}
-
-/// Progressive lowering of ContractionOp.
-/// One:
-///   %x = vector.contract with at least one free/batch dimension
-/// is replaced by:
-///   %a = vector.contract with one less free/batch dimension
-///   %b = vector.contract with one less free/batch dimension
-///   ..
-///   %x = combine %a %b ..
-/// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a dot-product.
-///
-/// This only kicks in when either VectorTransformsOptions is set
-/// to DOT or when other contraction patterns fail.
-//
-// TODO: break down into transpose/reshape/cast ops
-//               when they become available to avoid code dup
-// TODO: investigate lowering order impact on performance
-LogicalResult
-ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
-                                       PatternRewriter &rewriter) const {
-  // TODO: Remove native masks from contraction op?
-  if (!op.getMasks().empty())
-    return failure();
-
-  if (failed(filter(op)))
-    return failure();
-
-  // TODO: support mixed mode contract lowering.
-  if (op.getLhsType().getElementType() !=
-          getElementTypeOrSelf(op.getAccType()) ||
-      op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
-    return failure();
-
-  // TODO: the code below assumes the default contraction, make sure it supports
-  // other kinds before enabling this lowering.
-  if (op.getKind() != vector::CombiningKind::ADD) {
-    return rewriter.notifyMatchFailure(
-        op, "contractions other than 'add' not supported");
-  }
-
-  // TODO: implement benefits, cost models.
-  MLIRContext *ctx = op.getContext();
-  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
-  if (succeeded(pat1.matchAndRewrite(op, rewriter)))
-    return success();
-  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
-  if (succeeded(pat2.matchAndRewrite(op, rewriter)))
-    return success();
-  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
-  if (succeeded(pat3.matchAndRewrite(op, rewriter)))
-    return success();
-  ContractOpToElementwise pat4(vectorTransformOptions, ctx);
-  if (succeeded(pat4.matchAndRewrite(op, rewriter)))
-    return success();
-
-  // Vector mask setup.
-  OpBuilder::InsertionGuard guard(rewriter);
-  Operation *rootOp = op;
-  Value mask;
-  if (op.isMasked()) {
-    rewriter.setInsertionPoint(op.getMaskingOp());
-    rootOp = op.getMaskingOp();
-    mask = op.getMaskingOp().getMask();
-  }
-
-  // Find first batch dimension in LHS/RHS, and lower when found.
-  std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
-  if (!batchDimMap.empty()) {
-    int64_t lhsIndex = batchDimMap[0].first;
-    int64_t rhsIndex = batchDimMap[0].second;
-    auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
-    if (failed(newOp))
-      return failure();
-    rewriter.replaceOp(rootOp, *newOp);
-    return success();
-  }
-
-  // Collect contracting dimensions.
-  std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
-      op.getContractingDimMap();
-  DenseSet<int64_t> lhsContractingDimSet;
-  DenseSet<int64_t> rhsContractingDimSet;
-  for (auto &dimPair : contractingDimMap) {
-    lhsContractingDimSet.insert(dimPair.first);
-    rhsContractingDimSet.insert(dimPair.second);
-  }
-
-  // Find first free dimension in LHS, and lower when found.
-  VectorType lhsType = op.getLhsType();
-  for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
-    if (lhsContractingDimSet.count(lhsIndex) == 0) {
-      auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
-      if (failed(newOp))
-        return failure();
-      rewriter.replaceOp(rootOp, *newOp);
-      return success();
-    }
-  }
-
-  // Find first free dimension in RHS, and lower when found.
-  VectorType rhsType = op.getRhsType();
-  for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
-    if (rhsContractingDimSet.count(rhsIndex) == 0) {
-      auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
-      if (failed(newOp))
-        return failure();
-      rewriter.replaceOp(rootOp, *newOp);
-      return success();
-    }
-  }
-
-  // Lower the first remaining reduction dimension.
-  if (!contractingDimMap.empty()) {
-    auto newOp = lowerReduction(rewriter, op, mask);
-    if (failed(newOp))
-      return failure();
-    rewriter.replaceOp(rootOp, *newOp);
-    return success();
-  }
-
-  return failure();
-}
-
-// Lower one parallel dimension.
-// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
-// TODO: consider reusing existing contract unrolling
-FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
-                                                      vector::ContractionOp op,
-                                                      int64_t lhsIndex,
-                                                      int64_t rhsIndex,
-                                                      Value mask) const {
-  VectorType lhsType = op.getLhsType();
-  VectorType rhsType = op.getRhsType();
-  VectorType resType = op.getResultType().cast<VectorType>();
-  // Find the iterator type index and result index.
-  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
-  int64_t iterIndex = -1;
-  int64_t dimSize = -1;
-  if (lhsIndex >= 0) {
-    iterIndex = iMap[0].getDimPosition(lhsIndex);
-    if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
-      return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-        diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
-             << " to map to the same dimension";
-      });
-    dimSize = lhsType.getDimSize(lhsIndex);
-  } else if (rhsIndex >= 0) {
-    iterIndex = iMap[1].getDimPosition(rhsIndex);
-    dimSize = rhsType.getDimSize(rhsIndex);
-  }
-  if (iterIndex < 0)
-    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-      diag << "expected either lhsIndex=" << lhsIndex
-           << " or rhsIndex=" << rhsIndex << " to be nonnegative";
-    });
-  // value_or(-1) means that we tolerate a dimension not appearing
-  // in the result map. That can't happen for actual parallel iterators, but
-  // the caller ContractionOpLowering::matchAndRewrite is currently calling
-  // lowerParallel also for the case of unit-size reduction dims appearing only
-  // on one of LHS or RHS, not both. At the moment, such cases are created by
-  // CastAwayContractionLeadingOneDim, so we need to either support that or
-  // modify that pattern.
-  int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
-  if (resIndex == -1 && dimSize != 1)
-    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-      diag << "expected the dimension for iterIndex=" << iterIndex
-           << " to either appear in the result map, or to be a unit dimension";
-    });
-
-  // Construct new iterator types and affine map array attribute.
-  std::array<AffineMap, 3> lowIndexingMaps = {
-      adjustMap(iMap[0], iterIndex, rewriter),
-      adjustMap(iMap[1], iterIndex, rewriter),
-      adjustMap(iMap[2], iterIndex, rewriter)};
-  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
-  auto lowIter =
-      rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
-  // Unroll into a series of lower dimensional vector.contract ops.
-  Location loc = op.getLoc();
-  Value result = rewriter.create<arith::ConstantOp>(
-      loc, resType, rewriter.getZeroAttr(resType));
-
-  for (int64_t d = 0; d < dimSize; ++d) {
-    auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
-    auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
-    auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
-
-    Value lowMask;
-    if (mask)
-      lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
-                            iterIndex, d, rewriter);
-
-    Operation *lowContract = rewriter.create<vector::ContractionOp>(
-        loc, lhs, rhs, acc, lowAffine, lowIter);
-    lowContract = maskOperation(rewriter, lowContract, lowMask);
-    result = reshapeStore(loc, lowContract->getResult(0), result, resType,
-                          resIndex, d, rewriter);
-  }
-  return result;
-}
-
-// Lower one reduction dimension.
-FailureOr<Value> ContractionOpLowering::lowerReduction(
-    PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
-  auto loc = op.getLoc();
-  VectorType lhsType = op.getLhsType();
-  VectorType rhsType = op.getRhsType();
-  Type resType = op.getResultType();
-  if (resType.isa<VectorType>())
-    return rewriter.notifyMatchFailure(op,
-                                       "did not expect a VectorType result");
-  bool isInt = resType.isa<IntegerType>();
-  // Use iterator index 0.
-  int64_t iterIndex = 0;
-  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
-  std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
-  std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
-  if (!lookupLhs.has_value())
-    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-      diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
-    });
-  if (!lookupRhs.has_value())
-    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-      diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
-    });
-  int64_t lhsIndex = *lookupLhs;
-  int64_t rhsIndex = *lookupRhs;
-  int64_t dimSize = lhsType.getDimSize(lhsIndex);
-  if (dimSize != rhsType.getDimSize(rhsIndex))
-    return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-      diag << "expect LHS dimension " << lhsIndex
-           << " to have the same size as RHS dimension " << rhsIndex;
-    });
-  // Base case.
-  if (lhsType.getRank() == 1) {
-    if (rhsType.getRank() != 1)
-      return rewriter.notifyMatchFailure(
-          op, "When LHS has rank 1, expected also RHS to have rank 1");
-    Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
-    auto kind = vector::CombiningKind::ADD;
-
-    Value acc = op.getAcc();
-    Operation *reductionOp =
-        acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
-            : rewriter.create<vector::ReductionOp>(loc, kind, m);
-    return maskOperation(rewriter, reductionOp, mask)->getResult(0);
-  }
-  // Construct new iterator types and affine map array attribute.
-  std::array<AffineMap, 3> lowIndexingMaps = {
-      adjustMap(iMap[0], iterIndex, rewriter),
-      adjustMap(iMap[1], iterIndex, rewriter),
-      adjustMap(iMap[2], iterIndex, rewriter)};
-  auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
-  auto lowIter =
-      rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
-  // Unroll into a series of lower dimensional vector.contract ops.
-  // By feeding the initial accumulator into the first contraction,
-  // and the result of each contraction into the next, eventually
-  // the sum of all reductions is computed.
-  Value result = op.getAcc();
-  for (int64_t d = 0; d < dimSize; ++d) {
-    auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
-    auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
-    Value newMask;
-    if (mask)
-      newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
-                            iterIndex, d, rewriter);
-
-    Operation *newContract = rewriter.create<vector::ContractionOp>(
-        loc, lhs, rhs, result, lowAffine, lowIter);
-    result = maskOperation(rewriter, newContract, newMask)->getResult(0);
-  }
-  return result;
-}
-
-} // namespace mlir
-
-/// Progressive lowering of transfer_read. This pattern supports lowering of
-/// `vector.transfer_read` to a combination of `vector.load` and
-/// `vector.broadcast` if all of the following hold:
-/// - Stride of most minor memref dimension must be 1.
-/// - Out-of-bounds masking is not required.
-/// - If the memref's element type is a vector type then it coincides with the
-///   result type.
-/// - The permutation map doesn't perform permutation (broadcasting is allowed).
-struct TransferReadToVectorLoadLowering
-    : public OpRewritePattern<vector::TransferReadOp> {
-  TransferReadToVectorLoadLowering(MLIRContext *context,
-                                   std::optional<unsigned> maxRank,
-                                   PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransferReadOp>(context, benefit),
-        maxTransferRank(maxRank) {}
-
-  LogicalResult matchAndRewrite(vector::TransferReadOp read,
-                                PatternRewriter &rewriter) const override {
-    if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
-      return failure();
-
-    SmallVector<unsigned> broadcastedDims;
-    // Permutations are handled by VectorToSCF or
-    // populateVectorTransferPermutationMapLoweringPatterns.
-    // We let the 0-d corner case pass-through as it is supported.
-    if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
-            &broadcastedDims))
-      return failure();
-
-    auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
-    if (!memRefType)
-      return failure();
-
-    // Non-unit strides are handled by VectorToSCF.
-    if (!vector::isLastMemrefDimUnitStride(memRefType))
-      return failure();
-
-    // If there is broadcasting involved then we first load the unbroadcasted
-    // vector, and then broadcast it with `vector.broadcast`.
-    ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
-    SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
-                                                  vectorShape.end());
-    for (unsigned i : broadcastedDims)
-      unbroadcastedVectorShape[i] = 1;
-    VectorType unbroadcastedVectorType = VectorType::get(
-        unbroadcastedVectorShape, read.getVectorType().getElementType());
-
-    // `vector.load` supports vector types as memref's elements only when the
-    // resulting vector type is the same as the element type.
-    auto memrefElTy = memRefType.getElementType();
-    if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
-      return failure();
-
-    // Otherwise, element types of the memref and the vector must match.
-    if (!memrefElTy.isa<VectorType>() &&
-        memrefElTy != read.getVectorType().getElementType())
-      return failure();
-
-    // Out-of-bounds dims are handled by MaterializeTransferMask.
-    if (read.hasOutOfBoundsDim())
-      return failure();
-
-    // Create vector load op.
-    Operation *loadOp;
-    if (read.getMask()) {
-      Value fill = rewriter.create<vector::SplatOp>(
-          read.getLoc(), unbroadcastedVectorType, read.getPadding());
-      loadOp = rewriter.create<vector::MaskedLoadOp>(
-          read.getLoc(), unbroadcastedVectorType, read.getSource(),
-          read.getIndices(), read.getMask(), fill);
-    } else {
-      loadOp = rewriter.create<vector::LoadOp>(
-          read.getLoc(), unbroadcastedVectorType, read.getSource(),
-          read.getIndices());
-    }
-
-    // Insert a broadcasting op if required.
-    if (!broadcastedDims.empty()) {
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          read, read.getVectorType(), loadOp->getResult(0));
-    } else {
-      rewriter.replaceOp(read, loadOp->getResult(0));
-    }
-
-    return success();
-  }
-
-  std::optional<unsigned> maxTransferRank;
-};
-
-/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
-// TODO: we shouldn't cross the vector/scalar domains just for this
-// but atm we lack the infra to avoid it. Possible solutions include:
-// - go directly to LLVM + bitcast
-// - introduce a bitcast op and likely a new pointer dialect
-// - let memref.load/store additionally support the 0-d vector case
-// There are still deeper data layout issues lingering even in this
-// trivial case (for architectures for which this matters).
-struct VectorLoadToMemrefLoadLowering
-    : public OpRewritePattern<vector::LoadOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
-                                PatternRewriter &rewriter) const override {
-    auto vecType = loadOp.getVectorType();
-    if (vecType.getNumElements() != 1)
-      return failure();
-    auto memrefLoad = rewriter.create<memref::LoadOp>(
-        loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
-                                                     memrefLoad);
-    return success();
-  }
-};
-
-/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
-struct VectorStoreToMemrefStoreLowering
-    : public OpRewritePattern<vector::StoreOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
-                                PatternRewriter &rewriter) const override {
-    auto vecType = storeOp.getVectorType();
-    if (vecType.getNumElements() != 1)
-      return failure();
-    Value extracted;
-    if (vecType.getRank() == 0) {
-      // TODO: Unifiy once ExtractOp supports 0-d vectors.
-      extracted = rewriter.create<vector::ExtractElementOp>(
-          storeOp.getLoc(), storeOp.getValueToStore());
-    } else {
-      SmallVector<int64_t> indices(vecType.getRank(), 0);
-      extracted = rewriter.create<vector::ExtractOp>(
-          storeOp.getLoc(), storeOp.getValueToStore(), indices);
-    }
-
-    rewriter.replaceOpWithNewOp<memref::StoreOp>(
-        storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
-    return success();
-  }
-};
-
-/// Progressive lowering of transfer_write. This pattern supports lowering of
-/// `vector.transfer_write` to `vector.store` if all of the following hold:
-/// - Stride of most minor memref dimension must be 1.
-/// - Out-of-bounds masking is not required.
-/// - If the memref's element type is a vector type then it coincides with the
-///   type of the written value.
-/// - The permutation map is the minor identity map (neither permutation nor
-///   broadcasting is allowed).
-struct TransferWriteToVectorStoreLowering
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  TransferWriteToVectorStoreLowering(MLIRContext *context,
-                                     std::optional<unsigned> maxRank,
-                                     PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
-        maxTransferRank(maxRank) {}
-
-  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
-                                PatternRewriter &rewriter) const override {
-    if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
-      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
-        diag << "rank exceeds maxTransferRank: " << write;
-      });
-
-    // Permutations are handled by VectorToSCF or
-    // populateVectorTransferPermutationMapLoweringPatterns.
-    if ( // pass-through for the 0-d corner case.
-        !write.getPermutationMap().isMinorIdentity())
-      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
-        diag << "permutation map is not minor identity: " << write;
-      });
-
-    auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
-    if (!memRefType)
-      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
-        diag << "not a memref type: " << write;
-      });
-
-    // Non-unit strides are handled by VectorToSCF.
-    if (!vector::isLastMemrefDimUnitStride(memRefType))
-      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
-        diag << "most minor stride is not 1: " << write;
-      });
-
-    // `vector.store` supports vector types as memref's elements only when the
-    // type of the vector value being written is the same as the element type.
-    auto memrefElTy = memRefType.getElementType();
-    if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
-      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
-        diag << "elemental type mismatch: " << write;
-      });
-
-    // Otherwise, element types of the memref and the vector must match.
-    if (!memrefElTy.isa<VectorType>() &&
-        memrefElTy != write.getVectorType().getElementType())
-      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
-        diag << "elemental type mismatch: " << write;
-      });
-
-    // Out-of-bounds dims are handled by MaterializeTransferMask.
-    if (write.hasOutOfBoundsDim())
-      return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
-        diag << "out of bounds dim: " << write;
-      });
-    if (write.getMask()) {
-      rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
-          write, write.getSource(), write.getIndices(), write.getMask(),
-          write.getVector());
-    } else {
-      rewriter.replaceOpWithNewOp<vector::StoreOp>(
-          write, write.getVector(), write.getSource(), write.getIndices());
-    }
-    return success();
-  }
-
-  std::optional<unsigned> maxTransferRank;
-};
-
 // Returns the values in `arrayAttr` as an integer vector.
 static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(
@@ -2863,202 +1027,6 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
   }
 };
 
-namespace {
-
-/// This function checks to see if the vector combining kind
-/// is consistent with the integer or float element type.
-static bool isValidKind(bool isInt, vector::CombiningKind kind) {
-  using vector::CombiningKind;
-  enum class KindType { FLOAT, INT, INVALID };
-  KindType type{KindType::INVALID};
-  switch (kind) {
-  case CombiningKind::MINF:
-  case CombiningKind::MAXF:
-    type = KindType::FLOAT;
-    break;
-  case CombiningKind::MINUI:
-  case CombiningKind::MINSI:
-  case CombiningKind::MAXUI:
-  case CombiningKind::MAXSI:
-  case CombiningKind::AND:
-  case CombiningKind::OR:
-  case CombiningKind::XOR:
-    type = KindType::INT;
-    break;
-  case CombiningKind::ADD:
-  case CombiningKind::MUL:
-    type = isInt ? KindType::INT : KindType::FLOAT;
-    break;
-  }
-  bool isValidIntKind = (type == KindType::INT) && isInt;
-  bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
-  return (isValidIntKind || isValidFloatKind);
-}
-
-/// This function constructs the appropriate integer or float
-/// operation given the vector combining kind and operands. The
-/// supported int operations are : add, mul, min (signed/unsigned),
-/// max(signed/unsigned), and, or, xor. The supported float
-/// operations are : add, mul, min and max.
-static Value genOperator(Location loc, Value x, Value y,
-                         vector::CombiningKind kind,
-                         PatternRewriter &rewriter) {
-  using vector::CombiningKind;
-
-  auto elType = x.getType().cast<VectorType>().getElementType();
-  bool isInt = elType.isIntOrIndex();
-
-  Value combinedResult{nullptr};
-  switch (kind) {
-  case CombiningKind::ADD:
-    if (isInt)
-      combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
-    else
-      combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
-    break;
-  case CombiningKind::MUL:
-    if (isInt)
-      combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
-    else
-      combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
-    break;
-  case CombiningKind::MINUI:
-    combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
-    break;
-  case CombiningKind::MINSI:
-    combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
-    break;
-  case CombiningKind::MAXUI:
-    combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
-    break;
-  case CombiningKind::MAXSI:
-    combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
-    break;
-  case CombiningKind::AND:
-    combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
-    break;
-  case CombiningKind::OR:
-    combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
-    break;
-  case CombiningKind::XOR:
-    combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
-    break;
-  case CombiningKind::MINF:
-    combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
-    break;
-  case CombiningKind::MAXF:
-    combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
-    break;
-  }
-  return combinedResult;
-}
-
-/// Convert vector.scan op into arith ops and
-/// vector.insert_strided_slice/extract_strided_slice
-///
-/// Ex:
-/// ```
-///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
-///   1} :
-///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
-/// ```
-/// Gets converted to:
-/// ```
-///   %cst = arith.constant dense<0> : vector<2x3xi32>
-///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
-///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
-///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
-///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
-///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
-///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
-///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
-///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
-///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
-///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
-///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
-///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
-///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
-///   vector<2x3xi32>, vector<2xi32>
-/// ```
-struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ScanOp scanOp,
-                                PatternRewriter &rewriter) const override {
-    auto loc = scanOp.getLoc();
-    VectorType destType = scanOp.getDestType();
-    ArrayRef<int64_t> destShape = destType.getShape();
-    auto elType = destType.getElementType();
-    bool isInt = elType.isIntOrIndex();
-    if (!isValidKind(isInt, scanOp.getKind()))
-      return failure();
-
-    VectorType resType = VectorType::get(destShape, elType);
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resType, rewriter.getZeroAttr(resType));
-    int64_t reductionDim = scanOp.getReductionDim();
-    bool inclusive = scanOp.getInclusive();
-    int64_t destRank = destType.getRank();
-    VectorType initialValueType = scanOp.getInitialValueType();
-    int64_t initialValueRank = initialValueType.getRank();
-
-    SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
-    reductionShape[reductionDim] = 1;
-    VectorType reductionType = VectorType::get(reductionShape, elType);
-    SmallVector<int64_t> offsets(destRank, 0);
-    SmallVector<int64_t> strides(destRank, 1);
-    SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
-    sizes[reductionDim] = 1;
-    ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
-    ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
-
-    Value lastOutput, lastInput;
-    for (int i = 0; i < destShape[reductionDim]; i++) {
-      offsets[reductionDim] = i;
-      ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
-      Value input = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
-          scanStrides);
-      Value output;
-      if (i == 0) {
-        if (inclusive) {
-          output = input;
-        } else {
-          if (initialValueRank == 0) {
-            // ShapeCastOp cannot handle 0-D vectors
-            output = rewriter.create<vector::BroadcastOp>(
-                loc, input.getType(), scanOp.getInitialValue());
-          } else {
-            output = rewriter.create<vector::ShapeCastOp>(
-                loc, input.getType(), scanOp.getInitialValue());
-          }
-        }
-      } else {
-        Value y = inclusive ? input : lastInput;
-        output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
-        assert(output != nullptr);
-      }
-      result = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, output, result, offsets, strides);
-      lastOutput = output;
-      lastInput = input;
-    }
-
-    Value reduction;
-    if (initialValueRank == 0) {
-      Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
-      reduction =
-          rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
-    } else {
-      reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
-                                                       lastOutput);
-    }
-
-    rewriter.replaceOp(scanOp, {result, reduction});
-    return success();
-  }
-};
-
 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
 /// semantics to a contraction suitable for MMT (matrix matrix multiplication
 /// with the RHS transposed) lowering.
@@ -3157,132 +1125,6 @@ struct CanonicalizeContractMatmulToMMT final
   FilterConstraintType filter;
 };
 
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
-/// outermost dimension. For example:
-/// ```
-/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
-///        ... into vector<2x3xf32>
-///
-/// ==>
-///
-/// %0   = arith.constant dense<0.0> : vector<2x3xf32>
-/// %g0  = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
-/// %1   = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
-/// %g1  = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
-/// %g   = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
-/// ```
-///
-/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::GatherOp op,
-                                PatternRewriter &rewriter) const override {
-    VectorType resultTy = op.getType();
-    if (resultTy.getRank() < 2)
-      return rewriter.notifyMatchFailure(op, "already flat");
-
-    Location loc = op.getLoc();
-    Value indexVec = op.getIndexVec();
-    Value maskVec = op.getMask();
-    Value passThruVec = op.getPassThru();
-
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultTy, rewriter.getZeroAttr(resultTy));
-
-    Type subTy = VectorType::get(resultTy.getShape().drop_front(),
-                                 resultTy.getElementType());
-
-    for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
-      int64_t thisIdx[1] = {i};
-
-      Value indexSubVec =
-          rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
-      Value maskSubVec =
-          rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
-      Value passThruSubVec =
-          rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
-      Value subGather = rewriter.create<vector::GatherOp>(
-          loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
-          passThruSubVec);
-      result =
-          rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
-    }
-
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
-/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
-/// loads/extracts are made conditional using `scf.if` ops.
-struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::GatherOp op,
-                                PatternRewriter &rewriter) const override {
-    VectorType resultTy = op.getType();
-    if (resultTy.getRank() != 1)
-      return rewriter.notifyMatchFailure(op, "unsupported rank");
-
-    Location loc = op.getLoc();
-    Type elemTy = resultTy.getElementType();
-    // Vector type with a single element. Used to generate `vector.loads`.
-    VectorType elemVecTy = VectorType::get({1}, elemTy);
-
-    Value condMask = op.getMask();
-    Value base = op.getBase();
-    Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
-        loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
-        op.getIndexVec());
-    auto baseOffsets = llvm::to_vector(op.getIndices());
-    Value lastBaseOffset = baseOffsets.back();
-
-    Value result = op.getPassThru();
-
-    // Emit a conditional access for each vector element.
-    for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
-      int64_t thisIdx[1] = {i};
-      Value condition =
-          rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
-      Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
-      baseOffsets.back() =
-          rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
-
-      auto loadBuilder = [&](OpBuilder &b, Location loc) {
-        Value extracted;
-        if (isa<MemRefType>(base.getType())) {
-          // `vector.load` does not support scalar result; emit a vector load
-          // and extract the single result instead.
-          Value load =
-              b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
-          int64_t zeroIdx[1] = {0};
-          extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
-        } else {
-          extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
-        }
-
-        Value newResult =
-            b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
-        b.create<scf::YieldOp>(loc, newResult);
-      };
-      auto passThruBuilder = [result](OpBuilder &b, Location loc) {
-        b.create<scf::YieldOp>(loc, result);
-      };
-
-      result =
-          rewriter
-              .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
-                                 /*elseBuilder=*/passThruBuilder)
-              .getResult(0);
-    }
-
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::vector::populateVectorMaskMaterializationPatterns(
@@ -3307,33 +1149,6 @@ void mlir::vector::populateBubbleVectorBitCastOpPatterns(
                                                      benefit);
 }
 
-void mlir::vector::populateVectorBroadcastLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<BroadcastOpLowering>(patterns.getContext(), benefit);
-}
-
-void mlir::vector::populateVectorMaskOpLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
-      patterns.getContext(), benefit);
-}
-
-void mlir::vector::populateVectorShapeCastLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<ShapeCastOp2DDownCastRewritePattern,
-               ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
-      patterns.getContext(), benefit);
-}
-
-void mlir::vector::populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options,
-    PatternBenefit benefit) {
-  patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
-  patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
-               ContractionOpToOuterProductOpLowering>(
-      options, patterns.getContext(), benefit);
-}
-
 void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
     RewritePatternSet &patterns,
     std::function<LogicalResult(vector::ContractionOp)> constraint,
@@ -3342,13 +1157,6 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
                                                 std::move(constraint));
 }
 
-void mlir::vector::populateVectorTransposeLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options,
-    PatternBenefit benefit) {
-  patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
-      options, patterns.getContext(), benefit);
-}
-
 void mlir::vector::populateVectorReductionToContractPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<MultiReduceToContract, CombineContractBroadcast,
@@ -3363,28 +1171,6 @@ void mlir::vector::
   patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
 }
 
-void mlir::vector::populateVectorTransferLoweringPatterns(
-    RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
-    PatternBenefit benefit) {
-  patterns.add<TransferReadToVectorLoadLowering,
-               TransferWriteToVectorStoreLowering>(patterns.getContext(),
-                                                   maxTransferRank, benefit);
-  patterns
-      .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
-          patterns.getContext(), benefit);
-}
-
-void mlir::vector::populateVectorScanLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
-}
-
-void mlir::vector::populateVectorGatherLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
-                                                          benefit);
-}
-
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f79ca2259fa38..7a4f9cf5e5101 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
@@ -148,8 +149,9 @@ struct TestVectorContractionLowering
     if (lowerToOuterProduct) {
       VectorContractLowering lowering = VectorContractLowering::OuterProduct;
       VectorTransformsOptions options{lowering};
-      patterns.add<ContractionOpToOuterProductOpLowering>(options,
-                                                          &getContext());
+      populateVectorContractLoweringPatterns(
+          patterns, options, /*benefit=*/1,
+          /*disableOuterProductlowering=*/true);
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
       return;
     }
@@ -469,7 +471,7 @@ struct TestVectorTransferFullPartialSplitPatterns
       options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
     else
       options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
-    patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
+    populateVectorTransferFullPartialPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8538c3db59dcd..f565030d63d9f 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8539,6 +8539,7 @@ cc_library(
         ":TransformDialect",
         ":TransformDialectUtils",
         ":TransformUtils",
+        ":VectorTransforms",
         "//llvm:Support",
     ],
 )


        


More information about the Mlir-commits mailing list