[Mlir-commits] [mlir] [Will be divided up] Vector linearization via strided ops (PR #142672)

James Newling llvmlistbot at llvm.org
Tue Jun 3 14:07:03 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/142672

This is the end point of a few changes I have made to vector linearization. I plan to split this into multiple PRs, but posting here as a reference. 

Changes/improvements:

- Make all patterns optional
- Option to lower to 1D strided ops (although 1D shuffle remains default)
- Support for scalar insert/extract


>From 6ad62b41cdc7350cc28593424c3d41e5733ed1ef Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 3 Jun 2025 13:56:03 -0700
Subject: [PATCH] squash commits

---
 .../Vector/Transforms/VectorLinearize.h       | 354 +++++++
 .../Vector/Transforms/VectorRewritePatterns.h |  33 -
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  20 +-
 .../Vector/Transforms/VectorLinearize.cpp     | 967 +++++++++++++-----
 .../linearize-insert-extract-preference.mlir  | 287 ++++++
 .../linearize-subject-to-bitwidth.mlir        |   0
 .../Vector/{ => linearize}/linearize.mlir     |   3 +-
 .../linearize/rank-reduce-strided-ops.mlir    | 135 +++
 mlir/test/lib/Dialect/Vector/CMakeLists.txt   |   1 +
 .../Dialect/Vector/TestVectorLinearize.cpp    | 248 +++++
 .../Dialect/Vector/TestVectorTransforms.cpp   | 159 ---
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 12 files changed, 1771 insertions(+), 438 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
 create mode 100644 mlir/test/Dialect/Vector/linearize/linearize-insert-extract-preference.mlir
 rename mlir/test/Dialect/Vector/{ => linearize}/linearize-subject-to-bitwidth.mlir (100%)
 rename mlir/test/Dialect/Vector/{ => linearize}/linearize.mlir (99%)
 create mode 100644 mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir
 create mode 100644 mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
new file mode 100644
index 0000000000000..6fc5bb13f9b07
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
@@ -0,0 +1,354 @@
+//===- VectorLinearize.h - Vector linearization patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace vector {
+
+/// Initialize `typeConverter` with source and target materialization logic
+/// using shape_casts to/from 1D vectors.
+void initializeForVectorLinearize(TypeConverter &typeConverter);
+
+/// This enum controls the patterns used for linearization of insert,
+/// insert_strided_slice, extract, and extract_strided_slice operations.
+enum class InsertExtractLinearizePreference {
+
+  /// The lowerings are
+  ///   insert,  insert_strided_slice  -> 1D shuffle
+  ///   extract, extract_strided_slice -> 1D shuffle
+  ///
+  /// Even 1D insert_strided_slice and extract_strided_slice are converted to 1D
+  /// shuffles. Insert and extract ops on scalar elements are not converted to
+  /// 1D shuffles.
+  Shuffle = 0,
+
+  /// The preferred lowerings are
+  ///   insert,  insert_strided_slice  -> 1D insert_strided_slice
+  ///   extract, extract_strided_slice -> 1D extract_strided_slice
+  ///
+  /// When these lowerings are not possible because the slices are not
+  /// contiguous, 1D shuffles are used.
+  Strided
+};
+
+/// Initialize `conversionTarget`, and `patterns` for linearization. Here
+/// linearization means converting a single operation with 1+ vector
+/// operand/result of rank>1, into a new single operation whose vector operands
+/// and results are all of rank<=1.
+///
+/// This function initializes `conversionTarget` with the set of operations that
+/// are illegal and consequently must be converted to a linearized form. It
+/// also populates the set of patterns that can be run to convert illegal
+/// operations, and what priority/benefit they have. The patterns and legality
+/// rules depend on `preference`, which controls the benefit associated to the
+/// patterns based on whether 1D shuffles or 1D strided ops are preferred.
+///
+/// Note: the set of legal operations can be extended by a user if, for example,
+/// certain rank>1 vectors are considered valid, by adding additional
+/// dynamically legal ops to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to make it easy to reuse generic structural type
+/// conversions for linearizing scf/cf/func operations
+void populateForFullVectorLinearize(
+    const TypeConverter &, ConversionTarget &conversionTarget,
+    RewritePatternSet &patterns,
+    InsertExtractLinearizePreference preference =
+        InsertExtractLinearizePreference::Shuffle);
+
+enum class LinearizePattern {
+
+  /// BEFORE
+  /// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+  ///
+  /// AFTER
+  /// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+  /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+  LinearizeConstantLike = 0,
+
+  /// BEFORE
+  /// %2 = math.sin %arg0 : vector<2x2xf32>
+  ///
+  /// AFTER
+  /// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+  /// %1 = math.sin %0 : vector<4xf32>
+  /// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+  LinearizeVectorizable,
+
+  /// BEFORE
+  /// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+  ///
+  /// AFTER
+  /// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+  /// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+  /// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+  LinearizeVectorBitCast,
+
+  /// This pattern currently only supports 2D masks with a unit outer
+  /// dimension.
+  ///
+  /// BEFORE
+  /// %mask_2d = vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+  ///
+  /// AFTER
+  /// [...]
+  /// %mask_1d= vector.create_mask %mul : vector<4xi1>
+  /// %mask_2d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+  ///
+  /// where `%mul` is a function of `%arg0` and `%arg1`.
+  LinearizeVectorCreateMask,
+
+  /// BEFORE
+  /// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+  ///
+  /// AFTER
+  /// %v1_1d = vector.shape_cast %v1_3d : [...]
+  /// %v2_1d = vector.shape_cast %v2_3d : [...]
+  /// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+  /// %shuffle_3d = vector.shape_cast %shuffle_1d :  [...]
+  ///
+  /// Where `shuffle_indices_1d` are computed by expanding `shuffle_indices`.
+  LinearizeVectorShuffle,
+
+  /// BEFORE
+  /// %1 = vector.splat %value : vector<4x4xf32>
+  ///
+  /// AFTER
+  /// %0 = vector.splat %value : vector<16xf32>
+  /// %1 = vector.shape_cast %0 : vector<16xf32> to vector<4x4xf32>
+  LinearizeVectorSplat,
+
+  /// Reduce the rank of a vector.extract_strided_slice to the lowest rank
+  /// possible. For extract_strided_slice ops that slice contiguous elements,
+  /// the reduced-rank op is 1D, otherwise it is higher dimensional.
+  ///
+  /// BEFORE
+  /// %2 = vector.extract_strided_slice %arg0 {
+  ///    offsets = [1, 0, 1, 0],
+  ///      sizes = [1, 2, 1, 2],
+  ///    strides = [1, 1, 1, 1]} : vector<2x2x2x2xi8> to vector<1x2x1x2xi8>
+  ///
+  /// AFTER
+  /// %0 = vector.shape_cast %arg0 : vector<2x2x2x2xi8> to vector<4x4xi8>
+  /// %1 = vector.extract_strided_slice %0 {
+  ///    offsets = [2, 2],
+  ///      sizes = [2, 2],
+  ///    strides = [1, 1]} : vector<4x4xi8> to vector<2x2xi8>
+  /// %2 = vector.shape_cast %1 : vector<2x2xi8> to vector<1x2x1x2xi8>
+  RankReduceExtractStridedSlice,
+
+  /// Similar to RankReduceExtractStridedSlice, but both the operands have
+  /// their rank reduced.
+  ///
+  /// BEFORE
+  /// %3 = vector.insert_strided_slice %arg1, %arg0 {[...]}
+  ///                vector<1x2x1x2xi8> into vector<2x2x2x2xi8>
+  ///
+  /// AFTER
+  /// %0 = vector.shape_cast %arg0 : vector<2x2x2x2xi8> to vector<4x4xi8>
+  /// %1 = vector.shape_cast %arg1 : vector<1x2x1x2xi8> to vector<2x2xi8>
+  /// %2 = vector.insert_strided_slice %1, %0 {[...]}
+  /// %3 = vector.shape_cast %2 : vector<4x4xi8> to vector<2x2x2x2xi8>
+  RankReduceInsertStridedSlice,
+
+  /// BEFORE
+  /// %extract = vector.extract %src [ position ]
+  ///
+  /// AFTER
+  /// %src_1d = vector.shape_cast %src : [...]
+  /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices ]
+  /// %out_nd = vector.shape_cast %out_1d : [...]
+  ///
+  /// `shuffle_indices` is computed from `position`.
+  VectorExtractToRankOneShuffle,
+
+  /// BEFORE
+  /// %out_nd = vector.extract_strided_slice %source_nd
+  ///         { offsets = [..], strides = [..], sizes = [..] }
+  ///
+  /// AFTER
+  /// %source_1d = vector.shape_cast %source_nd [...]
+  /// %out_1d    = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+  /// %out_nd    = vector.shape_cast %out_1d [...]
+  ///
+  /// `shuffle_indices_1d` is computed using the offsets and sizes of the
+  /// original vector.extract_strided_slice operation.
+  VectorExtractStridedSliceToRankOneShuffle,
+
+  /// BEFORE
+  /// %1 = vector.extract %arg0[1, 2] : vector<2x1xi8> from vector<4x3x2x1xi8>
+  ///
+  /// AFTER
+  /// %0 = vector.shape_cast %arg0 : vector<4x3x2x1xi8> to vector<24xi8>
+  /// %1 = vector.extract_strided_slice %0 {offsets = [10], sizes = [2] [...]
+  /// %2 = vector.shape_cast %1 : vector<2xi8> to vector<2x1xi8>
+  VectorExtractToRankOneStrided,
+
+  /// BEFORE
+  /// %insert = vector.insert %src %dst [ position ]
+  ///
+  /// AFTER
+  /// %src_1d = vector.shape_cast %src : [...]
+  /// %dst_1d = vector.shape_cast %dst : [...]
+  /// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ]
+  /// %out_nd = vector.shape_cast %out_1d : [...]
+  ///
+  /// `shuffle_indices` is computed from `position`.
+  VectorInsertToRankOneShuffle,
+
+  /// This pattern converts a vector.insert_strided_slice operation into a
+  /// vector.shuffle operation that has rank-1 (linearized) operands and result.
+  ///
+  /// BEFORE
+  /// %0 = vector.insert_strided_slice %to_store, %into
+  ///             {offsets = [1, 0, 0, 0], strides = [1, 1]}
+  ///                  : vector<2x2xi8> into vector<2x1x3x2xi8>
+  /// AFTER
+  /// %to_store_1d
+  ///          = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+  /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+  /// %out_1d  = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+  /// %out_nd  = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+  ///
+  /// where shuffle_indices_1d in this case is
+  ///     [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+  ///                        ^^^^^^^^^^^^^^
+  ///                          to_store_1d
+  VectorInsertStridedSliceToRankOneShuffle,
+
+  /// Similar to VectorExtractToRankOneStrided, but for insert_strided_slice.
+  VectorInsertToRankOneStrided,
+
+  /// The number of patterns in this enum.
+  N
+};
+
+/// This class contains functions to control the set of linearization patterns
+/// to include for the conversion, and their priority.
+struct VectorLinearizePatterns {
+
+public:
+  /// By default all patterns are enabled and have benefit 1.
+  VectorLinearizePatterns() {
+    enabled.fill(true);
+    benefits.fill(PatternBenefit(1));
+  }
+
+  /// Add the patterns enabled for the conversion to `patterns`.
+  void addToPatternSet(const TypeConverter &,
+                       RewritePatternSet &patterns) const;
+
+  VectorLinearizePatterns &enable(LinearizePattern id, bool e = true) {
+    enabled[static_cast<unsigned>(id)] = e;
+    return *this;
+  }
+
+  VectorLinearizePatterns &enableAll(bool e = true) {
+    enabled.fill(e);
+    return *this;
+  }
+
+  bool isEnabled(LinearizePattern id) const {
+    return enabled[static_cast<unsigned>(id)];
+  }
+
+  PatternBenefit getBenefit(LinearizePattern id) const {
+    return benefits[static_cast<unsigned>(id)];
+  }
+
+  VectorLinearizePatterns &setBenefit(LinearizePattern id,
+                                      PatternBenefit benefit) {
+    getBenefitRef(id) = benefit;
+    return *this;
+  }
+
+  VectorLinearizePatterns &incrementBenefit(LinearizePattern id,
+                                            unsigned inc = 1) {
+    getBenefitRef(id) = getBenefit(id).getBenefit() + 1;
+    return *this;
+  }
+
+private:
+  std::array<bool, static_cast<unsigned>(LinearizePattern::N)> enabled;
+  std::array<PatternBenefit, static_cast<unsigned>(LinearizePattern::N)>
+      benefits;
+
+  PatternBenefit &getBenefitRef(LinearizePattern id) {
+    unsigned idInt = static_cast<unsigned>(id);
+    assert(idInt < static_cast<unsigned>(LinearizePattern::N) &&
+           "invalid linearization pattern id");
+    return benefits[idInt];
+  }
+};
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumeratates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+SmallVector<int64_t> getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+                                                     ArrayRef<int64_t> large,
+                                                     ArrayRef<int64_t> offsets);
+
+/// Return the strided slice with the lowest rank that is equivalent to the
+/// strided slice of `small` from `large`, starting at `offsets`. The result is
+/// a tuple of three vectors:
+///
+/// 0) The shape of the new small vector.
+/// 1) The shape of the new large vector.
+/// 2) The offsets of the new large vector.
+///
+/// Example 1 (contiguous slices can always be represented in 1-D).
+///
+/// Input:
+///  small  = (1, 3, 4)
+///  large  = (3, 3, 4)
+///  offset = (2, 3, 4)
+///
+/// Output:
+///  small  = (12)
+///  large  = (36)
+///  offset = (24)
+///
+/// Example 2 (a non-contiguous slice)
+///
+/// Input:
+///  small  =    (2, 2, 1, 2)
+///  large  = (2, 2, 2, 2, 2)
+///  offset = (1, 1, 0, 1)
+///
+///
+/// Output:
+///  small  =  (4, 2)
+///  large  =  (8, 4)
+///  offset = (24, 2)
+std::array<SmallVector<int64_t>, 3>
+getCollapsedStridedSliceShape(ArrayRef<int64_t> small, ArrayRef<int64_t> large,
+                              ArrayRef<int64_t> offsets);
+
+std::optional<std::array<SmallVector<int64_t>, 3>>
+getCollapsedExtractStridedSliceShape(vector::ExtractStridedSliceOp extractOp);
+
+std::optional<std::array<SmallVector<int64_t>, 3>>
+getCollapsedInsertStridedSliceShape(vector::InsertStridedSliceOp insertOp);
+
+} // namespace vector
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..6954cb7172129 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,39 +406,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-///
-/// Definition: here 'linearization' means converting a single operation with
-/// 1+ vector operand/result of rank>1, into a new single operation whose
-/// vector operands and results are all of rank<=1.
-///
-/// This function registers (1) which operations are legal, and hence should not
-/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
-/// materialze the conversion (with shape_cast)
-///
-/// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, by adding additional
-/// dynamically legal ops to `conversionTarget`.
-///
-/// Further note: the choice to use a dialect conversion design for
-/// linearization is to make it easy to reuse generic structural type
-/// conversions for linearizing scf/cf/func operations
-void populateForVectorLinearize(TypeConverter &typeConverter,
-                                ConversionTarget &conversionTarget);
-
-/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
-/// contains patterns for converting ConstantLike, Vectorizable, and
-/// vector::BitCast ops.
-void populateVectorLinearizeBasePatterns(const TypeConverter &,
-                                         const ConversionTarget &,
-                                         RewritePatternSet &patterns);
-
-/// Populates `patterns` for linearizing ND (N >= 2) vector operations
-/// to 1D vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
-                                                   const ConversionTarget &,
-                                                   RewritePatternSet &patterns);
-
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fcfb401fd9867..359b2ba091967 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5726,13 +5726,21 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
 
-  // No-op shape cast.
-  if (getSource().getType() == resultType)
-    return getSource();
 
-  // shape_cast(shape_cast(x)) -> shape_cast(x)
-  if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
-    setOperand(precedingShapeCast.getSource());
+  // y = shape_cast(shape_cast(shape_cast(x))) 
+  //    -> shape_cast(x) # if x and y different types
+  //    -> x             # if x and y same type
+  // Value newSource  = getSource();
+  ShapeCastOp parent = *this;
+  while (auto precedingShapeCast = parent.getSource().getDefiningOp<ShapeCastOp>()) {
+    parent = precedingShapeCast;
+  }
+
+  if (parent.getSource().getType() == resultType)
+    return parent.getSource();
+
+  if (parent != *this){
+    setOperand(parent.getSource());
     return getResult();
   }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..fae452d8e5dc9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,9 +10,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Operation.h"
@@ -45,6 +46,408 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
   return rewriter.notifyMatchFailure(loc, "unsupported attr type");
 }
 
+/// Convert an array of attributes into a vector of integers.
+static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
+  if (!attrs)
+    return failure();
+  SmallVector<int64_t> ints;
+  ints.reserve(attrs.size());
+  for (auto attr : attrs) {
+    if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+      ints.push_back(intAttr.getInt());
+    } else {
+      return failure();
+    }
+  }
+  return ints;
+}
+
+/// Convert OpFoldResults into a vector of integers, failing when an
+/// OpFoldResult is not an Attribute (unless the dimension in `shape` is 1, in
+/// which case the offset is 0, irrespective). Ensure that the returned vector
+/// is of the same rank as `shape` by appending zeros.
+static FailureOr<SmallVector<int64_t>>
+getIntegerOffsetsFromFoldResults(ArrayRef<OpFoldResult> offsetFoldResults,
+                                 ArrayRef<int64_t> shape) {
+  assert(shape.size() >= offsetFoldResults.size() &&
+         "offsets assumed not be be higher rank than shape");
+  unsigned deltaRank = shape.size() - offsetFoldResults.size();
+  SmallVector<int64_t> offsets;
+  offsets.reserve(offsetFoldResults.size());
+  for (auto [offsetFoldResult, dimSize] :
+       llvm::zip(offsetFoldResults, shape.drop_back(deltaRank))) {
+    if (dimSize == 1) {
+      offsets.push_back(0);
+    } else if (auto offsetAttr = dyn_cast<Attribute>(offsetFoldResult)) {
+      offsets.push_back(cast<IntegerAttr>(offsetAttr).getInt());
+    } else {
+      return failure();
+    }
+  }
+  offsets.resize(shape.size(), 0);
+  return offsets;
+}
+
+/// If `ndIndex` is the index in the n-dimensional array of shape `shape`, get
+/// the corresponding index into the flattened array.
+static int64_t getIndexInFlattened(ArrayRef<int64_t> ndIndex,
+                                   ArrayRef<int64_t> shape) {
+  assert(ndIndex.size() == shape.size() &&
+         "ndIndex and shape assumed to have the same size");
+  int64_t index = 0;
+  int64_t stride = 1;
+  for (int i = shape.size() - 1; i >= 0; --i) {
+    index += ndIndex[i] * stride;
+    stride *= shape[i];
+  }
+  return index;
+}
+
+/// Return true if `op` is an insert, extract, insert_strided_slice, or
+/// extract_strided_slice operation that operates on scalable vectors.
+/// Otherwise return false.
+static bool isScalableExtractOrInsertOrStrided(Operation *op) {
+  return TypeSwitch<Operation *, bool>(op)
+      .Case<vector::ExtractStridedSliceOp>(
+          [&](vector::ExtractStridedSliceOp extractOp) {
+            return extractOp.getType().isScalable();
+          })
+      .Case<vector::InsertStridedSliceOp>(
+          [&](vector::InsertStridedSliceOp insertOp) {
+            return insertOp.getType().isScalable();
+          })
+      .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
+        return insertOp.getType().isScalable();
+      })
+      .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
+        return extractOp.getSourceVectorType().isScalable();
+      })
+      .Default([&](auto) { return false; });
+}
+
+SmallVector<int64_t>
+vector::getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+                                        ArrayRef<int64_t> large,
+                                        ArrayRef<int64_t> offsets) {
+
+  // Example of alignment between, `large`, `small` and `offsets`:
+  //    large  =  4, 5, 6, 7, 8
+  //    small  =     1, 6, 7, 8
+  //  offsets  =  2, 3, 0
+  //
+  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
+  assert((large.size() >= small.size()) &&
+         "rank of 'large' cannot be lower than rank of 'small'");
+  assert((large.size() >= offsets.size()) &&
+         "rank of 'large' cannot be lower than the number of offsets");
+  unsigned delta = large.size() - small.size();
+  unsigned nOffsets = offsets.size();
+  auto getSmall = [&](int64_t i) -> int64_t {
+    return i >= delta ? small[i - delta] : 1;
+  };
+  auto getOffset = [&](int64_t i) -> int64_t {
+    return i < nOffsets ? offsets[i] : 0;
+  };
+
+  // Using 2 vectors of indices, at each iteration populate the updated set of
+  // indices based on the old set of indices, and the size of the small vector
+  // in the current iteration.
+  SmallVector<int64_t> indices{0};
+  int64_t stride = 1;
+  for (int i = large.size() - 1; i >= 0; --i) {
+    int64_t currentSize = indices.size();
+    int64_t smallSize = getSmall(i);
+    int64_t nextSize = currentSize * smallSize;
+    SmallVector<int64_t> nextIndices(nextSize);
+    int64_t *base = nextIndices.begin();
+    int64_t offset = getOffset(i) * stride;
+    for (int j = 0; j < smallSize; ++j) {
+      for (int k = 0; k < currentSize; ++k) {
+        base[k] = indices[k] + offset;
+      }
+      offset += stride;
+      base += currentSize;
+    }
+    stride *= large[i];
+    indices = std::move(nextIndices);
+  }
+  return indices;
+}
+
+void vector::initializeForVectorLinearize(TypeConverter &typeConverter) {
+
+  auto convertType = [](Type type) -> std::optional<Type> {
+    VectorType vectorType = dyn_cast<VectorType>(type);
+
+    if (!vectorType || !vector::isLinearizableVector(vectorType))
+      return type;
+
+    VectorType linearizedType =
+        VectorType::get(vectorType.getNumElements(),
+                        vectorType.getElementType(), vectorType.isScalable());
+
+    return linearizedType;
+  };
+  typeConverter.addConversion(convertType);
+
+  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                            Location loc) -> Value {
+    if (inputs.size() != 1) {
+      return nullptr;
+    }
+    Value value = inputs.front();
+    if (!isa<VectorType>(type) || !isa<VectorType>(value.getType())) {
+      return nullptr;
+    }
+    return builder.create<vector::ShapeCastOp>(loc, type, value);
+  };
+  typeConverter.addSourceMaterialization(materializeCast);
+  typeConverter.addTargetMaterialization(materializeCast);
+}
+
+void vector::populateForFullVectorLinearize(
+    const TypeConverter &typeConverter, ConversionTarget &target,
+    RewritePatternSet &patterns, InsertExtractLinearizePreference preference) {
+
+  target.markUnknownOpDynamicallyLegal(
+      [=](Operation *op) -> std::optional<bool> {
+        // Only ops that are in the vector dialect, are ConstantLike, or
+        // are Vectorizable might be linearized currently.
+        StringLiteral vectorDialect =
+            vector::VectorDialect::getDialectNamespace();
+        StringRef opDialect = op->getDialect()->getNamespace();
+        bool supported = (opDialect == vectorDialect) ||
+                         op->hasTrait<OpTrait::ConstantLike>() ||
+                         op->hasTrait<OpTrait::Vectorizable>();
+        if (!supported)
+          return true;
+
+        // As type legalization is done with vector.shape_cast, shape_cast
+        // itself cannot be linearized (doing so would create new shape_casts to
+        // linearize ad infinitum).
+        if (isa<vector::ShapeCastOp>(op))
+          return true;
+
+        // The operations extract_strided_slice, extract, insert_strided_slice,
+        // and insert are linearized to a rank-1 operations that do not fully
+        // support scalable vectors, so it is not generally possible to
+        // linearize these ops if they operate on scalable vectors.
+        if (isScalableExtractOrInsertOrStrided(op))
+          return true;
+
+        // This will return true if, for all operand and result types `t`,
+        // convertType(t) = t. This is true if there are no rank>=2 vectors.
+        return typeConverter.isLegal(op);
+      });
+
+  VectorLinearizePatterns linearizePatterns;
+
+  if (preference == InsertExtractLinearizePreference::Shuffle) {
+    // Mark extract_strided_slice, insert_strided_slice, extract with source
+    // rank > 1, and insert with result rank > 1 as illegal, as they must be
+    // converted to shuffle or rank-1 extract/insert.
+    //
+    // Note that the order of the calls to `markUnknownOpDynamicallyLegal`
+    // is important: the legality rule added here takes precedence over the
+    // generic one preceding it which marked these ops as legal.
+    target.markUnknownOpDynamicallyLegal(
+        [](Operation *op) -> std::optional<bool> {
+          bool isStrided =
+              isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp>(
+                  op);
+
+          bool isHighRankExtractOrInsert = [&]() {
+            if (auto extractOp = dyn_cast<vector::ExtractOp>(op)) {
+              return extractOp.getSourceVectorType().getRank() > 1;
+            }
+            if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+              return insertOp.getType().getRank() > 1;
+            }
+            return false;
+          }();
+
+          bool isScalable = isScalableExtractOrInsertOrStrided(op);
+
+          if ((isStrided || isHighRankExtractOrInsert) && !isScalable) {
+            return false;
+          }
+          return std::nullopt;
+        });
+
+    // Ensure that the benefit of patterns targetting shuffle is higher than
+    // the benefit of patterns targeting rank-1 strided slice operations. This
+    // will ensure that patterns for converting to rank-1 shuffle are run first.
+    linearizePatterns
+        .incrementBenefit(
+            LinearizePattern::VectorExtractStridedSliceToRankOneShuffle)
+        .incrementBenefit(
+            LinearizePattern::VectorInsertStridedSliceToRankOneShuffle)
+        .incrementBenefit(LinearizePattern::VectorExtractToRankOneShuffle)
+        .incrementBenefit(LinearizePattern::VectorInsertToRankOneShuffle);
+
+  } else if (preference == InsertExtractLinearizePreference::Strided) {
+    linearizePatterns
+        .incrementBenefit(LinearizePattern::RankReduceInsertStridedSlice)
+        .incrementBenefit(LinearizePattern::RankReduceExtractStridedSlice)
+        .incrementBenefit(LinearizePattern::VectorInsertToRankOneStrided)
+        .incrementBenefit(LinearizePattern::VectorExtractToRankOneStrided);
+  } else {
+    assert(false && "unsupported InsertExtractLinearizePreference");
+  }
+  linearizePatterns.addToPatternSet(typeConverter, patterns);
+}
+
+/// Get the lowest rank shapes and offsets which represent the same strided
+/// slice as the strided slice described by `small`, `large`, and `offsets`.
+///
+/// Example
+///
+/// %0 = vector.extract_strided_slice %1
+///       {ofsets = [0, 0, 0], sizes = [2, 2, 2], strides = [1, 1, 1]} :
+///       vector<4x2x4xf32> to vector<2x2x2xf32>
+///
+/// is equivalent to
+///
+/// [...rank reducing shape casts...]
+/// %0 = vector.extract_strided_slice %1
+///      {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} :
+///      vector<8x4xf32> to vector<4x2xf32>
+/// [...rank increasing shape cast...]
+///
+/// So the output for
+///  (small, large, offsets = [2, 2, 2], [4, 2, 4], [0, 0, 0]) is
+///  (small, large, offsets = [4, 2],    [8, 4],    [0, 0])
+std::array<SmallVector<int64_t>, 3>
+vector::getCollapsedStridedSliceShape(ArrayRef<int64_t> small,
+                                      ArrayRef<int64_t> large,
+                                      ArrayRef<int64_t> offsets) {
+
+  // The total number of elements in the small (large, respectively) vector.
+  int64_t tSmall = std::accumulate(small.begin(), small.end(), 1,
+                                   std::multiplies<int64_t>());
+  int64_t tLarge = std::accumulate(large.begin(), large.end(), 1,
+                                   std::multiplies<int64_t>());
+  assert(tLarge >= tSmall &&
+         "total number of elements in 'small' is larger than in 'large'");
+  assert(large.size() >= small.size() &&
+         "rank of 'small' is larger than  rank of 'large'");
+  assert(offsets.size() <= large.size() &&
+         "rank of large is less than the number of offsets");
+
+  int64_t nOffsets = offsets.size();
+  auto getOffset = [&](int64_t i) -> int64_t {
+    return i < nOffsets ? offsets[i] : 0;
+  };
+
+  unsigned delta = large.size() - small.size();
+
+  // The cumulative (product of dimensions) number of elements from the back
+  // currently visited in the small (large, respectively) vector.
+  int64_t nSmall = 1;
+  int64_t nLarge = 1;
+
+  // The cumulative number (product of dimensions) of elements from the back
+  // currently visited within the current collapse group in the small (large,
+  // respectively) vector.
+  int64_t cSmall = 1;
+  int64_t cLarge = 1;
+
+  SmallVector<int64_t> newSmall, newLarge, newOffsets;
+  if (large.size() == 0)
+    return {newSmall, newLarge, newOffsets};
+
+  // The offset assigned to the current collapse group.
+  int64_t cOff = 0;
+
+  unsigned index = large.size() - 1;
+  while (nLarge < tLarge) {
+    assert(cSmall <= nSmall && nSmall <= tSmall && //
+           cLarge <= nLarge && nLarge <= tLarge &&
+           "confusion in element accumulation");
+    cOff += getOffset(index) * cLarge;
+    if (nSmall < tSmall) {
+      cSmall *= small[index - delta];
+      nSmall *= small[index - delta];
+    }
+    cLarge *= large[index];
+    nLarge *= large[index];
+    if ((nSmall < tSmall) && (large[index] != small[index - delta])) {
+      newSmall.push_back(cSmall);
+      newLarge.push_back(cLarge);
+      newOffsets.push_back(cOff);
+      cSmall = 1;
+      cLarge = 1;
+      cOff = 0;
+    }
+    --index;
+  }
+  newSmall.push_back(cSmall);
+  newLarge.push_back(cLarge);
+  newOffsets.push_back(cOff);
+  std::reverse(newSmall.begin(), newSmall.end());
+  std::reverse(newLarge.begin(), newLarge.end());
+  std::reverse(newOffsets.begin(), newOffsets.end());
+  return {newSmall, newLarge, newOffsets};
+}
+
+// returns small, large, offsets.
+std::optional<std::array<SmallVector<int64_t>, 3>>
+vector::getCollapsedExtractStridedSliceShape(
+    vector::ExtractStridedSliceOp extractOp) {
+
+  if (extractOp.hasNonUnitStrides())
+    return std::nullopt;
+
+  ArrayRef<int64_t> outShape = extractOp.getType().getShape();
+  ArrayRef<int64_t> inShape = extractOp.getSourceVectorType().getShape();
+
+  auto maybeIntOffsets = intsFromArrayAttr(extractOp.getOffsets());
+  if (failed(maybeIntOffsets))
+    return std::nullopt;
+
+  SmallVector<int64_t> offsets = std::move(maybeIntOffsets.value());
+  const auto &[collapsedOutShape, collapsedInShape, collapsedOffsets] =
+      getCollapsedStridedSliceShape(outShape, inShape, offsets);
+
+  bool unchanged = (collapsedInShape.size() == inShape.size()) &&
+                   (collapsedOutShape.size() == outShape.size());
+
+  if (unchanged)
+    return std::nullopt;
+
+  return std::array<SmallVector<int64_t>, 3>{
+      collapsedOutShape, collapsedInShape, collapsedOffsets};
+}
+
+// returns small, large, offsets.
+std::optional<std::array<SmallVector<int64_t>, 3>>
+vector::getCollapsedInsertStridedSliceShape(
+    vector::InsertStridedSliceOp insertOp) {
+
+  if (insertOp.hasNonUnitStrides())
+    return std::nullopt;
+
+  ArrayRef<int64_t> outShape = insertOp.getType().getShape();
+  ArrayRef<int64_t> inShape = insertOp.getSourceVectorType().getShape();
+
+  auto maybeIntOffsets = intsFromArrayAttr(insertOp.getOffsets());
+  if (failed(maybeIntOffsets))
+    return std::nullopt;
+
+  SmallVector<int64_t> offsets = std::move(maybeIntOffsets.value());
+  const auto &[collapsedInShape, collapsedOutShape, collapsedOffsets] =
+      getCollapsedStridedSliceShape(inShape, outShape, offsets);
+
+  bool unchanged = (collapsedInShape.size() == inShape.size()) &&
+                   (collapsedOutShape.size() == outShape.size());
+
+  if (unchanged)
+    return std::nullopt;
+
+  return std::array<SmallVector<int64_t>, 3>{
+      collapsedInShape, collapsedOutShape, collapsedOffsets};
+}
+
 namespace {
 
 struct LinearizeConstantLike final
@@ -52,7 +455,7 @@ struct LinearizeConstantLike final
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
   LinearizeConstantLike(const TypeConverter &typeConverter,
-                        MLIRContext *context, PatternBenefit benefit = 1)
+                        MLIRContext *context, PatternBenefit benefit)
       : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -94,7 +497,7 @@ struct LinearizeVectorizable final
 
 public:
   LinearizeVectorizable(const TypeConverter &typeConverter,
-                        MLIRContext *context, PatternBenefit benefit = 1)
+                        MLIRContext *context, PatternBenefit benefit)
       : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -109,90 +512,6 @@ struct LinearizeVectorizable final
   }
 };
 
-template <typename TOp>
-static bool stridesAllOne(TOp op) {
-  static_assert(
-      std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
-          std::is_same_v<TOp, vector::InsertStridedSliceOp>,
-      "expected vector.extract_strided_slice or vector.insert_strided_slice");
-  ArrayAttr strides = op.getStrides();
-  return llvm::all_of(strides, isOneInteger);
-}
-
-/// Convert an array of attributes into a vector of integers, if possible.
-static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
-  if (!attrs)
-    return failure();
-  SmallVector<int64_t> ints;
-  ints.reserve(attrs.size());
-  for (auto attr : attrs) {
-    if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
-      ints.push_back(intAttr.getInt());
-    } else {
-      return failure();
-    }
-  }
-  return ints;
-}
-
-/// Consider inserting a vector of shape `small` into a vector of shape `large`,
-/// at position `offsets`: this function enumeratates all the indices in `large`
-/// that are written to. The enumeration is with row-major ordering.
-///
-/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
-/// positions written to are (1,3) and (1,4), which have linearized indices 8
-/// and 9. So [8,9] is returned.
-///
-/// The length of the returned vector is equal to the number of elements in
-/// the shape `small` (i.e. the product of dimensions of `small`).
-SmallVector<int64_t> static getStridedSliceInsertionIndices(
-    ArrayRef<int64_t> small, ArrayRef<int64_t> large,
-    ArrayRef<int64_t> offsets) {
-
-  // Example of alignment between, `large`, `small` and `offsets`:
-  //    large  =  4, 5, 6, 7, 8
-  //    small  =     1, 6, 7, 8
-  //  offsets  =  2, 3, 0
-  //
-  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
-  assert((large.size() >= small.size()) &&
-         "rank of 'large' cannot be lower than rank of 'small'");
-  assert((large.size() >= offsets.size()) &&
-         "rank of 'large' cannot be lower than the number of offsets");
-  unsigned delta = large.size() - small.size();
-  unsigned nOffsets = offsets.size();
-  auto getSmall = [&](int64_t i) -> int64_t {
-    return i >= delta ? small[i - delta] : 1;
-  };
-  auto getOffset = [&](int64_t i) -> int64_t {
-    return i < nOffsets ? offsets[i] : 0;
-  };
-
-  // Using 2 vectors of indices, at each iteration populate the updated set of
-  // indices based on the old set of indices, and the size of the small vector
-  // in the current iteration.
-  SmallVector<int64_t> indices{0};
-  int64_t stride = 1;
-  for (int i = large.size() - 1; i >= 0; --i) {
-    int64_t currentSize = indices.size();
-    int64_t smallSize = getSmall(i);
-    int64_t nextSize = currentSize * smallSize;
-    SmallVector<int64_t> nextIndices(nextSize);
-    int64_t *base = nextIndices.begin();
-    int64_t offset = getOffset(i) * stride;
-    for (int j = 0; j < smallSize; ++j) {
-      for (int k = 0; k < currentSize; ++k) {
-        base[k] = indices[k] + offset;
-      }
-      offset += stride;
-      base += currentSize;
-    }
-    stride *= large[i];
-    indices = std::move(nextIndices);
-  }
-  return indices;
-}
-
 /// This pattern converts a vector.extract_strided_slice operation into a
 /// vector.shuffle operation that has a rank-1 (linearized) operand and result.
 ///
@@ -212,12 +531,12 @@ SmallVector<int64_t> static getStridedSliceInsertionIndices(
 ///
 /// `shuffle_indices_1d` is computed using the offsets and sizes of the original
 /// vector.extract_strided_slice operation.
-struct LinearizeVectorExtractStridedSlice final
-    : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
+struct VectorExtractStridedSliceToRankOneShuffle final
+    : public OpConversionPattern<vector::ExtractStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
-                                     MLIRContext *context,
-                                     PatternBenefit benefit = 1)
+  VectorExtractStridedSliceToRankOneShuffle(const TypeConverter &typeConverter,
+                                            MLIRContext *context,
+                                            PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
@@ -231,7 +550,7 @@ struct LinearizeVectorExtractStridedSlice final
 
     // Expect a legalization failure if the strides are not all 1 (if ever the
     // verifier for extract_strided_slice allows non-1 strides).
-    if (!stridesAllOne(extractStridedSliceOp)) {
+    if (extractStridedSliceOp.hasNonUnitStrides()) {
       return rewriter.notifyMatchFailure(
           extractStridedSliceOp,
           "extract_strided_slice with strides != 1 not supported");
@@ -249,7 +568,7 @@ struct LinearizeVectorExtractStridedSlice final
 
     ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
 
-    SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
+    SmallVector<int64_t> indices = vector::getStridedSliceInsertionIndices(
         outputShape, inputShape, offsets.value());
 
     Value srcVector = adaptor.getVector();
@@ -259,36 +578,24 @@ struct LinearizeVectorExtractStridedSlice final
   }
 };
 
-/// This pattern converts a vector.insert_strided_slice operation into a
-/// vector.shuffle operation that has rank-1 (linearized) operands and result.
-///
-/// For example, the following:
-/// ```
-///  %0 = vector.insert_strided_slice %to_store, %into
-///             {offsets = [1, 0, 0, 0], strides = [1, 1]}
-///                  : vector<2x2xi8> into vector<2x1x3x2xi8>
-/// ```
-///
-/// is converted to
-/// ```
-///  %to_store_1d
-///           = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
-///  %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
-///  %out_1d  = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
-///  %out_nd  = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
-/// ```
-///
-/// where shuffle_indices_1d in this case is
-///     [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
-///                        ^^^^^^^^^^^^^^
-///                          to_store_1d
-///
-struct LinearizeVectorInsertStridedSlice final
-    : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
+static Value asRankOne(ConversionPatternRewriter &rewriter, Value v) {
+  auto vType = dyn_cast<VectorType>(v.getType());
+  assert(vType && "expected vector type");
+  assert(vType.getRank() <= 1 && "expected rank-0 or rank-1 type");
+  if (vType.getRank() == 1)
+    return v;
+  // Convert rank-0 vector to rank-1 vector.
+  v = rewriter.create<vector::ShapeCastOp>(
+      v.getLoc(), VectorType::get({1}, vType.getElementType()), v);
+  return v;
+}
+
+struct VectorInsertStridedSliceToRankOneShuffle final
+    : public OpConversionPattern<vector::InsertStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
-                                    MLIRContext *context,
-                                    PatternBenefit benefit = 1)
+  VectorInsertStridedSliceToRankOneShuffle(const TypeConverter &typeConverter,
+                                           MLIRContext *context,
+                                           PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
@@ -298,7 +605,7 @@ struct LinearizeVectorInsertStridedSlice final
 
     // Expect a legalization failure if the strides are not all 1 (if ever the
     // verifier for insert_strided_slice allows non-1 strides).
-    if (!stridesAllOne(insertStridedSliceOp)) {
+    if (insertStridedSliceOp.hasNonUnitStrides()) {
       return rewriter.notifyMatchFailure(
           insertStridedSliceOp,
           "insert_strided_slice with strides != 1 not supported");
@@ -317,7 +624,7 @@ struct LinearizeVectorInsertStridedSlice final
       return rewriter.notifyMatchFailure(insertStridedSliceOp,
                                          "failed to get integer offsets");
     }
-    SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
+    SmallVector<int64_t> sliceIndices = vector::getStridedSliceInsertionIndices(
         inputShape, outputShape, offsets.value());
 
     SmallVector<int64_t> indices(nOutputElements);
@@ -326,7 +633,7 @@ struct LinearizeVectorInsertStridedSlice final
       indices[sliceIndex] = index + nOutputElements;
     }
 
-    Value flatToStore = adaptor.getValueToStore();
+    Value flatToStore = asRankOne(rewriter, adaptor.getValueToStore());
     Value flatDest = adaptor.getDest();
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
                                                    flatDest.getType(), flatDest,
@@ -350,7 +657,7 @@ struct LinearizeVectorShuffle final
     : public OpConversionPattern<vector::ShuffleOp> {
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorShuffle(const TypeConverter &typeConverter,
-                         MLIRContext *context, PatternBenefit benefit = 1)
+                         MLIRContext *context, PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
@@ -360,8 +667,8 @@ struct LinearizeVectorShuffle final
         getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
     assert(dstType && "vector type destination expected.");
 
-    Value vec1 = adaptor.getV1();
-    Value vec2 = adaptor.getV2();
+    Value vec1 = asRankOne(rewriter, adaptor.getV1());
+    Value vec2 = asRankOne(rewriter, adaptor.getV2());
     int shuffleSliceLen = 1;
     int rank = shuffleOp.getV1().getType().getRank();
 
@@ -404,11 +711,11 @@ struct LinearizeVectorShuffle final
 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
 ///   %out_nd = vector.shape_cast %out_1d
 /// `shuffle_indices_1d` is computed using the position of the original extract.
-struct LinearizeVectorExtract final
+struct VectorExtractToRankOneShuffle final
     : public OpConversionPattern<vector::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorExtract(const TypeConverter &typeConverter,
-                         MLIRContext *context, PatternBenefit benefit = 1)
+  VectorExtractToRankOneShuffle(const TypeConverter &typeConverter,
+                                MLIRContext *context, PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
@@ -436,11 +743,120 @@ struct LinearizeVectorExtract final
       linearizedOffset += offsets[i] * size;
     }
 
+    Value v0 = asRankOne(rewriter, adaptor.getVector());
     llvm::SmallVector<int64_t, 2> indices(size);
     std::iota(indices.begin(), indices.end(), linearizedOffset);
-    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractOp, dstTy, v0, v0,
+                                                   indices);
+
+    return success();
+  }
+};
+
+/// Convert a vector.extract op with input rank > 1, to an operation with input
+/// of rank 1 and output of rank <= 1. Two lowering cases:
+///
+/// 1) If the result of the vector.extract is a scalar, convert it to a
+///    vector.extract on a rank-1 input which still outputs a scalar.
+///
+/// 2) Otherwise, convert to an extract_strided_slice op on a vector of rank-1.
+struct VectorExtractToRankOneStrided final
+    : public OpConversionPattern<vector::ExtractOp> {
+  using OpConversionPattern::OpConversionPattern;
+  VectorExtractToRankOneStrided(const TypeConverter &typeConverter,
+                                MLIRContext *context, PatternBenefit benefit)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    // TypedValue<VectorType> input = extractOp.getVector();
+    VectorType inType = extractOp.getVector().getType();
+    if (inType.getRank() == 1)
+      return failure();
+
+    SmallVector<OpFoldResult> offsets = extractOp.getMixedPosition();
+    auto maybeIntOffsets =
+        getIntegerOffsetsFromFoldResults(offsets, inType.getShape());
+    if (failed(maybeIntOffsets)) {
+      return failure();
+    }
+    const auto &intOffsets = maybeIntOffsets.value();
+    int64_t globalOffset = getIndexInFlattened(intOffsets, inType.getShape());
+
+    Location loc = extractOp.getLoc();
+
+    Type outType = extractOp.getType();
+
+    // Case 1 described above:
+    if (outType.isIntOrIndexOrFloat()) {
+      Value flattened = rewriter.create<vector::ExtractOp>(
+          loc, adaptor.getVector(), SmallVector<int64_t>{globalOffset});
+      rewriter.replaceOp(extractOp, flattened);
+      return success();
+    }
+
+    VectorType vOutType = dyn_cast<VectorType>(outType);
+    assert(vOutType && "expected vector type for output");
+
+    auto numberElementsOut = vOutType.getNumElements();
+    auto strided = rewriter.create<vector::ExtractStridedSliceOp>(
+        loc, adaptor.getVector(), SmallVector<int64_t>{globalOffset},
+        SmallVector<int64_t>{numberElementsOut}, SmallVector<int64_t>{1});
+
+    rewriter.replaceOp(extractOp, strided);
+    return success();
+  }
+};
+
+/// Convert vector.insert where the destination is rank > 1. Two cases:
+///
+/// 1) If the source to insert is a scalar, convert to a vector.insert op
+///    where the destination is rank-1.
+///
+/// 2) Otherwise, convert to a vector.insert_strided_slice op into a vector of
+///    rank-1.
+struct VectorInsertToRankOneStrided final
+    : public OpConversionPattern<vector::InsertOp> {
+  using OpConversionPattern::OpConversionPattern;
+  VectorInsertToRankOneStrided(const TypeConverter &typeConverter,
+                               MLIRContext *context, PatternBenefit benefit)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType largeType = insertOp.getDest().getType();
+    Type smallType = insertOp.getValueToStoreType();
+    SmallVector<OpFoldResult> positions = insertOp.getMixedPosition();
+    auto maybeIntOffsets =
+        getIntegerOffsetsFromFoldResults(positions, largeType.getShape());
+    if (failed(maybeIntOffsets)) {
+      return failure();
+    }
+    const auto &intOffsets = maybeIntOffsets.value();
+    int64_t globalOffset =
+        getIndexInFlattened(intOffsets, largeType.getShape());
+
+    Location loc = insertOp.getLoc();
+
+    // case 1
+    if (smallType.isSignlessIntOrFloat()) {
+      auto flatOut = rewriter.create<vector::InsertOp>(
+          loc, adaptor.getValueToStore(), adaptor.getDest(),
+          SmallVector<int64_t>{globalOffset});
+      rewriter.replaceOp(insertOp, flatOut);
+      return success();
+    }
 
+    // case 2
+    Value v0 = asRankOne(rewriter, adaptor.getValueToStore());
+    auto flatOut = rewriter.create<vector::InsertStridedSliceOp>(
+        insertOp.getLoc(), v0, adaptor.getDest(),
+        SmallVector<int64_t>{globalOffset}, SmallVector<int64_t>{1});
+    rewriter.replaceOp(insertOp, flatOut);
     return success();
   }
 };
@@ -455,11 +871,11 @@ struct LinearizeVectorExtract final
 ///   %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
 ///   ] %out_nd = vector.shape_cast %out_1d
 /// `shuffle_indices_1d` is computed using the position of the original insert.
-struct LinearizeVectorInsert final
+struct VectorInsertToRankOneShuffle final
     : public OpConversionPattern<vector::InsertOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorInsert(const TypeConverter &typeConverter,
-                        MLIRContext *context, PatternBenefit benefit = 1)
+  VectorInsertToRankOneShuffle(const TypeConverter &typeConverter,
+                               MLIRContext *context, PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
@@ -508,7 +924,8 @@ struct LinearizeVectorInsert final
                                            // [offset+srcNumElements, end)
 
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
+        insertOp, dstTy, adaptor.getDest(),
+        asRankOne(rewriter, adaptor.getValueToStore()), indices);
 
     return success();
   }
@@ -526,7 +943,7 @@ struct LinearizeVectorBitCast final
     : public OpConversionPattern<vector::BitCastOp> {
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorBitCast(const TypeConverter &typeConverter,
-                         MLIRContext *context, PatternBenefit benefit = 1)
+                         MLIRContext *context, PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
@@ -535,7 +952,7 @@ struct LinearizeVectorBitCast final
     assert(resType && "expected 1-D vector type");
     rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
                                                    adaptor.getSource());
-    return mlir::success();
+    return success();
   }
 };
 
@@ -550,7 +967,7 @@ struct LinearizeVectorSplat final
   using OpConversionPattern::OpConversionPattern;
 
   LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
-                       PatternBenefit benefit = 1)
+                       PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
@@ -581,7 +998,7 @@ struct LinearizeVectorCreateMask final
   using OpConversionPattern::OpConversionPattern;
 
   LinearizeVectorCreateMask(const TypeConverter &typeConverter,
-                            MLIRContext *context, PatternBenefit benefit = 1)
+                            MLIRContext *context, PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
@@ -607,17 +1024,16 @@ struct LinearizeVectorCreateMask final
     // The result of the comparison is then multiplied with
     // the second operand of create_mask to get the 1D mask.
     auto firstOperand = adaptor.getOperands().front();
-    auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
-    auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
-        loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
-    auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto isNonZero = rewriter.createOrFold<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sgt, firstOperand, zero);
+    auto isNonZeroIndex = rewriter.createOrFold<arith::IndexCastOp>(
         loc, rewriter.getIndexType(), isNonZero);
     auto secondOperand = adaptor.getOperands().back();
-    auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
+    auto maskSize = rewriter.createOrFold<arith::AndIOp>(
         loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
 
-    auto newMask =
-        rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
+    auto newMask = rewriter.create<vector::CreateMaskOp>(loc, dstTy, maskSize);
     rewriter.replaceOp(createMaskOp, newMask);
     return success();
   }
@@ -625,104 +1041,179 @@ struct LinearizeVectorCreateMask final
 
 } // namespace
 
-/// This method defines the set of operations that are linearizable, and hence
-/// that are considered illegal for the conversion target.
-static bool isLinearizable(Operation *op) {
+/// This pattern converts a vector.extract_strided_slice into a new
+/// vector.extract_strided_slice where the operand and result of the new
+/// vector.extract_strided_slice have ranks that are as low as possible.
+///
+/// If the original vector.extract_strided_slice is a contiguous slice of
+/// a vector, then the new vector.extract_strided_slice will have rank-1
+/// operand and result. Otherwise additional dimensions will remain in the
+/// new operand and result.
+struct RankReduceExtractStridedSlice final
+    : public OpConversionPattern<vector::ExtractStridedSliceOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  // Only ops that are in the vector dialect, are ConstantLike, or
-  // are Vectorizable might be linearized currently.
-  StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
-  StringRef opDialect = op->getDialect()->getNamespace();
-  bool supported = (opDialect == vectorDialect) ||
-                   op->hasTrait<OpTrait::ConstantLike>() ||
-                   op->hasTrait<OpTrait::Vectorizable>();
-  if (!supported)
-    return false;
+  RankReduceExtractStridedSlice(const TypeConverter &typeConverter,
+                                MLIRContext *context, PatternBenefit benefit)
+      : OpConversionPattern(typeConverter, context, benefit) {}
 
-  return TypeSwitch<Operation *, bool>(op)
-      // As type legalization is done with vector.shape_cast, shape_cast
-      // itself cannot be linearized (will create new shape_casts to linearize
-      // ad infinitum).
-      .Case<vector::ShapeCastOp>([&](auto) { return false; })
-      // The operations
-      // - vector.extract_strided_slice
-      // - vector.extract
-      // - vector.insert_strided_slice
-      // - vector.insert
-      // are linearized to a rank-1 vector.shuffle by the current patterns.
-      // vector.shuffle only supports fixed size vectors, so it is impossible to
-      // use this approach to linearize these ops if they operate on scalable
-      // vectors.
-      .Case<vector::ExtractStridedSliceOp>(
-          [&](vector::ExtractStridedSliceOp extractOp) {
-            return !extractOp.getType().isScalable();
-          })
-      .Case<vector::InsertStridedSliceOp>(
-          [&](vector::InsertStridedSliceOp insertOp) {
-            return !insertOp.getType().isScalable();
-          })
-      .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
-        return !insertOp.getType().isScalable();
-      })
-      .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
-        return !extractOp.getSourceVectorType().isScalable();
-      })
-      .Default([&](auto) { return true; });
-}
+  LogicalResult
+  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
-void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
-                                              ConversionTarget &target) {
+    auto maybeCollapsed = getCollapsedExtractStridedSliceShape(extractOp);
+    if (!maybeCollapsed.has_value())
+      return failure();
 
-  auto convertType = [](Type type) -> std::optional<Type> {
-    VectorType vectorType = dyn_cast<VectorType>(type);
-    if (!vectorType || !isLinearizableVector(vectorType))
-      return type;
+    const auto &[collapsedOutShape, collapsedInShape, collapsedOffsets] =
+        maybeCollapsed.value();
 
-    VectorType linearizedType =
-        VectorType::get(vectorType.getNumElements(),
-                        vectorType.getElementType(), vectorType.isScalable());
-    return linearizedType;
-  };
-  typeConverter.addConversion(convertType);
+    VectorType collapsedInType =
+        VectorType::get(collapsedInShape, extractOp.getType().getElementType());
 
-  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
-                            Location loc) -> Value {
-    if (inputs.size() != 1)
-      return nullptr;
+    auto collapsedIn = rewriter.createOrFold<vector::ShapeCastOp>(
+        extractOp.getLoc(), collapsedInType, adaptor.getVector());
 
-    Value value = inputs.front();
-    if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
-      return nullptr;
+    auto replacement = rewriter.create<vector::ExtractStridedSliceOp>(
+        extractOp.getLoc(), collapsedIn, collapsedOffsets, collapsedOutShape,
+        SmallVector<int64_t>(collapsedOffsets.size(), 1));
 
-    return builder.create<vector::ShapeCastOp>(loc, type, value);
-  };
-  typeConverter.addSourceMaterialization(materializeCast);
-  typeConverter.addTargetMaterialization(materializeCast);
+    VectorType flatOutputType =
+        getTypeConverter()->convertType<VectorType>(extractOp.getType());
 
-  target.markUnknownOpDynamicallyLegal(
-      [=](Operation *op) -> std::optional<bool> {
-        if (!isLinearizable(op))
-          return true;
-        // This will return true if, for all operand and result types `t`,
-        // convertType(t) = t. This is true if there are no rank>=2 vectors.
-        return typeConverter.isLegal(op);
-      });
-}
+    Value out = rewriter.createOrFold<vector::ShapeCastOp>(
+        extractOp.getLoc(), flatOutputType, replacement);
 
-void mlir::vector::populateVectorLinearizeBasePatterns(
-    const TypeConverter &typeConverter, const ConversionTarget &target,
-    RewritePatternSet &patterns) {
-  patterns
-      .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
-           LinearizeVectorSplat, LinearizeVectorCreateMask>(
-          typeConverter, patterns.getContext());
-}
+    rewriter.replaceOp(extractOp, out);
+
+    return success();
+  }
+};
+
+struct RankReduceInsertStridedSlice final
+    : public OpConversionPattern<vector::InsertStridedSliceOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  RankReduceInsertStridedSlice(const TypeConverter &typeConverter,
+                               MLIRContext *context, PatternBenefit benefit)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto maybeCollapsed = getCollapsedInsertStridedSliceShape(insertOp);
+
+    if (!maybeCollapsed.has_value())
+      return failure();
+
+    const auto &[collapsedInShape, collapsedOutShape, collapsedOffsets] =
+        maybeCollapsed.value();
+
+    VectorType collapsedInType =
+        VectorType::get(collapsedInShape, insertOp.getType().getElementType());
+
+    Value collapsedIn = rewriter.createOrFold<vector::ShapeCastOp>(
+        insertOp.getLoc(), collapsedInType, adaptor.getValueToStore());
+
+    VectorType collapsedOutType =
+        VectorType::get(collapsedOutShape, insertOp.getType().getElementType());
+
+    Value collapsedDst = rewriter.createOrFold<vector::ShapeCastOp>(
+        insertOp.getLoc(), collapsedOutType, adaptor.getDest());
+
+    auto replacement = rewriter.create<vector::InsertStridedSliceOp>(
+        insertOp.getLoc(), collapsedIn, collapsedDst, collapsedOffsets,
+        SmallVector<int64_t>(collapsedOffsets.size(), 1));
+
+    Value out = rewriter.createOrFold<vector::ShapeCastOp>(
+        insertOp.getLoc(), insertOp.getType(), replacement);
+
+    rewriter.replaceOp(insertOp, out);
+
+    return success();
+  }
+};
 
-void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
-    const TypeConverter &typeConverter, const ConversionTarget &target,
-    RewritePatternSet &patterns) {
-  patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
-               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
-               LinearizeVectorInsertStridedSlice>(typeConverter,
-                                                  patterns.getContext());
+void vector::VectorLinearizePatterns::addToPatternSet(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns) const {
+
+  MLIRContext *context = patterns.getContext();
+
+  if (isEnabled(LinearizePattern::LinearizeConstantLike))
+    patterns.add<LinearizeConstantLike>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeConstantLike));
+
+  if (isEnabled(LinearizePattern::LinearizeVectorizable))
+    patterns.add<LinearizeVectorizable>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeVectorizable));
+
+  if (isEnabled(LinearizePattern::LinearizeVectorBitCast))
+    patterns.add<LinearizeVectorBitCast>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeVectorBitCast));
+
+  if (isEnabled(LinearizePattern::LinearizeVectorCreateMask))
+    patterns.add<LinearizeVectorCreateMask>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeVectorCreateMask));
+
+  if (isEnabled(LinearizePattern::LinearizeVectorShuffle))
+    patterns.add<LinearizeVectorShuffle>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeVectorShuffle));
+
+  if (isEnabled(LinearizePattern::LinearizeVectorSplat))
+    patterns.add<LinearizeVectorSplat>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeVectorSplat));
+
+  // ------------------------ //
+  // Extract related patterns //
+  // ------------------------ //
+  if (isEnabled(LinearizePattern::VectorExtractToRankOneShuffle))
+    patterns.add<VectorExtractToRankOneShuffle>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::VectorExtractToRankOneShuffle));
+
+  if (isEnabled(LinearizePattern::VectorExtractStridedSliceToRankOneShuffle))
+    patterns.add<VectorExtractStridedSliceToRankOneShuffle>(
+        typeConverter, context,
+        getBenefit(
+            LinearizePattern::VectorExtractStridedSliceToRankOneShuffle));
+
+  if (isEnabled(LinearizePattern::RankReduceExtractStridedSlice))
+    patterns.add<RankReduceExtractStridedSlice>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::RankReduceExtractStridedSlice));
+
+  if (isEnabled(LinearizePattern::VectorExtractToRankOneStrided))
+    patterns.add<VectorExtractToRankOneStrided>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::VectorExtractToRankOneStrided));
+
+  // ------------------------ //
+  // Insert related patterns  //
+  // ------------------------ //
+  if (isEnabled(LinearizePattern::VectorInsertToRankOneShuffle))
+    patterns.add<VectorInsertToRankOneShuffle>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::VectorInsertToRankOneShuffle));
+
+  if (isEnabled(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle))
+    patterns.add<VectorInsertStridedSliceToRankOneShuffle>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle));
+
+  if (isEnabled(LinearizePattern::RankReduceInsertStridedSlice))
+    patterns.add<RankReduceInsertStridedSlice>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::RankReduceInsertStridedSlice));
+
+  if (isEnabled(LinearizePattern::VectorInsertToRankOneStrided))
+    patterns.add<VectorInsertToRankOneStrided>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::VectorInsertToRankOneStrided));
 }
diff --git a/mlir/test/Dialect/Vector/linearize/linearize-insert-extract-preference.mlir b/mlir/test/Dialect/Vector/linearize/linearize-insert-extract-preference.mlir
new file mode 100644
index 0000000000000..a73602263d06e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize/linearize-insert-extract-preference.mlir
@@ -0,0 +1,287 @@
+// Everything becomes a shuffle (except rank-1 insert/extract).
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=preference=Shuffle | FileCheck %s --check-prefixes=SHUFFLE,ALL
+
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=preference=Strided | FileCheck %s --check-prefixes=STRIDED,ALL
+
+
+// **--------------------------------------------------------**
+//              Tests of vector.insert
+// **--------------------------------------------------------**
+
+// vector.insert where the destination is a 1D vector is always unchanged.
+//
+// ALL-LABEL: insert_scalar_to_1D(
+//  ALL-SAME: %[[A0:.*]]: i8, %[[A1:.*]]: vector<4xi8>
+//       ALL: %[[IN0:.*]] = vector.insert %[[A0]], %[[A1]] [2] : i8 into vector<4xi8>
+//       ALL: return %[[IN0]] : vector<4xi8>
+func.func @insert_scalar_to_1D(%arg0 : i8, %arg1 : vector<4xi8>) -> vector<4xi8>
+{
+  %inserted = vector.insert %arg0, %arg1[2] : i8 into vector<4xi8>
+  return %inserted : vector<4xi8>
+}
+
+// -----
+
+// vector.insert of scalar always becomes insert of scalar into 1-D vector.
+//
+// ALL-LABEL: insert_scalar_to_2D(
+//  ALL-SAME: %[[A0:.*]]: i8, %[[A1:.*]]: vector<3x4xi8>
+//       ALL: %[[SC0:.*]] = vector.shape_cast %[[A1]] : vector<3x4xi8> to vector<12xi8>
+//       ALL: %[[IN0:.*]] = vector.insert %[[A0]], %[[SC0]] [9] : i8 into vector<12xi8>
+//       ALL: %[[SC1:.*]] = vector.shape_cast %[[IN0]] : vector<12xi8> to vector<3x4xi8>
+//       ALL: return %[[SC1]] : vector<3x4xi8>
+func.func @insert_scalar_to_2D(%arg0 : i8, %arg1 : vector<3x4xi8>) -> vector<3x4xi8>
+{
+  %inserted = vector.insert %arg0, %arg1[2, 1] : i8 into vector<3x4xi8>
+  return %inserted : vector<3x4xi8>
+}
+
+// -----
+
+// vector.insert where the source isn't a scalar. First case: 1D -> 2D.
+//
+//    ALL-LABEL: insert_1D_to_2D(
+//
+//      SHUFFLE: vector.shuffle {{.*}} [0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11]
+//
+//      STRIDED: vector.insert_strided_slice {{.*}} {offsets = [4], strides = [1]}
+// STRIDED-SAME: vector<4xi8> into vector<12xi8>
+func.func @insert_1D_to_2D(%arg0 : vector<4xi8>, %arg1 : vector<3x4xi8>) -> vector<3x4xi8>
+{
+  %inserted = vector.insert %arg0, %arg1[1] : vector<4xi8> into vector<3x4xi8>
+  return %inserted : vector<3x4xi8>
+}
+
+
+// -----
+
+// vector.insert where the source isn't a scalar. Second case: 0D -> 2D.
+//
+//    ALL-LABEL: insert_OD_to_2D(
+//
+//      SHUFFLE: vector.shuffle {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 10, 11] :
+// SHUFFLE-SAME: vector<12xi8>, vector<1xi8>
+//
+//      STRIDED: vector.insert_strided_slice {{.*}} {offsets = [9], strides = [1]}
+// STRIDED-SAME: vector<1xi8> into vector<12xi8>
+func.func @insert_OD_to_2D(%arg0 : vector<i8>, %arg1 : vector<3x4xi8>) -> vector<3x4xi8>
+{
+  %inserted = vector.insert %arg0, %arg1[2, 1] : vector<i8> into vector<3x4xi8>
+  return %inserted : vector<3x4xi8>
+}
+
+// -----
+
+// vector.insert where the source isn't a scalar. Third case: 0D -> 1D.
+//
+//    ALL-LABEL: insert_OD_to_1D(
+//     ALL-SAME: %[[A0:.*]]: vector<i8>, %[[A1:.*]]: vector<4xi8>
+//          ALL: %[[IN0:.*]] = vector.insert %[[A0]], %[[A1]] [2] : vector<i8> into vector<4xi8>
+//          ALL: return %[[IN0]] : vector<4xi8>
+func.func @insert_OD_to_1D(%arg0 : vector<i8>, %arg1 : vector<4xi8>) -> vector<4xi8>
+{
+  %inserted = vector.insert %arg0, %arg1[2] : vector<i8> into vector<4xi8>
+  return %inserted : vector<4xi8>
+}
+
+// -----
+
+// vector.insert where the source isn't a scalar. Fourth case: 2D -> 4D.
+//
+//    ALL-LABEL: insert_2D_to_4D(
+//  ALL-COUNT-2: shape_cast
+//
+//      SHUFFLE: vector.shuffle {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] :
+// SHUFFLE-SAME: vector<16xi8>, vector<4xi8>
+//
+//      STRIDED: vector.insert_strided_slice {{.*}} {offsets = [12], strides = [1]}
+// STRIDED-SAME: vector<4xi8> into vector<16xi8>
+func.func @insert_2D_to_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x2x2x2xi8>) -> vector<2x2x2x2xi8>
+{
+  %inserted = vector.insert %arg0, %arg1[1, 1] : vector<2x2xi8> into vector<2x2x2x2xi8>
+  return %inserted : vector<2x2x2x2xi8>
+}
+
+// -----
+
+// **--------------------------------------------------------**
+//              Tests of vector.extract
+// **--------------------------------------------------------**
+
+// vector.extract where the source is 1D vector is always unchanged.
+//
+// ALL-LABEL: extract_scalar_from_1D(
+//  ALL-SAME: %[[A0:.*]]: vector<4xi8>
+//       ALL: %[[EX0:.*]] = vector.extract %[[A0]][2] : i8 from vector<4xi8>
+//       ALL: return %[[EX0]] : i8
+func.func @extract_scalar_from_1D(%arg0 : vector<4xi8>) -> i8
+{
+  %extracted = vector.extract %arg0[2] : i8 from vector<4xi8>
+  return %extracted : i8
+}
+
+// ALL-LABEL: extract_scalar_from_2D(
+//  ALL-SAME: %[[A0:.*]]: vector<3x4xi8>
+//       ALL: %[[SC0:.*]] = vector.shape_cast %[[A0]] : vector<3x4xi8> to vector<12xi8>
+//       ALL: %[[EX0:.*]] = vector.extract %[[SC0]][9] : i8 from vector<12xi8>
+//       ALL: return %[[EX0]] : i8
+func.func @extract_scalar_from_2D(%arg0 : vector<3x4xi8>) -> i8
+{
+  %extracted = vector.extract %arg0[2, 1] : i8 from vector<3x4xi8>
+  return %extracted : i8
+}
+
+// -----
+
+//    ALL-LABEL: extract_1D_from_2D(
+//
+//      SHUFFLE: vector.shuffle
+// SHUFFLE-SAME: [4, 5, 6, 7] : vector<12xi8>, vector<12xi8>
+//
+//      STRIDED: vector.extract_strided_slice
+// STRIDED-SAME: {offsets = [4], sizes = [4], strides = [1]} : vector<12xi8> to vector<4xi8>
+func.func @extract_1D_from_2D(%arg0 : vector<3x4xi8>) -> vector<4xi8>
+{
+  %extracted = vector.extract %arg0[1] : vector<4xi8> from vector<3x4xi8>
+  return %extracted : vector<4xi8>
+}
+
+// -----
+
+//    ALL-LABEL: extract_2D_from_4D(
+//
+//      SHUFFLE: vector.shuffle
+// SHUFFLE-SAME: [10, 11] : vector<24xi8>, vector<24xi8>
+//
+//      STRIDED: vector.extract_strided_slice
+// STRIDED-SAME: {offsets = [10], sizes = [2], strides = [1]} : vector<24xi8> to vector<2xi8>
+func.func @extract_2D_from_4D(%arg0 : vector<4x3x2x1xi8>) -> vector<2x1xi8> {
+  %extracted = vector.extract %arg0[1, 2] : vector<2x1xi8> from vector<4x3x2x1xi8>
+  return %extracted : vector<2x1xi8>
+}
+
+// **--------------------------------------------------------**
+//              Tests of vector.insert_strided_slice
+// **--------------------------------------------------------**
+
+// -----
+
+// ALL-LABEL: insert_strided_slice_1D(
+//
+//   SHUFFLE: shuffle {{.*}} [0, 8, 9, 3, 4, 5, 6, 7]
+//
+//   STRIDED: insert_strided_slice {{.*}} {offsets = [1], strides = [1]}
+func.func @insert_strided_slice_1D(%arg0 : vector<2xi8>, %arg1 : vector<8xi8>) -> vector<8xi8>
+{
+  %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xi8> into vector<8xi8>
+  return %inserted : vector<8xi8>
+}
+
+// -----
+
+//    ALL-LABEL: insert_strided_slice_4D_contiguous(
+//
+//      SHUFFLE: vector.shuffle
+// SHUFFLE-SAME: 52, 53, 120, 121
+// SHUFFLE-SAME: 130, 131, 66, 67
+// SHUFFLE-SAME: vector<120xi8>, vector<12xi8>
+//
+//      STRIDED: vector.insert_strided_slice
+// STRIDED-SAME: {offsets = [54], strides = [1]}
+// STRIDED-SAME: vector<12xi8> into vector<120xi8>
+
+
+func.func @insert_strided_slice_4D_contiguous(%arg0 : vector<1x2x2x3xi8>, %arg1 : vector<5x4x2x3xi8>) -> vector<5x4x2x3xi8> {
+  %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [2, 1, 0, 0], strides = [1, 1, 1, 1]} : vector<1x2x2x3xi8> into vector<5x4x2x3xi8>
+  return %inserted : vector<5x4x2x3xi8>
+}
+
+// ----- 
+
+// This insert_strided_slice is not contiguous, and so it is always linearized to a 1D vector.shuffle
+
+// ALL-LABEL: insert_strided_slice_4D_noncontiguous(
+//       ALL: vector.shuffle
+//  ALL-SAME: [0, 1, 2, 8, 4, 5, 6, 9] : vector<8xi8>, vector<2xi8>
+
+func.func @insert_strided_slice_4D_noncontiguous(%arg0 : vector<1x2x1x1xi8>, %arg1 : vector<1x2x2x2xi8>) -> vector<1x2x2x2xi8> {
+  %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 1, 1], strides = [1, 1, 1, 1]} : vector<1x2x1x1xi8> into vector<1x2x2x2xi8>
+  return %inserted : vector<1x2x2x2xi8>
+}
+
+// -----
+
+// **--------------------------------------------------------**
+//              Tests of vector.extract_strided_slice
+// **--------------------------------------------------------**
+
+//    ALL-LABEL: extract_strided_slice_1D(
+// 
+//      SHUFFLE: vector.shuffle {{.*}} [1, 2] 
+// 
+//      STRIDED: vector.extract_strided_slice
+// STRIDED-SAME: {offsets = [1], sizes = [2], strides = [1]} 
+// STRIDED-SAME: vector<8xi8> to vector<2xi8>
+func.func @extract_strided_slice_1D(%arg0 : vector<8xi8>) -> vector<2xi8>
+{
+  %extracted = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<8xi8> to vector<2xi8>
+  return %extracted : vector<2xi8>
+}
+
+// ----- 
+
+//    ALL-LABEL: extract_strided_slice_4D_contiguous_1(
+//
+//      SHUFFLE: vector.shuffle
+// SHUFFLE-SAME: [3, 4, 5]
+// SHUFFLE-SAME: vector<6xi8>, vector<6xi8>
+// 
+//     STRIDED: vector.extract_strided_slice
+// STRIDED-SAME: {offsets = [3], sizes = [3], strides = [1]}
+// STRIDED-SAME: vector<6xi8> to vector<3xi8>
+func.func @extract_strided_slice_4D_contiguous_1(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x3x1xi8> {
+  %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 3, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x3x1xi8>
+  return %extracted : vector<1x1x3x1xi8>
+}
+
+// -----  
+
+//    ALL-LABEL: extract_strided_slice_4D_contiguous_2(
+// 
+//      SHUFFLE: vector.shuffle
+// SHUFFLE-SAME: [3, 4]
+// SHUFFLE-SAME: vector<6xi8>, vector<6xi8>
+//
+//      STRIDED: vector.extract_strided_slice
+// STRIDED-SAME: {offsets = [3], sizes = [2], strides = [1]}
+// STRIDED-SAME: vector<6xi8> to vector<2xi8>
+func.func @extract_strided_slice_4D_contiguous_2(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> {
+  %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x2x1xi8>
+  return %extracted : vector<1x1x2x1xi8>
+}
+
+// -----
+
+// ALL-LABEL: extract_strided_slice_4D_noncontiguous(
+//       ALL: vector.shuffle
+//  ALL-SAME: [0, 1, 3, 4]
+//  ALL-SAME: vector<6xi8>, vector<6xi8>
+func.func @extract_strided_slice_4D_noncontiguous(%arg0 : vector<2x1x3x1xi8>) -> vector<2x1x2x1xi8> {
+  %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0], sizes = [2, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<2x1x2x1xi8>
+  return %extracted : vector<2x1x2x1xi8>
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir b/mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir
similarity index 100%
rename from mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir
rename to mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize/linearize.mlir
similarity index 99%
rename from mlir/test/Dialect/Vector/linearize.mlir
rename to mlir/test/Dialect/Vector/linearize/linearize.mlir
index 9cbf319ffddb2..b7a5448c7dc22 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize/linearize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=preference=Shuffle  -verify-diagnostics | FileCheck %s
 
 // CHECK-LABEL: test_linearize
 // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -297,7 +297,6 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x
 // CHECK-LABEL: test_vector_insert
 // CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
 func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
-
   // CHECK-DAG: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
   // CHECK-DAG: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
   // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
diff --git a/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir b/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir
new file mode 100644
index 0000000000000..342936b0f7eed
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir
@@ -0,0 +1,135 @@
+// RUN: mlir-opt %s -split-input-file -test-rank-reduce-strided-slice-ops  -verify-diagnostics | FileCheck %s
+
+
+// **---------------------------------------------**
+//       Tests of vector.extract_strided_slice
+// **---------------------------------------------**
+
+// The 6 elements extracted are contiguous, so this can be expressed as a rank-1 vector.extract_strided_slice.
+
+// CHECK-LABEL: @extract_strided_slice_2D_to_1D(
+//  CHECK-SAME:    %[[A:.*]]: vector<5x2xi8>) -> vector<3x2xi8> {
+//       CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<5x2xi8> to vector<10xi8>
+//       CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]]
+//  CHECK-SAME:    {offsets = [2], sizes = [6], strides = [1]} : vector<10xi8> to vector<6xi8>
+//       CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<6xi8> to vector<3x2xi8>
+//       CHECK: return %[[CASTED]] : vector<3x2xi8>
+func.func @extract_strided_slice_2D_to_1D(%arg0 : vector<5x2xi8>) -> vector<3x2xi8> {
+  %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [3, 2], strides = [1, 1]} : vector<5x2xi8> to vector<3x2xi8>
+  return %extracted : vector<3x2xi8>
+}
+
+// -----
+
+// The 5 elements extracted are not contiguous, so this cannot be expressed as a rank-1 vector.extract_strided_slice.
+
+// CHECK-LABEL: @negative_extract_strided_slice_2D_to_1D(
+//  CHECK-SAME:    %[[A:.*]]: vector<5x2xi8>) -> vector<5x1xi8> {
+//       CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[A]]
+//       CHECK: return %[[EXTRACTED]] : vector<5x1xi8>
+func.func @negative_extract_strided_slice_2D_to_1D(%arg0 : vector<5x2xi8>) -> vector<5x1xi8> {
+  %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [5, 1], strides = [1, 1]} : vector<5x2xi8> to vector<5x1xi8>
+  return %extracted : vector<5x1xi8>
+}
+
+// -----
+
+// The 2 elements extracted are contiguous, so this can be expressed as a rank-1 vector.extract_strided_slice.
+
+// CHECK-LABEL: @extract_strided_slice_4D_leading_ones(
+//  CHECK-SAME:    %[[A:.*]]: vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> {
+//       CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<2x1x3x1xi8> to vector<6xi8>
+//       CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]]
+//  CHECK-SAME:    {offsets = [3], sizes = [2], strides = [1]} : vector<6xi8> to vector<2xi8>
+//       CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<2xi8> to vector<1x1x2x1xi8>
+//       CHECK: return %[[CASTED]] : vector<1x1x2x1xi8>
+
+func.func @extract_strided_slice_4D_leading_ones(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> {
+  %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x2x1xi8>
+  return %extracted : vector<1x1x2x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_strided_slice_4D_becomes_2D(
+//  CHECK-SAME:    %[[A:.*]]: vector<8x7x6x5xi8>) -> vector<2x7x2x5xi8> {
+//       CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<8x7x6x5xi8> to vector<56x30xi8>
+//       CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]]
+//  CHECK-SAME:    {offsets = [14, 5], sizes = [14, 10], strides = [1, 1]} : vector<56x30xi8> to vector<14x10xi8>
+//       CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<14x10xi8> to vector<2x7x2x5xi8>
+//      CHECK: return %[[CASTED]] : vector<2x7x2x5xi8>
+func.func @extract_strided_slice_4D_becomes_2D(%arg0 : vector<8x7x6x5xi8>) -> vector<2x7x2x5xi8> {
+  %extracted = vector.extract_strided_slice %arg0 {offsets = [2, 0, 1, 0], sizes = [2, 7, 2, 5], strides = [1, 1, 1, 1]} : vector<8x7x6x5xi8> to vector<2x7x2x5xi8>
+  return %extracted : vector<2x7x2x5xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_strided_slice_4D_becomes_3D(
+//  CHECK-SAME:    %[[A:.*]]: vector<8x7x6x5xi8>) -> vector<8x2x6x2xi8> {
+  //       CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<8x7x6x5xi8> to vector<8x42x5xi8>
+  //       CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]]
+  //  CHECK-SAME:    {offsets = [0, 12, 1], sizes = [8, 12, 2], strides = [1, 1, 1]} : vector<8x42x5xi8> to vector<8x12x2xi8>
+  //       CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<8x12x2xi8> to vector<8x2x6x2xi8>
+  //       CHECK: return %[[CASTED]] : vector<8x2x6x2xi8>
+
+func.func @extract_strided_slice_4D_becomes_3D(%arg0 : vector<8x7x6x5xi8>) -> vector<8x2x6x2xi8> {
+  %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 2, 0, 1], sizes = [8, 2, 6, 2], strides = [1, 1, 1, 1]} : vector<8x7x6x5xi8> to vector<8x2x6x2xi8>
+  return %extracted : vector<8x2x6x2xi8>
+}
+
+// -----
+
+// **---------------------------------------------**
+//       Tests of vector.insert_strided_slice
+// **---------------------------------------------**
+
+
+// CHECK-LABEL: @negative_insert_strided_slice(
+//  CHECK-SAME:    %[[A:.*]]: vector<2x2xi8>, %[[B:.*]]: vector<2x1xi8>) -> vector<2x2xi8> {
+//       CHECK: %[[INSERTED:.*]] = vector.insert_strided_slice %[[B]], %[[A]]
+//       CHECK: return %[[INSERTED]] : vector<2x2xi8>
+func.func @negative_insert_strided_slice(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1xi8>) -> vector<2x2xi8> {
+  %inserted = vector.insert_strided_slice %arg1, %arg0 {offsets = [0, 1], strides = [1, 1]} : vector<2x1xi8> into vector<2x2xi8>
+  return %inserted : vector<2x2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @positive_insert_strided_slice(
+//  CHECK-SAME:    %[[A:.*]]: vector<2x2xi8>, %[[B:.*]]: vector<1x2xi8>) -> vector<2x2xi8> {
+//   CHECK-DAG: %[[SCA:.*]] = vector.shape_cast %[[A]] : vector<2x2xi8> to vector<4xi8>
+//   CHECK-DAG: %[[SCB:.*]] = vector.shape_cast %[[B]] : vector<1x2xi8> to vector<2xi8>
+//       CHECK: %[[INSERTED:.*]] = vector.insert_strided_slice %[[SCB]], %[[SCA]]
+//  CHECK-SAME:    {offsets = [0], strides = [1]} : vector<2xi8> into vector<4xi8>
+//       CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INSERTED]] : vector<4xi8> to vector<2x2xi8>
+//       CHECK: return %[[CASTED]] : vector<2x2xi8>
+
+func.func @positive_insert_strided_slice(%arg0 : vector<2x2xi8>, %arg1 : vector<1x2xi8>) -> vector<2x2xi8> {
+  %inserted = vector.insert_strided_slice %arg1, %arg0 {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi8> into vector<2x2xi8>
+  return %inserted : vector<2x2xi8>
+}
+
+// ----- 
+
+func.func @test_extract_strided_slice_4D(%arg0 : vector<2x2x2x2xi8>) -> vector<1x2x1x2xi8> {
+  %0 = vector.extract_strided_slice %arg0
+    {offsets = [1, 0, 1, 0],
+       sizes = [1, 2, 1, 2],
+     strides = [1, 1, 1, 1]} : vector<2x2x2x2xi8> to vector<1x2x1x2xi8>
+  return %0 : vector<1x2x1x2xi8>
+}
+
+// -----
+
+// Equivent to the above but now with an insert strided slice. 
+
+
+func.func @test_insert_strided_slice_4D(%arg0 : vector<2x2x2x2xi8>, %arg1 : vector<1x2x1x2xi8>) -> vector<2x2x2x2xi8> {
+  %0 = vector.insert_strided_slice %arg1, %arg0
+    {offsets = [1, 0, 1, 0],
+     strides = [1, 1, 1, 1]} : vector<1x2x1x2xi8> into vector<2x2x2x2xi8>
+  return %0 : vector<2x2x2x2xi8>
+}
+
+
diff --git a/mlir/test/lib/Dialect/Vector/CMakeLists.txt b/mlir/test/lib/Dialect/Vector/CMakeLists.txt
index e16937029ac0e..1ce069599af43 100644
--- a/mlir/test/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Vector/CMakeLists.txt
@@ -1,6 +1,7 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRVectorTestPasses
   TestVectorTransforms.cpp
+  TestVectorLinearize.cpp
 
   EXCLUDE_FROM_LIBMLIR
   )
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp
new file mode 100644
index 0000000000000..21d6356bf04fa
--- /dev/null
+++ b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp
@@ -0,0 +1,248 @@
+//===- TestVectorLinearize.cpp - Test Vector linearization ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <optional>
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math//IR/Math.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+struct TestVectorLinearize final
+    : public PassWrapper<TestVectorLinearize, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+
+  TestVectorLinearize() = default;
+
+  TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+
+  StringRef getArgument() const override { return "test-vector-linearize"; }
+  StringRef getDescription() const override {
+    return "Linearizes ND vectors for N >= 2 into 1D vectors";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<VectorDialect, arith::ArithDialect, math::MathDialect>();
+  }
+  Option<InsertExtractLinearizePreference> preference{
+      *this, "preference",
+      llvm::cl::desc("Corresponds 1:1 with InsertExtractLinearizePreference"),
+      llvm::cl::values(
+          clEnumValN(
+              static_cast<int>(InsertExtractLinearizePreference::Strided),
+              "Strided", ""),
+          clEnumValN(
+              static_cast<int>(InsertExtractLinearizePreference::Shuffle),
+              "Shuffle", ""))};
+
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    TypeConverter converter;
+    RewritePatternSet patterns(&context);
+    ConversionTarget target(context);
+    initializeForVectorLinearize(converter);
+    populateForFullVectorLinearize(converter, target, patterns, preference);
+
+    mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+        converter, patterns, target);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+struct TestRankReduceStridedSliceOps final
+    : public PassWrapper<TestRankReduceStridedSliceOps, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRankReduceStridedSliceOps)
+
+  TestRankReduceStridedSliceOps() = default;
+
+  StringRef getArgument() const override {
+    return "test-rank-reduce-strided-slice-ops";
+  }
+  StringRef getDescription() const override {
+    return "Test pass for rank-reducing strided slice ops.";
+  }
+
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    TypeConverter typeConverter;
+    RewritePatternSet patterns(&context);
+    ConversionTarget target(context);
+
+    VectorLinearizePatterns()
+        .enableAll(false)
+        .enable(LinearizePattern::RankReduceInsertStridedSlice)
+        .enable(LinearizePattern::RankReduceExtractStridedSlice)
+        .addToPatternSet(typeConverter, patterns);
+
+    typeConverter.addConversion(
+        [](Type t) -> std::optional<Type> { return t; });
+
+    typeConverter.addSourceMaterialization(
+        [](OpBuilder &builder, Type type, ValueRange inputs,
+           Location loc) -> Value { return inputs.front(); });
+
+    target.markUnknownOpDynamicallyLegal(
+        [&](Operation *op) -> std::optional<bool> {
+          if (auto insertOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
+            return !getCollapsedInsertStridedSliceShape(insertOp).has_value();
+          }
+          if (auto extractOp = dyn_cast<vector::ExtractStridedSliceOp>(op)) {
+            return !getCollapsedExtractStridedSliceShape(extractOp).has_value();
+          }
+          return true;
+        });
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+struct TestVectorBitWidthLinearize final
+    : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
+
+  TestVectorBitWidthLinearize() = default;
+  TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
+      : PassWrapper(pass) {}
+
+  StringRef getArgument() const override {
+    return "test-bit-width-constrained-vector-linearize";
+  }
+  StringRef getDescription() const override {
+    return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
+           "on inner-most dimension's bit width. If the inner-most dimension "
+           "exceded a threshold, the op is not linearized.";
+  }
+  Option<unsigned> targetVectorBitwidth{
+      *this, "target-vector-bitwidth",
+      llvm::cl::desc(
+          "Minimum vector bitwidth to enable the flattening transformation"),
+      llvm::cl::init(std::numeric_limits<unsigned>::max())};
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    TypeConverter typeConverter;
+    RewritePatternSet patterns(&context);
+    ConversionTarget target(context);
+    populateWithBitWidthConstraints(typeConverter, target, patterns,
+                                    targetVectorBitwidth);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+
+private:
+  /// If `type` is VectorType with trailing dimension of (bit) size greater than
+  /// or equal to `targetBitWidth`, its defining op is considered legal.
+  static bool
+  isNotLinearizableBecauseLargeInnerDimension(Type type,
+                                              unsigned targetBitWidth) {
+
+    VectorType vecType = dyn_cast<VectorType>(type);
+
+    // Not linearizable for reasons other than what this function checks.
+    if (!vecType || vecType.getRank() == 0)
+      return false;
+
+    // The width of the type 'index' is unbounded (and therefore potentially
+    // above the target width).
+    if (vecType.getElementType().isIndex())
+      return true;
+
+    unsigned finalDimSize = vecType.getShape().back();
+    unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
+    unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
+    return trailingVecDimBitWidth >= targetBitWidth;
+  }
+
+  static bool
+  isNotLinearizableBecauseLargeInnerDimension(Operation *op,
+                                              unsigned targetBitWidth) {
+    // Check on bitwidths.
+    SmallVector<std::pair<Type, unsigned>> toCheck =
+        getTypeBitWidthBoundPairs(op, targetBitWidth);
+    return std::any_of(toCheck.begin(), toCheck.end(),
+                       [&](std::pair<Type, unsigned> typeWidth) {
+                         return isNotLinearizableBecauseLargeInnerDimension(
+                             typeWidth.first, typeWidth.second);
+                       });
+  }
+
+  static void populateWithBitWidthConstraints(TypeConverter &typeConverter,
+                                              ConversionTarget &target,
+                                              RewritePatternSet &patterns,
+                                              unsigned targetBitWidth) {
+
+    initializeForVectorLinearize(typeConverter);
+    populateForFullVectorLinearize(typeConverter, target, patterns,
+                                   InsertExtractLinearizePreference::Shuffle);
+
+    // Extend the set of legal ops to include those with large inner-most
+    // dimensions on selected operands/results.
+    target.markUnknownOpDynamicallyLegal(
+        [=](Operation *op) -> std::optional<bool> {
+          if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
+            return true;
+          }
+          return {};
+        });
+  }
+
+  /// Get the set of operand/result types to check for sufficiently
+  /// small inner-most dimension size.
+  static SmallVector<std::pair<Type, unsigned>>
+  getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
+
+    if (auto insertOp = dyn_cast<InsertOp>(op)) {
+      unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
+                       ? targetBitWidth + 1
+                       : targetBitWidth;
+      return {{insertOp.getValueToStoreType(), w}};
+    }
+
+    auto resultTypes = op->getResultTypes();
+    SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+    resultsWithBitWidth.reserve(resultTypes.size());
+    for (Type type : resultTypes) {
+      resultsWithBitWidth.push_back({type, targetBitWidth});
+    }
+    return resultsWithBitWidth;
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+extern void registerTestVectorLinearize() {
+  PassRegistration<TestVectorLinearize>();
+  PassRegistration<TestVectorBitWidthLinearize>();
+  PassRegistration<TestRankReduceStridedSliceOps>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f4f32e9339870..5c75d32c22236 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -837,160 +836,6 @@ struct TestVectorEmulateMaskedLoadStore final
   }
 };
 
-/// Get the set of operand/result types to check for sufficiently
-/// small inner-most dimension size.
-static SmallVector<std::pair<Type, unsigned>>
-getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
-
-  if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
-    unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
-                     ? targetBitWidth + 1
-                     : targetBitWidth;
-    return {{insertOp.getValueToStoreType(), w}};
-  }
-
-  auto resultTypes = op->getResultTypes();
-  SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
-  resultsWithBitWidth.reserve(resultTypes.size());
-  for (Type type : resultTypes) {
-    resultsWithBitWidth.push_back({type, targetBitWidth});
-  }
-  return resultsWithBitWidth;
-}
-
-/// If `type` is VectorType with trailing dimension of (bit) size greater than
-/// or equal to `targetBitWidth`, its defining op is considered legal.
-static bool
-isNotLinearizableBecauseLargeInnerDimension(Type type,
-                                            unsigned targetBitWidth) {
-
-  VectorType vecType = dyn_cast<VectorType>(type);
-
-  // Not linearizable for reasons other than what this function checks.
-  if (!vecType || vecType.getRank() == 0)
-    return false;
-
-  // The width of the type 'index' is unbounded (and therefore potentially above
-  // the target width).
-  if (vecType.getElementType().isIndex())
-    return true;
-
-  unsigned finalDimSize = vecType.getShape().back();
-  unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
-  unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
-  return trailingVecDimBitWidth >= targetBitWidth;
-}
-
-static bool
-isNotLinearizableBecauseLargeInnerDimension(Operation *op,
-                                            unsigned targetBitWidth) {
-  // Check on bitwidths.
-  SmallVector<std::pair<Type, unsigned>> toCheck =
-      getTypeBitWidthBoundPairs(op, targetBitWidth);
-  return llvm::any_of(toCheck, [&](std::pair<Type, unsigned> typeWidth) {
-    return isNotLinearizableBecauseLargeInnerDimension(typeWidth.first,
-                                                       typeWidth.second);
-  });
-}
-
-void populateWithBitWidthConstraints(TypeConverter &typeConverter,
-                                     ConversionTarget &target,
-                                     unsigned targetBitWidth) {
-
-  // The general purpose definition of what ops are legal must come first.
-  populateForVectorLinearize(typeConverter, target);
-
-  // Extend the set of legal ops to include those with large inner-most
-  // dimensions on selected operands/results.
-  target.markUnknownOpDynamicallyLegal(
-      [=](Operation *op) -> std::optional<bool> {
-        if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
-          return true;
-        }
-        return {};
-      });
-}
-
-struct TestVectorBitWidthLinearize final
-    : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
-
-  TestVectorBitWidthLinearize() = default;
-  TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
-      : PassWrapper(pass) {}
-
-  StringRef getArgument() const override {
-    return "test-bit-width-constrained-vector-linearize";
-  }
-  StringRef getDescription() const override {
-    return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
-           "in inner-most dimension's bit width.";
-  }
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<vector::VectorDialect>();
-  }
-
-  Option<unsigned> targetVectorBitwidth{
-      *this, "target-vector-bitwidth",
-      llvm::cl::desc(
-          "Minimum vector bitwidth to enable the flattening transformation"),
-      llvm::cl::init(std::numeric_limits<unsigned>::max())};
-  void runOnOperation() override {
-    auto *context = &getContext();
-
-    TypeConverter typeConverter;
-    RewritePatternSet patterns(context);
-    ConversionTarget target(*context);
-
-    populateWithBitWidthConstraints(typeConverter, target,
-                                    targetVectorBitwidth);
-
-    vector::populateVectorLinearizeBasePatterns(typeConverter, target,
-                                                patterns);
-
-    vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target,
-                                                          patterns);
-
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      return signalPassFailure();
-  }
-};
-
-struct TestVectorLinearize final
-    : public PassWrapper<TestVectorLinearize, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
-
-  TestVectorLinearize() = default;
-
-  StringRef getArgument() const override { return "test-vector-linearize"; }
-  StringRef getDescription() const override {
-    return "Linearizes ND vectors for N >= 2 into 1D vectors";
-  }
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<vector::VectorDialect, arith::ArithDialect>();
-  }
-
-  void runOnOperation() override {
-    MLIRContext &context = getContext();
-    TypeConverter converter;
-    RewritePatternSet patterns(&context);
-    ConversionTarget target(context);
-
-    vector::populateForVectorLinearize(converter, target);
-
-    vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
-    vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
-                                                          patterns);
-    mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
-        converter, patterns, target);
-
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      return signalPassFailure();
-  }
-};
-
 struct TestEliminateVectorMasks
     : public PassWrapper<TestEliminateVectorMasks,
                          OperationPass<func::FuncOp>> {
@@ -1062,10 +907,6 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorEmulateMaskedLoadStore>();
 
-  PassRegistration<TestVectorLinearize>();
-
-  PassRegistration<TestVectorBitWidthLinearize>();
-
   PassRegistration<TestEliminateVectorMasks>();
 }
 } // namespace test
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2e08ae6f37980..f52f36107e301 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -155,6 +155,7 @@ void registerTestTopologicalSortAnalysisPass();
 void registerTestTransformDialectEraseSchedulePass();
 void registerTestPassStateExtensionCommunication();
 void registerTestVectorLowerings();
+void registerTestVectorLinearize();
 void registerTestVectorReductionToSPIRVDotProd();
 void registerTestVulkanRunnerPipeline();
 void registerTestWrittenToPass();
@@ -300,6 +301,7 @@ void registerTestPasses() {
   mlir::test::registerTestTransformDialectEraseSchedulePass();
   mlir::test::registerTestPassStateExtensionCommunication();
   mlir::test::registerTestVectorLowerings();
+  mlir::test::registerTestVectorLinearize();
   mlir::test::registerTestVectorReductionToSPIRVDotProd();
   mlir::test::registerTestVulkanRunnerPipeline();
   mlir::test::registerTestWrittenToPass();



More information about the Mlir-commits mailing list