[Mlir-commits] [mlir] [mlir][vector] Refactor vector linearization patterns (PR #142685)

James Newling llvmlistbot at llvm.org
Tue Jun 3 15:53:18 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/142685

>From 2e6562533eb69815a496189ce6854e6440adcdc4 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 3 Jun 2025 15:13:44 -0700
Subject: [PATCH 1/2] nfc changes to linearization

---
 .../Vector/Transforms/VectorLinearize.h       | 252 +++++++++
 .../Vector/Transforms/VectorRewritePatterns.h |  33 --
 .../Vector/Transforms/VectorLinearize.cpp     | 508 +++++++++---------
 .../linearize-subject-to-bitwidth.mlir        |   0
 .../Vector/{ => linearize}/linearize.mlir     |   0
 mlir/test/lib/Dialect/Vector/CMakeLists.txt   |   1 +
 .../Dialect/Vector/TestVectorLinearize.cpp    | 185 +++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   | 159 ------
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 9 files changed, 696 insertions(+), 444 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
 rename mlir/test/Dialect/Vector/{ => linearize}/linearize-subject-to-bitwidth.mlir (100%)
 rename mlir/test/Dialect/Vector/{ => linearize}/linearize.mlir (100%)
 create mode 100644 mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
new file mode 100644
index 0000000000000..cd62de640d088
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
@@ -0,0 +1,252 @@
+//===- VectorLinearize.h - Vector linearization patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace vector {
+
+/// Initialize `typeConverter` with source and target materialization logic
+/// using shape_casts to/from 1D vectors.
+void initializeForVectorLinearize(TypeConverter &typeConverter);
+
+/// Initialize `conversionTarget`, and `patterns` for linearization. Here
+/// linearization means converting a single operation with 1+ vector
+/// operand/result of rank>1, into a new single operation whose vector operands
+/// and results are all of rank<=1.
+///
+/// This function initializes `conversionTarget` with the set of operations that
+/// are illegal and consequently must be converted to a linearized form. It
+/// also populates the set of patterns that can be run to convert illegal
+/// operations, and what priority/benefit they have.
+///
+/// Note: the set of legal operations can be extended by a user if, for example,
+/// certain rank>1 vectors are considered valid, by adding additional
+/// dynamically legal ops to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to make it easy to reuse generic structural type
+/// conversions for linearizing scf/cf/func operations
+void populateForFullVectorLinearize(const TypeConverter &,
+                                    ConversionTarget &conversionTarget,
+                                    RewritePatternSet &patterns);
+
+/// The set of patterns available for linearization.
+enum class LinearizePattern {
+
+  /// This pattern converts a constant (or poison) vector of rank>1 into a
+  /// 1D vector followed by a shape_cast.
+  ///
+  /// BEFORE
+  /// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+  ///
+  /// AFTER
+  /// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+  /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+  LinearizeConstantLike = 0,
+
+  /// BEFORE
+  /// %2 = math.sin %arg0 : vector<2x2xf32>
+  ///
+  /// AFTER
+  /// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+  /// %1 = math.sin %0 : vector<4xf32>
+  /// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+  LinearizeVectorizable,
+
+  /// BEFORE
+  /// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+  ///
+  /// AFTER
+  /// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+  /// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+  /// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+  LinearizeVectorBitCast,
+
+  /// This pattern currently only supports 2D masks with a unit outer
+  /// dimension.
+  ///
+  /// BEFORE
+  /// %mask_2d = vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+  ///
+  /// AFTER
+  /// [...]
+  /// %mask_1d= vector.create_mask %mul : vector<4xi1>
+  /// %mask_2d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+  ///
+  /// where `%mul` is a function of `%arg0` and `%arg1`.
+  LinearizeVectorCreateMask,
+
+  /// This pattern converts the ShuffleOp that works on nD (n > 1)
+  /// vectors to a ShuffleOp that works on linearized vectors.
+  ///
+  /// BEFORE
+  /// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+  ///
+  /// AFTER
+  /// %v1_1d = vector.shape_cast %v1_3d : [...]
+  /// %v2_1d = vector.shape_cast %v2_3d : [...]
+  /// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+  /// %shuffle_3d = vector.shape_cast %shuffle_1d :  [...]
+  ///
+  /// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`.
+  LinearizeVectorShuffle,
+
+  /// BEFORE
+  /// %1 = vector.splat %value : vector<4x4xf32>
+  ///
+  /// AFTER
+  /// %0 = vector.splat %value : vector<16xf32>
+  /// %1 = vector.shape_cast %0 : vector<16xf32> to vector<4x4xf32>
+  LinearizeVectorSplat,
+
+  /// BEFORE
+  /// %extract = vector.extract %src [ position ]
+  ///
+  /// AFTER
+  /// %src_1d = vector.shape_cast %src : [...]
+  /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices ]
+  /// %out_nd = vector.shape_cast %out_1d : [...]
+  ///
+  /// `shuffle_indices` is computed from `position` of original extract.
+  VectorExtractToRankOneShuffle,
+
+  /// This pattern converts a vector.extract_strided_slice operation into a
+  /// vector.shuffle operation that has a rank-1 (linearized) operand and
+  /// result.
+  ///
+  /// BEFORE
+  /// %out_nd = vector.extract_strided_slice %source_nd
+  ///         { offsets = [..], strides = [..], sizes = [..] }
+  ///
+  /// AFTER
+  /// %source_1d = vector.shape_cast %source_nd [...]
+  /// %out_1d    = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+  /// %out_nd    = vector.shape_cast %out_1d [...]
+  ///
+  /// `shuffle_indices_1d` is computed using the offsets and sizes of the
+  /// original vector.extract_strided_slice operation.
+  VectorExtractStridedSliceToRankOneShuffle,
+
+  /// BEFORE
+  /// %insert = vector.insert %src %dst [ position ]
+  ///
+  /// AFTER
+  /// %src_1d = vector.shape_cast %src : [...]
+  /// %dst_1d = vector.shape_cast %dst : [...]
+  /// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ]
+  /// %out_nd = vector.shape_cast %out_1d : [...]
+  ///
+  /// `shuffle_indices` is computed from `position`.
+  VectorInsertToRankOneShuffle,
+
+  /// This pattern converts a vector.insert_strided_slice operation into a
+  /// vector.shuffle operation that has rank-1 (linearized) operands and result.
+  ///
+  /// BEFORE
+  /// %0 = vector.insert_strided_slice %to_store, %into
+  ///             {offsets = [1, 0, 0, 0], strides = [1, 1]}
+  ///                  : vector<2x2xi8> into vector<2x1x3x2xi8>
+  /// AFTER
+  /// %to_store_1d
+  ///          = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+  /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+  /// %out_1d  = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+  /// %out_nd  = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+  ///
+  /// where shuffle_indices_1d in this case is
+  ///     [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+  ///                        ^^^^^^^^^^^^^^
+  ///                          to_store_1d
+  VectorInsertStridedSliceToRankOneShuffle,
+
+  /// The number of patterns in this enum.
+  N
+};
+
+/// This class contains functions to control the set of linearization patterns
+/// to include for the conversion, and their priority.
+struct VectorLinearizePatterns {
+
+public:
+  /// By default all patterns are enabled and have benefit 1.
+  VectorLinearizePatterns() {
+    enabled.fill(true);
+    benefits.fill(PatternBenefit(1));
+  }
+
+  /// Add the patterns enabled for the conversion to `patterns`.
+  void addToPatternSet(const TypeConverter &,
+                       RewritePatternSet &patterns) const;
+
+  VectorLinearizePatterns &enable(LinearizePattern id, bool e = true) {
+    enabled[static_cast<unsigned>(id)] = e;
+    return *this;
+  }
+
+  VectorLinearizePatterns &enableAll(bool e = true) {
+    enabled.fill(e);
+    return *this;
+  }
+
+  bool isEnabled(LinearizePattern id) const {
+    return enabled[static_cast<unsigned>(id)];
+  }
+
+  PatternBenefit getBenefit(LinearizePattern id) const {
+    return benefits[static_cast<unsigned>(id)];
+  }
+
+  VectorLinearizePatterns &setBenefit(LinearizePattern id,
+                                      PatternBenefit benefit) {
+    getBenefitRef(id) = benefit;
+    return *this;
+  }
+
+  VectorLinearizePatterns &incrementBenefit(LinearizePattern id,
+                                            unsigned inc = 1) {
+    getBenefitRef(id) = getBenefit(id).getBenefit() + 1;
+    return *this;
+  }
+
+private:
+  std::array<bool, static_cast<unsigned>(LinearizePattern::N)> enabled;
+  std::array<PatternBenefit, static_cast<unsigned>(LinearizePattern::N)>
+      benefits;
+
+  PatternBenefit &getBenefitRef(LinearizePattern id) {
+    unsigned idInt = static_cast<unsigned>(id);
+    assert(idInt < static_cast<unsigned>(LinearizePattern::N) &&
+           "invalid linearization pattern id");
+    return benefits[idInt];
+  }
+};
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumeratates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+SmallVector<int64_t> getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+                                                     ArrayRef<int64_t> large,
+                                                     ArrayRef<int64_t> offsets);
+
+} // namespace vector
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..6954cb7172129 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,39 +406,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-///
-/// Definition: here 'linearization' means converting a single operation with
-/// 1+ vector operand/result of rank>1, into a new single operation whose
-/// vector operands and results are all of rank<=1.
-///
-/// This function registers (1) which operations are legal, and hence should not
-/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
-/// materialze the conversion (with shape_cast)
-///
-/// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, by adding additional
-/// dynamically legal ops to `conversionTarget`.
-///
-/// Further note: the choice to use a dialect conversion design for
-/// linearization is to make it easy to reuse generic structural type
-/// conversions for linearizing scf/cf/func operations
-void populateForVectorLinearize(TypeConverter &typeConverter,
-                                ConversionTarget &conversionTarget);
-
-/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
-/// contains patterns for converting ConstantLike, Vectorizable, and
-/// vector::BitCast ops.
-void populateVectorLinearizeBasePatterns(const TypeConverter &,
-                                         const ConversionTarget &,
-                                         RewritePatternSet &patterns);
-
-/// Populates `patterns` for linearizing ND (N >= 2) vector operations
-/// to 1D vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
-                                                   const ConversionTarget &,
-                                                   RewritePatternSet &patterns);
-
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..2367e6c99a5f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,9 +10,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Operation.h"
@@ -52,7 +53,7 @@ struct LinearizeConstantLike final
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
   LinearizeConstantLike(const TypeConverter &typeConverter,
-                        MLIRContext *context, PatternBenefit benefit = 1)
+                        MLIRContext *context, PatternBenefit benefit)
       : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -94,7 +95,7 @@ struct LinearizeVectorizable final
 
 public:
   LinearizeVectorizable(const TypeConverter &typeConverter,
-                        MLIRContext *context, PatternBenefit benefit = 1)
+                        MLIRContext *context, PatternBenefit benefit)
       : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -109,17 +110,7 @@ struct LinearizeVectorizable final
   }
 };
 
-template <typename TOp>
-static bool stridesAllOne(TOp op) {
-  static_assert(
-      std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
-          std::is_same_v<TOp, vector::InsertStridedSliceOp>,
-      "expected vector.extract_strided_slice or vector.insert_strided_slice");
-  ArrayAttr strides = op.getStrides();
-  return llvm::all_of(strides, isOneInteger);
-}
-
-/// Convert an array of attributes into a vector of integers, if possible.
+/// Convert an array of attributes into a vector of integers.
 static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
   if (!attrs)
     return failure();
@@ -135,89 +126,12 @@ static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
   return ints;
 }
 
-/// Consider inserting a vector of shape `small` into a vector of shape `large`,
-/// at position `offsets`: this function enumeratates all the indices in `large`
-/// that are written to. The enumeration is with row-major ordering.
-///
-/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
-/// positions written to are (1,3) and (1,4), which have linearized indices 8
-/// and 9. So [8,9] is returned.
-///
-/// The length of the returned vector is equal to the number of elements in
-/// the shape `small` (i.e. the product of dimensions of `small`).
-SmallVector<int64_t> static getStridedSliceInsertionIndices(
-    ArrayRef<int64_t> small, ArrayRef<int64_t> large,
-    ArrayRef<int64_t> offsets) {
-
-  // Example of alignment between, `large`, `small` and `offsets`:
-  //    large  =  4, 5, 6, 7, 8
-  //    small  =     1, 6, 7, 8
-  //  offsets  =  2, 3, 0
-  //
-  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
-  assert((large.size() >= small.size()) &&
-         "rank of 'large' cannot be lower than rank of 'small'");
-  assert((large.size() >= offsets.size()) &&
-         "rank of 'large' cannot be lower than the number of offsets");
-  unsigned delta = large.size() - small.size();
-  unsigned nOffsets = offsets.size();
-  auto getSmall = [&](int64_t i) -> int64_t {
-    return i >= delta ? small[i - delta] : 1;
-  };
-  auto getOffset = [&](int64_t i) -> int64_t {
-    return i < nOffsets ? offsets[i] : 0;
-  };
-
-  // Using 2 vectors of indices, at each iteration populate the updated set of
-  // indices based on the old set of indices, and the size of the small vector
-  // in the current iteration.
-  SmallVector<int64_t> indices{0};
-  int64_t stride = 1;
-  for (int i = large.size() - 1; i >= 0; --i) {
-    int64_t currentSize = indices.size();
-    int64_t smallSize = getSmall(i);
-    int64_t nextSize = currentSize * smallSize;
-    SmallVector<int64_t> nextIndices(nextSize);
-    int64_t *base = nextIndices.begin();
-    int64_t offset = getOffset(i) * stride;
-    for (int j = 0; j < smallSize; ++j) {
-      for (int k = 0; k < currentSize; ++k) {
-        base[k] = indices[k] + offset;
-      }
-      offset += stride;
-      base += currentSize;
-    }
-    stride *= large[i];
-    indices = std::move(nextIndices);
-  }
-  return indices;
-}
-
-/// This pattern converts a vector.extract_strided_slice operation into a
-/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
-///
-/// For example, the following:
-///
-/// ```
-///   vector.extract_strided_slice %source
-///         { offsets = [..], strides = [..], sizes = [..] }
-/// ```
-///
-/// is converted to :
-/// ```
-///   %source_1d = vector.shape_cast %source
-///   %out_1d    = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-///   %out_nd    = vector.shape_cast %out_1d
-/// ```
-///
-/// `shuffle_indices_1d` is computed using the offsets and sizes of the original
-/// vector.extract_strided_slice operation.
-struct LinearizeVectorExtractStridedSlice final
-    : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
+struct VectorExtractStridedSliceToRankOneShuffle final
+    : public OpConversionPattern<vector::ExtractStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
-                                     MLIRContext *context,
-                                     PatternBenefit benefit = 1)
+  VectorExtractStridedSliceToRankOneShuffle(const TypeConverter &typeConverter,
+                                            MLIRContext *context,
+                                            PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
@@ -231,7 +145,7 @@ struct LinearizeVectorExtractStridedSlice final
 
     // Expect a legalization failure if the strides are not all 1 (if ever the
     // verifier for extract_strided_slice allows non-1 strides).
-    if (!stridesAllOne(extractStridedSliceOp)) {
+    if (extractStridedSliceOp.hasNonUnitStrides()) {
       return rewriter.notifyMatchFailure(
           extractStridedSliceOp,
           "extract_strided_slice with strides != 1 not supported");
@@ -249,7 +163,7 @@ struct LinearizeVectorExtractStridedSlice final
 
     ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
 
-    SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
+    SmallVector<int64_t> indices = vector::getStridedSliceInsertionIndices(
         outputShape, inputShape, offsets.value());
 
     Value srcVector = adaptor.getVector();
@@ -259,36 +173,24 @@ struct LinearizeVectorExtractStridedSlice final
   }
 };
 
-/// This pattern converts a vector.insert_strided_slice operation into a
-/// vector.shuffle operation that has rank-1 (linearized) operands and result.
-///
-/// For example, the following:
-/// ```
-///  %0 = vector.insert_strided_slice %to_store, %into
-///             {offsets = [1, 0, 0, 0], strides = [1, 1]}
-///                  : vector<2x2xi8> into vector<2x1x3x2xi8>
-/// ```
-///
-/// is converted to
-/// ```
-///  %to_store_1d
-///           = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
-///  %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
-///  %out_1d  = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
-///  %out_nd  = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
-/// ```
-///
-/// where shuffle_indices_1d in this case is
-///     [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
-///                        ^^^^^^^^^^^^^^
-///                          to_store_1d
-///
-struct LinearizeVectorInsertStridedSlice final
-    : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
+static Value asRankOne(ConversionPatternRewriter &rewriter, Value v) {
+  auto vType = dyn_cast<VectorType>(v.getType());
+  assert(vType && "expected vector type");
+  assert(vType.getRank() <= 1 && "expected rank-0 or rank-1 type");
+  if (vType.getRank() == 1)
+    return v;
+  // Convert rank-0 vector to rank-1 vector.
+  v = rewriter.create<vector::ShapeCastOp>(
+      v.getLoc(), VectorType::get({1}, vType.getElementType()), v);
+  return v;
+}
+
+struct VectorInsertStridedSliceToRankOneShuffle final
+    : public OpConversionPattern<vector::InsertStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
-                                    MLIRContext *context,
-                                    PatternBenefit benefit = 1)
+  VectorInsertStridedSliceToRankOneShuffle(const TypeConverter &typeConverter,
+                                           MLIRContext *context,
+                                           PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
@@ -298,7 +200,7 @@ struct LinearizeVectorInsertStridedSlice final
 
     // Expect a legalization failure if the strides are not all 1 (if ever the
     // verifier for insert_strided_slice allows non-1 strides).
-    if (!stridesAllOne(insertStridedSliceOp)) {
+    if (insertStridedSliceOp.hasNonUnitStrides()) {
       return rewriter.notifyMatchFailure(
           insertStridedSliceOp,
           "insert_strided_slice with strides != 1 not supported");
@@ -317,7 +219,7 @@ struct LinearizeVectorInsertStridedSlice final
       return rewriter.notifyMatchFailure(insertStridedSliceOp,
                                          "failed to get integer offsets");
     }
-    SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
+    SmallVector<int64_t> sliceIndices = vector::getStridedSliceInsertionIndices(
         inputShape, outputShape, offsets.value());
 
     SmallVector<int64_t> indices(nOutputElements);
@@ -326,7 +228,7 @@ struct LinearizeVectorInsertStridedSlice final
       indices[sliceIndex] = index + nOutputElements;
     }
 
-    Value flatToStore = adaptor.getValueToStore();
+    Value flatToStore = asRankOne(rewriter, adaptor.getValueToStore());
     Value flatDest = adaptor.getDest();
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
                                                    flatDest.getType(), flatDest,
@@ -335,22 +237,11 @@ struct LinearizeVectorInsertStridedSlice final
   }
 };
 
-/// This pattern converts the ShuffleOp that works on nD (n > 1)
-/// vectors to a ShuffleOp that works on linearized vectors.
-/// Following,
-///   vector.shuffle %v1, %v2 [ shuffle_indices ]
-/// is converted to :
-///   %v1_1d = vector.shape_cast %v1
-///   %v2_1d = vector.shape_cast %v2
-///   %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
-///   %out_nd = vector.shape_cast %out_1d
-// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
-/// of the original shuffle operation.
 struct LinearizeVectorShuffle final
     : public OpConversionPattern<vector::ShuffleOp> {
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorShuffle(const TypeConverter &typeConverter,
-                         MLIRContext *context, PatternBenefit benefit = 1)
+                         MLIRContext *context, PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
@@ -360,8 +251,8 @@ struct LinearizeVectorShuffle final
         getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
     assert(dstType && "vector type destination expected.");
 
-    Value vec1 = adaptor.getV1();
-    Value vec2 = adaptor.getV2();
+    Value vec1 = asRankOne(rewriter, adaptor.getV1());
+    Value vec2 = asRankOne(rewriter, adaptor.getV2());
     int shuffleSliceLen = 1;
     int rank = shuffleOp.getV1().getType().getRank();
 
@@ -395,20 +286,11 @@ struct LinearizeVectorShuffle final
   }
 };
 
-/// This pattern converts the ExtractOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-///   vector.extract %source [ position ]
-/// is converted to :
-///   %source_1d = vector.shape_cast %source
-///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-///   %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original extract.
-struct LinearizeVectorExtract final
+struct VectorExtractToRankOneShuffle final
     : public OpConversionPattern<vector::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorExtract(const TypeConverter &typeConverter,
-                         MLIRContext *context, PatternBenefit benefit = 1)
+  VectorExtractToRankOneShuffle(const TypeConverter &typeConverter,
+                                MLIRContext *context, PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
@@ -436,30 +318,21 @@ struct LinearizeVectorExtract final
       linearizedOffset += offsets[i] * size;
     }
 
+    Value v0 = asRankOne(rewriter, adaptor.getVector());
     llvm::SmallVector<int64_t, 2> indices(size);
     std::iota(indices.begin(), indices.end(), linearizedOffset);
-    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractOp, dstTy, v0, v0,
+                                                   indices);
 
     return success();
   }
 };
 
-/// This pattern converts the InsertOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-///   vector.insert %source %destination [ position ]
-/// is converted to :
-///   %source_1d = vector.shape_cast %source
-///   %destination_1d = vector.shape_cast %destination
-///   %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
-///   ] %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original insert.
-struct LinearizeVectorInsert final
+struct VectorInsertToRankOneShuffle final
     : public OpConversionPattern<vector::InsertOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorInsert(const TypeConverter &typeConverter,
-                        MLIRContext *context, PatternBenefit benefit = 1)
+  VectorInsertToRankOneShuffle(const TypeConverter &typeConverter,
+                               MLIRContext *context, PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
@@ -508,25 +381,18 @@ struct LinearizeVectorInsert final
                                            // [offset+srcNumElements, end)
 
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
+        insertOp, dstTy, adaptor.getDest(),
+        asRankOne(rewriter, adaptor.getValueToStore()), indices);
 
     return success();
   }
 };
 
-/// This pattern converts the BitCastOp that works on nD (n > 1)
-/// vectors to a BitCastOp that works on linearized vectors.
-/// Following,
-///   vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
-/// is converted to :
-///   %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
-///   %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
-///   %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
 struct LinearizeVectorBitCast final
     : public OpConversionPattern<vector::BitCastOp> {
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorBitCast(const TypeConverter &typeConverter,
-                         MLIRContext *context, PatternBenefit benefit = 1)
+                         MLIRContext *context, PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
@@ -535,22 +401,16 @@ struct LinearizeVectorBitCast final
     assert(resType && "expected 1-D vector type");
     rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
                                                    adaptor.getSource());
-    return mlir::success();
+    return success();
   }
 };
 
-/// This pattern converts the SplatOp to work on a linearized vector.
-/// Following,
-///   vector.splat %value : vector<4x4xf32>
-/// is converted to:
-///   %out_1d = vector.splat %value : vector<16xf32>
-///   %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
 struct LinearizeVectorSplat final
     : public OpConversionPattern<vector::SplatOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
-                       PatternBenefit benefit = 1)
+                       PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
@@ -581,7 +441,7 @@ struct LinearizeVectorCreateMask final
   using OpConversionPattern::OpConversionPattern;
 
   LinearizeVectorCreateMask(const TypeConverter &typeConverter,
-                            MLIRContext *context, PatternBenefit benefit = 1)
+                            MLIRContext *context, PatternBenefit benefit)
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
@@ -607,17 +467,16 @@ struct LinearizeVectorCreateMask final
     // The result of the comparison is then multiplied with
     // the second operand of create_mask to get the 1D mask.
     auto firstOperand = adaptor.getOperands().front();
-    auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
-    auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
-        loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
-    auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto isNonZero = rewriter.createOrFold<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sgt, firstOperand, zero);
+    auto isNonZeroIndex = rewriter.createOrFold<arith::IndexCastOp>(
         loc, rewriter.getIndexType(), isNonZero);
     auto secondOperand = adaptor.getOperands().back();
-    auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
+    auto maskSize = rewriter.createOrFold<arith::AndIOp>(
         loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
 
-    auto newMask =
-        rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
+    auto newMask = rewriter.create<vector::CreateMaskOp>(loc, dstTy, maskSize);
     rewriter.replaceOp(createMaskOp, newMask);
     return success();
   }
@@ -625,104 +484,249 @@ struct LinearizeVectorCreateMask final
 
 } // namespace
 
-/// This method defines the set of operations that are linearizable, and hence
-/// that are considered illegal for the conversion target.
-static bool isLinearizable(Operation *op) {
-
-  // Only ops that are in the vector dialect, are ConstantLike, or
-  // are Vectorizable might be linearized currently.
-  StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
-  StringRef opDialect = op->getDialect()->getNamespace();
-  bool supported = (opDialect == vectorDialect) ||
-                   op->hasTrait<OpTrait::ConstantLike>() ||
-                   op->hasTrait<OpTrait::Vectorizable>();
-  if (!supported)
-    return false;
-
+/// Return true if `op` is an insert, extract, insert_strided_slice, or
+/// extract_strided_slice operation that operates on scalable vectors.
+/// Otherwise return false.
+static bool isScalableExtractOrInsertOrStrided(Operation *op) {
   return TypeSwitch<Operation *, bool>(op)
-      // As type legalization is done with vector.shape_cast, shape_cast
-      // itself cannot be linearized (will create new shape_casts to linearize
-      // ad infinitum).
-      .Case<vector::ShapeCastOp>([&](auto) { return false; })
-      // The operations
-      // - vector.extract_strided_slice
-      // - vector.extract
-      // - vector.insert_strided_slice
-      // - vector.insert
-      // are linearized to a rank-1 vector.shuffle by the current patterns.
-      // vector.shuffle only supports fixed size vectors, so it is impossible to
-      // use this approach to linearize these ops if they operate on scalable
-      // vectors.
       .Case<vector::ExtractStridedSliceOp>(
           [&](vector::ExtractStridedSliceOp extractOp) {
-            return !extractOp.getType().isScalable();
+            return extractOp.getType().isScalable();
           })
       .Case<vector::InsertStridedSliceOp>(
           [&](vector::InsertStridedSliceOp insertOp) {
-            return !insertOp.getType().isScalable();
+            return insertOp.getType().isScalable();
           })
       .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
-        return !insertOp.getType().isScalable();
+        return insertOp.getType().isScalable();
       })
       .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
-        return !extractOp.getSourceVectorType().isScalable();
+        return extractOp.getSourceVectorType().isScalable();
       })
-      .Default([&](auto) { return true; });
+      .Default([&](auto) { return false; });
 }
 
-void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
-                                              ConversionTarget &target) {
+SmallVector<int64_t>
+vector::getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+                                        ArrayRef<int64_t> large,
+                                        ArrayRef<int64_t> offsets) {
+
+  // Example of alignment between, `large`, `small` and `offsets`:
+  //    large  =  4, 5, 6, 7, 8
+  //    small  =     1, 6, 7, 8
+  //  offsets  =  2, 3, 0
+  //
+  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
+  assert((large.size() >= small.size()) &&
+         "rank of 'large' cannot be lower than rank of 'small'");
+  assert((large.size() >= offsets.size()) &&
+         "rank of 'large' cannot be lower than the number of offsets");
+  unsigned delta = large.size() - small.size();
+  unsigned nOffsets = offsets.size();
+  auto getSmall = [&](int64_t i) -> int64_t {
+    return i >= delta ? small[i - delta] : 1;
+  };
+  auto getOffset = [&](int64_t i) -> int64_t {
+    return i < nOffsets ? offsets[i] : 0;
+  };
+
+  // Using 2 vectors of indices, at each iteration populate the updated set of
+  // indices based on the old set of indices, and the size of the small vector
+  // in the current iteration.
+  SmallVector<int64_t> indices{0};
+  int64_t stride = 1;
+  for (int i = large.size() - 1; i >= 0; --i) {
+    int64_t currentSize = indices.size();
+    int64_t smallSize = getSmall(i);
+    int64_t nextSize = currentSize * smallSize;
+    SmallVector<int64_t> nextIndices(nextSize);
+    int64_t *base = nextIndices.begin();
+    int64_t offset = getOffset(i) * stride;
+    for (int j = 0; j < smallSize; ++j) {
+      for (int k = 0; k < currentSize; ++k) {
+        base[k] = indices[k] + offset;
+      }
+      offset += stride;
+      base += currentSize;
+    }
+    stride *= large[i];
+    indices = std::move(nextIndices);
+  }
+  return indices;
+}
+
+void vector::initializeForVectorLinearize(TypeConverter &typeConverter) {
 
   auto convertType = [](Type type) -> std::optional<Type> {
     VectorType vectorType = dyn_cast<VectorType>(type);
-    if (!vectorType || !isLinearizableVector(vectorType))
+
+    if (!vectorType || !vector::isLinearizableVector(vectorType))
       return type;
 
     VectorType linearizedType =
         VectorType::get(vectorType.getNumElements(),
                         vectorType.getElementType(), vectorType.isScalable());
+
     return linearizedType;
   };
   typeConverter.addConversion(convertType);
 
   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
                             Location loc) -> Value {
-    if (inputs.size() != 1)
+    if (inputs.size() != 1) {
       return nullptr;
-
+    }
     Value value = inputs.front();
-    if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
+    if (!isa<VectorType>(type) || !isa<VectorType>(value.getType())) {
       return nullptr;
-
+    }
     return builder.create<vector::ShapeCastOp>(loc, type, value);
   };
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
+}
+
+void vector::populateForFullVectorLinearize(const TypeConverter &typeConverter,
+                                            ConversionTarget &target,
+                                            RewritePatternSet &patterns) {
 
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if (!isLinearizable(op))
+        // Only ops that are in the vector dialect, are ConstantLike, or
+        // are Vectorizable might be linearized currently.
+        StringLiteral vectorDialect =
+            vector::VectorDialect::getDialectNamespace();
+        StringRef opDialect = op->getDialect()->getNamespace();
+        bool supported = (opDialect == vectorDialect) ||
+                         op->hasTrait<OpTrait::ConstantLike>() ||
+                         op->hasTrait<OpTrait::Vectorizable>();
+        if (!supported)
+          return true;
+
+        // As type legalization is done with vector.shape_cast, shape_cast
+        // itself cannot be linearized (doing so would create new shape_casts to
+        // linearize ad infinitum).
+        if (isa<vector::ShapeCastOp>(op))
+          return true;
+
+        // The operations extract_strided_slice, extract, insert_strided_slice,
+        // and insert are linearized to a rank-1 operations that do not fully
+        // support scalable vectors, so it is not generally possible to
+        // linearize these ops if they operate on scalable vectors.
+        if (isScalableExtractOrInsertOrStrided(op))
           return true;
+
         // This will return true if, for all operand and result types `t`,
         // convertType(t) = t. This is true if there are no rank>=2 vectors.
         return typeConverter.isLegal(op);
       });
-}
 
-void mlir::vector::populateVectorLinearizeBasePatterns(
-    const TypeConverter &typeConverter, const ConversionTarget &target,
-    RewritePatternSet &patterns) {
-  patterns
-      .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
-           LinearizeVectorSplat, LinearizeVectorCreateMask>(
-          typeConverter, patterns.getContext());
+  VectorLinearizePatterns linearizePatterns;
+
+  // Mark extract_strided_slice, insert_strided_slice, extract with source
+  // rank > 1, and insert with result rank > 1 as illegal, as they must be
+  // converted to shuffle or rank-1 extract/insert.
+  //
+  // Note that the order of the calls to `markUnknownOpDynamicallyLegal`
+  // is important: the legality rule added here takes precedence over the
+  // generic one preceding it which marked these ops as legal.
+  target.markUnknownOpDynamicallyLegal(
+      [](Operation *op) -> std::optional<bool> {
+        bool isStrided =
+            isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp>(
+                op);
+
+        bool isHighRankExtractOrInsert = [&]() {
+          if (auto extractOp = dyn_cast<vector::ExtractOp>(op)) {
+            return extractOp.getSourceVectorType().getRank() > 1;
+          }
+          if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+            return insertOp.getType().getRank() > 1;
+          }
+          return false;
+        }();
+
+        bool isScalable = isScalableExtractOrInsertOrStrided(op);
+
+        if ((isStrided || isHighRankExtractOrInsert) && !isScalable) {
+          return false;
+        }
+        return std::nullopt;
+      });
+
+  // Ensure that the benefit of patterns targetting shuffle is higher than
+  // the benefit of patterns targeting rank-1 strided slice operations. This
+  // will ensure that patterns for converting to rank-1 shuffle are run first.
+  linearizePatterns
+      .incrementBenefit(
+          LinearizePattern::VectorExtractStridedSliceToRankOneShuffle)
+      .incrementBenefit(
+          LinearizePattern::VectorInsertStridedSliceToRankOneShuffle)
+      .incrementBenefit(LinearizePattern::VectorExtractToRankOneShuffle)
+      .incrementBenefit(LinearizePattern::VectorInsertToRankOneShuffle);
+
+  linearizePatterns.addToPatternSet(typeConverter, patterns);
 }
 
-void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
-    const TypeConverter &typeConverter, const ConversionTarget &target,
-    RewritePatternSet &patterns) {
-  patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
-               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
-               LinearizeVectorInsertStridedSlice>(typeConverter,
-                                                  patterns.getContext());
+void vector::VectorLinearizePatterns::addToPatternSet(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns) const {
+
+  MLIRContext *context = patterns.getContext();
+
+  if (isEnabled(LinearizePattern::LinearizeConstantLike))
+    patterns.add<LinearizeConstantLike>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeConstantLike));
+
+  if (isEnabled(LinearizePattern::LinearizeVectorizable))
+    patterns.add<LinearizeVectorizable>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeVectorizable));
+
+  if (isEnabled(LinearizePattern::LinearizeVectorBitCast))
+    patterns.add<LinearizeVectorBitCast>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeVectorBitCast));
+
+  if (isEnabled(LinearizePattern::LinearizeVectorCreateMask))
+    patterns.add<LinearizeVectorCreateMask>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeVectorCreateMask));
+
+  if (isEnabled(LinearizePattern::LinearizeVectorShuffle))
+    patterns.add<LinearizeVectorShuffle>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeVectorShuffle));
+
+  if (isEnabled(LinearizePattern::LinearizeVectorSplat))
+    patterns.add<LinearizeVectorSplat>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::LinearizeVectorSplat));
+
+  // ------------------------ //
+  // Extract related patterns //
+  // ------------------------ //
+  if (isEnabled(LinearizePattern::VectorExtractToRankOneShuffle))
+    patterns.add<VectorExtractToRankOneShuffle>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::VectorExtractToRankOneShuffle));
+
+  if (isEnabled(LinearizePattern::VectorExtractStridedSliceToRankOneShuffle))
+    patterns.add<VectorExtractStridedSliceToRankOneShuffle>(
+        typeConverter, context,
+        getBenefit(
+            LinearizePattern::VectorExtractStridedSliceToRankOneShuffle));
+
+  // ------------------------ //
+  // Insert related patterns  //
+  // ------------------------ //
+  if (isEnabled(LinearizePattern::VectorInsertToRankOneShuffle))
+    patterns.add<VectorInsertToRankOneShuffle>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::VectorInsertToRankOneShuffle));
+
+  if (isEnabled(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle))
+    patterns.add<VectorInsertStridedSliceToRankOneShuffle>(
+        typeConverter, context,
+        getBenefit(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle));
 }
diff --git a/mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir b/mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir
similarity index 100%
rename from mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir
rename to mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize/linearize.mlir
similarity index 100%
rename from mlir/test/Dialect/Vector/linearize.mlir
rename to mlir/test/Dialect/Vector/linearize/linearize.mlir
diff --git a/mlir/test/lib/Dialect/Vector/CMakeLists.txt b/mlir/test/lib/Dialect/Vector/CMakeLists.txt
index e16937029ac0e..1ce069599af43 100644
--- a/mlir/test/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Vector/CMakeLists.txt
@@ -1,6 +1,7 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRVectorTestPasses
   TestVectorTransforms.cpp
+  TestVectorLinearize.cpp
 
   EXCLUDE_FROM_LIBMLIR
   )
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp
new file mode 100644
index 0000000000000..67179c9f98e9b
--- /dev/null
+++ b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp
@@ -0,0 +1,185 @@
+//===- TestVectorLinearize.cpp - Test Vector linearization ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <optional>
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math//IR/Math.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+struct TestVectorLinearize final
+    : public PassWrapper<TestVectorLinearize, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+
+  TestVectorLinearize() = default;
+
+  StringRef getArgument() const override { return "test-vector-linearize"; }
+  StringRef getDescription() const override {
+    return "Linearizes ND vectors for N >= 2 into 1D vectors";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<VectorDialect, arith::ArithDialect, math::MathDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    TypeConverter converter;
+    RewritePatternSet patterns(&context);
+    ConversionTarget target(context);
+    initializeForVectorLinearize(converter);
+    populateForFullVectorLinearize(converter, target, patterns);
+
+    mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+        converter, patterns, target);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+struct TestVectorBitWidthLinearize final
+    : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
+
+  TestVectorBitWidthLinearize() = default;
+  TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
+      : PassWrapper(pass) {}
+
+  StringRef getArgument() const override {
+    return "test-bit-width-constrained-vector-linearize";
+  }
+  StringRef getDescription() const override {
+    return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
+           "on inner-most dimension's bit width. If the inner-most dimension "
+           "exceded a threshold, the op is not linearized.";
+  }
+  Option<unsigned> targetVectorBitwidth{
+      *this, "target-vector-bitwidth",
+      llvm::cl::desc(
+          "Minimum vector bitwidth to enable the flattening transformation"),
+      llvm::cl::init(std::numeric_limits<unsigned>::max())};
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    TypeConverter typeConverter;
+    RewritePatternSet patterns(&context);
+    ConversionTarget target(context);
+    populateWithBitWidthConstraints(typeConverter, target, patterns,
+                                    targetVectorBitwidth);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+
+private:
+  /// If `type` is VectorType with trailing dimension of (bit) size greater than
+  /// or equal to `targetBitWidth`, its defining op is considered legal.
+  static bool
+  isNotLinearizableBecauseLargeInnerDimension(Type type,
+                                              unsigned targetBitWidth) {
+
+    VectorType vecType = dyn_cast<VectorType>(type);
+
+    // Not linearizable for reasons other than what this function checks.
+    if (!vecType || vecType.getRank() == 0)
+      return false;
+
+    // The width of the type 'index' is unbounded (and therefore potentially
+    // above the target width).
+    if (vecType.getElementType().isIndex())
+      return true;
+
+    unsigned finalDimSize = vecType.getShape().back();
+    unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
+    unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
+    return trailingVecDimBitWidth >= targetBitWidth;
+  }
+
+  static bool
+  isNotLinearizableBecauseLargeInnerDimension(Operation *op,
+                                              unsigned targetBitWidth) {
+    // Check on bitwidths.
+    SmallVector<std::pair<Type, unsigned>> toCheck =
+        getTypeBitWidthBoundPairs(op, targetBitWidth);
+    return std::any_of(toCheck.begin(), toCheck.end(),
+                       [&](std::pair<Type, unsigned> typeWidth) {
+                         return isNotLinearizableBecauseLargeInnerDimension(
+                             typeWidth.first, typeWidth.second);
+                       });
+  }
+
+  static void populateWithBitWidthConstraints(TypeConverter &typeConverter,
+                                              ConversionTarget &target,
+                                              RewritePatternSet &patterns,
+                                              unsigned targetBitWidth) {
+
+    initializeForVectorLinearize(typeConverter);
+    populateForFullVectorLinearize(typeConverter, target, patterns);
+
+    // Extend the set of legal ops to include those with large inner-most
+    // dimensions on selected operands/results.
+    target.markUnknownOpDynamicallyLegal(
+        [=](Operation *op) -> std::optional<bool> {
+          if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
+            return true;
+          }
+          return {};
+        });
+  }
+
+  /// Get the set of operand/result types to check for sufficiently
+  /// small inner-most dimension size.
+  static SmallVector<std::pair<Type, unsigned>>
+  getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
+
+    if (auto insertOp = dyn_cast<InsertOp>(op)) {
+      unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
+                       ? targetBitWidth + 1
+                       : targetBitWidth;
+      return {{insertOp.getValueToStoreType(), w}};
+    }
+
+    auto resultTypes = op->getResultTypes();
+    SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+    resultsWithBitWidth.reserve(resultTypes.size());
+    for (Type type : resultTypes) {
+      resultsWithBitWidth.push_back({type, targetBitWidth});
+    }
+    return resultsWithBitWidth;
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+extern void registerTestVectorLinearize() {
+  PassRegistration<TestVectorLinearize>();
+  PassRegistration<TestVectorBitWidthLinearize>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f4f32e9339870..5c75d32c22236 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -837,160 +836,6 @@ struct TestVectorEmulateMaskedLoadStore final
   }
 };
 
-/// Get the set of operand/result types to check for sufficiently
-/// small inner-most dimension size.
-static SmallVector<std::pair<Type, unsigned>>
-getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
-
-  if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
-    unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
-                     ? targetBitWidth + 1
-                     : targetBitWidth;
-    return {{insertOp.getValueToStoreType(), w}};
-  }
-
-  auto resultTypes = op->getResultTypes();
-  SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
-  resultsWithBitWidth.reserve(resultTypes.size());
-  for (Type type : resultTypes) {
-    resultsWithBitWidth.push_back({type, targetBitWidth});
-  }
-  return resultsWithBitWidth;
-}
-
-/// If `type` is VectorType with trailing dimension of (bit) size greater than
-/// or equal to `targetBitWidth`, its defining op is considered legal.
-static bool
-isNotLinearizableBecauseLargeInnerDimension(Type type,
-                                            unsigned targetBitWidth) {
-
-  VectorType vecType = dyn_cast<VectorType>(type);
-
-  // Not linearizable for reasons other than what this function checks.
-  if (!vecType || vecType.getRank() == 0)
-    return false;
-
-  // The width of the type 'index' is unbounded (and therefore potentially above
-  // the target width).
-  if (vecType.getElementType().isIndex())
-    return true;
-
-  unsigned finalDimSize = vecType.getShape().back();
-  unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
-  unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
-  return trailingVecDimBitWidth >= targetBitWidth;
-}
-
-static bool
-isNotLinearizableBecauseLargeInnerDimension(Operation *op,
-                                            unsigned targetBitWidth) {
-  // Check on bitwidths.
-  SmallVector<std::pair<Type, unsigned>> toCheck =
-      getTypeBitWidthBoundPairs(op, targetBitWidth);
-  return llvm::any_of(toCheck, [&](std::pair<Type, unsigned> typeWidth) {
-    return isNotLinearizableBecauseLargeInnerDimension(typeWidth.first,
-                                                       typeWidth.second);
-  });
-}
-
-void populateWithBitWidthConstraints(TypeConverter &typeConverter,
-                                     ConversionTarget &target,
-                                     unsigned targetBitWidth) {
-
-  // The general purpose definition of what ops are legal must come first.
-  populateForVectorLinearize(typeConverter, target);
-
-  // Extend the set of legal ops to include those with large inner-most
-  // dimensions on selected operands/results.
-  target.markUnknownOpDynamicallyLegal(
-      [=](Operation *op) -> std::optional<bool> {
-        if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
-          return true;
-        }
-        return {};
-      });
-}
-
-struct TestVectorBitWidthLinearize final
-    : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
-
-  TestVectorBitWidthLinearize() = default;
-  TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
-      : PassWrapper(pass) {}
-
-  StringRef getArgument() const override {
-    return "test-bit-width-constrained-vector-linearize";
-  }
-  StringRef getDescription() const override {
-    return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
-           "in inner-most dimension's bit width.";
-  }
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<vector::VectorDialect>();
-  }
-
-  Option<unsigned> targetVectorBitwidth{
-      *this, "target-vector-bitwidth",
-      llvm::cl::desc(
-          "Minimum vector bitwidth to enable the flattening transformation"),
-      llvm::cl::init(std::numeric_limits<unsigned>::max())};
-  void runOnOperation() override {
-    auto *context = &getContext();
-
-    TypeConverter typeConverter;
-    RewritePatternSet patterns(context);
-    ConversionTarget target(*context);
-
-    populateWithBitWidthConstraints(typeConverter, target,
-                                    targetVectorBitwidth);
-
-    vector::populateVectorLinearizeBasePatterns(typeConverter, target,
-                                                patterns);
-
-    vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target,
-                                                          patterns);
-
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      return signalPassFailure();
-  }
-};
-
-struct TestVectorLinearize final
-    : public PassWrapper<TestVectorLinearize, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
-
-  TestVectorLinearize() = default;
-
-  StringRef getArgument() const override { return "test-vector-linearize"; }
-  StringRef getDescription() const override {
-    return "Linearizes ND vectors for N >= 2 into 1D vectors";
-  }
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<vector::VectorDialect, arith::ArithDialect>();
-  }
-
-  void runOnOperation() override {
-    MLIRContext &context = getContext();
-    TypeConverter converter;
-    RewritePatternSet patterns(&context);
-    ConversionTarget target(context);
-
-    vector::populateForVectorLinearize(converter, target);
-
-    vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
-    vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
-                                                          patterns);
-    mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
-        converter, patterns, target);
-
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      return signalPassFailure();
-  }
-};
-
 struct TestEliminateVectorMasks
     : public PassWrapper<TestEliminateVectorMasks,
                          OperationPass<func::FuncOp>> {
@@ -1062,10 +907,6 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorEmulateMaskedLoadStore>();
 
-  PassRegistration<TestVectorLinearize>();
-
-  PassRegistration<TestVectorBitWidthLinearize>();
-
   PassRegistration<TestEliminateVectorMasks>();
 }
 } // namespace test
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2e08ae6f37980..f52f36107e301 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -155,6 +155,7 @@ void registerTestTopologicalSortAnalysisPass();
 void registerTestTransformDialectEraseSchedulePass();
 void registerTestPassStateExtensionCommunication();
 void registerTestVectorLowerings();
+void registerTestVectorLinearize();
 void registerTestVectorReductionToSPIRVDotProd();
 void registerTestVulkanRunnerPipeline();
 void registerTestWrittenToPass();
@@ -300,6 +301,7 @@ void registerTestPasses() {
   mlir::test::registerTestTransformDialectEraseSchedulePass();
   mlir::test::registerTestPassStateExtensionCommunication();
   mlir::test::registerTestVectorLowerings();
+  mlir::test::registerTestVectorLinearize();
   mlir::test::registerTestVectorReductionToSPIRVDotProd();
   mlir::test::registerTestVulkanRunnerPipeline();
   mlir::test::registerTestWrittenToPass();

>From 68183789137dad8c832fb9b889875d07e21d869a Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 3 Jun 2025 15:53:55 -0700
Subject: [PATCH 2/2] typos

---
 .../Vector/Transforms/VectorLinearize.h       | 28 +++++++++----------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
index cd62de640d088..af5ce2103f774 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
@@ -16,27 +16,27 @@
 namespace mlir {
 namespace vector {
 
-/// Initialize `typeConverter` with source and target materialization logic
-/// using shape_casts to/from 1D vectors.
+/// Initialize `typeConverter` with source and target materializations that
+/// use shape_casts to/from 1D vectors.
 void initializeForVectorLinearize(TypeConverter &typeConverter);
 
-/// Initialize `conversionTarget`, and `patterns` for linearization. Here
+/// Initialize `conversionTarget` and `patterns` for linearization. Here
 /// linearization means converting a single operation with 1+ vector
 /// operand/result of rank>1, into a new single operation whose vector operands
-/// and results are all of rank<=1.
+/// and results are all rank<=1.
 ///
-/// This function initializes `conversionTarget` with the set of operations that
-/// are illegal and consequently must be converted to a linearized form. It
-/// also populates the set of patterns that can be run to convert illegal
-/// operations, and what priority/benefit they have.
+/// This function initializes `conversionTarget` with a definition of which
+/// operations are illegal and consequently must be converted to a linearized
+/// (legal) form. It also populates `patterns` with the patterns that will be
+/// run to convert illegal operations, and what sets what priority/benefit they
+/// have.
 ///
-/// Note: the set of legal operations can be extended by a user if, for example,
-/// certain rank>1 vectors are considered valid, by adding additional
-/// dynamically legal ops to `conversionTarget`.
+/// Note: the set of legal operations can be extended by a user by adding
+/// additional legality rules to `conversionTarget`.
 ///
 /// Further note: the choice to use a dialect conversion design for
-/// linearization is to make it easy to reuse generic structural type
-/// conversions for linearizing scf/cf/func operations
+/// linearization is to enable reuse of generic structural type conversions for
+/// linearizing scf/cf/func operations.
 void populateForFullVectorLinearize(const TypeConverter &,
                                     ConversionTarget &conversionTarget,
                                     RewritePatternSet &patterns);
@@ -233,7 +233,7 @@ struct VectorLinearizePatterns {
 };
 
 /// Consider inserting a vector of shape `small` into a vector of shape `large`,
-/// at position `offsets`: this function enumeratates all the indices in `large`
+/// at position `offsets`: this function enumerates all the indices in `large`
 /// that are written to. The enumeration is with row-major ordering.
 ///
 /// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2



More information about the Mlir-commits mailing list