[Mlir-commits] [mlir] [Will be divided up] Vector linearization via strided ops (PR #142672)
James Newling
llvmlistbot at llvm.org
Tue Jun 3 14:07:03 PDT 2025
https://github.com/newling created https://github.com/llvm/llvm-project/pull/142672
This is the end point of a few changes I have made to vector linearization. I plan to split this into multiple PRs, but posting here as a reference.
Changes/improvements:
- Make all patterns optional
- Option to lower to 1D strided ops (although 1D shuffle remains default)
- Support for scalar insert/extract
>From 6ad62b41cdc7350cc28593424c3d41e5733ed1ef Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 3 Jun 2025 13:56:03 -0700
Subject: [PATCH] squash commits
---
.../Vector/Transforms/VectorLinearize.h | 354 +++++++
.../Vector/Transforms/VectorRewritePatterns.h | 33 -
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 20 +-
.../Vector/Transforms/VectorLinearize.cpp | 967 +++++++++++++-----
.../linearize-insert-extract-preference.mlir | 287 ++++++
.../linearize-subject-to-bitwidth.mlir | 0
.../Vector/{ => linearize}/linearize.mlir | 3 +-
.../linearize/rank-reduce-strided-ops.mlir | 135 +++
mlir/test/lib/Dialect/Vector/CMakeLists.txt | 1 +
.../Dialect/Vector/TestVectorLinearize.cpp | 248 +++++
.../Dialect/Vector/TestVectorTransforms.cpp | 159 ---
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
12 files changed, 1771 insertions(+), 438 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
create mode 100644 mlir/test/Dialect/Vector/linearize/linearize-insert-extract-preference.mlir
rename mlir/test/Dialect/Vector/{ => linearize}/linearize-subject-to-bitwidth.mlir (100%)
rename mlir/test/Dialect/Vector/{ => linearize}/linearize.mlir (99%)
create mode 100644 mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir
create mode 100644 mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
new file mode 100644
index 0000000000000..6fc5bb13f9b07
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
@@ -0,0 +1,354 @@
+//===- VectorLinearize.h - Vector linearization patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace vector {
+
+/// Initialize `typeConverter` with source and target materialization logic
+/// using shape_casts to/from 1D vectors.
+void initializeForVectorLinearize(TypeConverter &typeConverter);
+
+/// This enum controls the patterns used for linearization of insert,
+/// insert_strided_slice, extract, and extract_strided_slice operations.
+enum class InsertExtractLinearizePreference {
+
+ /// The lowerings are
+ /// insert, insert_strided_slice -> 1D shuffle
+ /// extract, extract_strided_slice -> 1D shuffle
+ ///
+ /// Even 1D insert_strided_slice and extract_strided_slice are converted to 1D
+ /// shuffles. Insert and extract ops on scalar elements are not converted to
+ /// 1D shuffles.
+ Shuffle = 0,
+
+ /// The preferred lowerings are
+ /// insert, insert_strided_slice -> 1D insert_strided_slice
+ /// extract, extract_strided_slice -> 1D extract_strided_slice
+ ///
+ /// When these lowerings are not possible because the slices are not
+ /// contiguous, 1D shuffles are used.
+ Strided
+};
+
+/// Initialize `conversionTarget`, and `patterns` for linearization. Here
+/// linearization means converting a single operation with 1+ vector
+/// operand/result of rank>1, into a new single operation whose vector operands
+/// and results are all of rank<=1.
+///
+/// This function initializes `conversionTarget` with the set of operations that
+/// are illegal and consequently must be converted to a linearized form. It
+/// also populates the set of patterns that can be run to convert illegal
+/// operations, and what priority/benefit they have. The patterns and legality
+/// rules depend on `preference`, which controls the benefit associated to the
+/// patterns based on whether 1D shuffles or 1D strided ops are preferred.
+///
+/// Note: the set of legal operations can be extended by a user if, for example,
+/// certain rank>1 vectors are considered valid, by adding additional
+/// dynamically legal ops to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to make it easy to reuse generic structural type
+/// conversions for linearizing scf/cf/func operations
+void populateForFullVectorLinearize(
+ const TypeConverter &, ConversionTarget &conversionTarget,
+ RewritePatternSet &patterns,
+ InsertExtractLinearizePreference preference =
+ InsertExtractLinearizePreference::Shuffle);
+
+enum class LinearizePattern {
+
+ /// BEFORE
+ /// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+ ///
+ /// AFTER
+ /// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+ /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+ LinearizeConstantLike = 0,
+
+ /// BEFORE
+ /// %2 = math.sin %arg0 : vector<2x2xf32>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+ /// %1 = math.sin %0 : vector<4xf32>
+ /// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+ LinearizeVectorizable,
+
+ /// BEFORE
+ /// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+ /// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+ /// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+ LinearizeVectorBitCast,
+
+ /// This pattern currently only supports 2D masks with a unit outer
+ /// dimension.
+ ///
+ /// BEFORE
+ /// %mask_2d = vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+ ///
+ /// AFTER
+ /// [...]
+ /// %mask_1d= vector.create_mask %mul : vector<4xi1>
+ /// %mask_2d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+ ///
+ /// where `%mul` is a function of `%arg0` and `%arg1`.
+ LinearizeVectorCreateMask,
+
+ /// BEFORE
+ /// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+ ///
+ /// AFTER
+ /// %v1_1d = vector.shape_cast %v1_3d : [...]
+ /// %v2_1d = vector.shape_cast %v2_3d : [...]
+ /// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+ /// %shuffle_3d = vector.shape_cast %shuffle_1d : [...]
+ ///
+ /// Where `shuffle_indices_1d` are computed by expanding `shuffle_indices`.
+ LinearizeVectorShuffle,
+
+ /// BEFORE
+ /// %1 = vector.splat %value : vector<4x4xf32>
+ ///
+ /// AFTER
+ /// %0 = vector.splat %value : vector<16xf32>
+ /// %1 = vector.shape_cast %0 : vector<16xf32> to vector<4x4xf32>
+ LinearizeVectorSplat,
+
+ /// Reduce the rank of a vector.extract_strided_slice to the lowest rank
+ /// possible. For extract_strided_slice ops that slice contiguous elements,
+ /// the reduced-rank op is 1D, otherwise it is higher dimensional.
+ ///
+ /// BEFORE
+ /// %2 = vector.extract_strided_slice %arg0 {
+ /// offsets = [1, 0, 1, 0],
+ /// sizes = [1, 2, 1, 2],
+ /// strides = [1, 1, 1, 1]} : vector<2x2x2x2xi8> to vector<1x2x1x2xi8>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<2x2x2x2xi8> to vector<4x4xi8>
+ /// %1 = vector.extract_strided_slice %0 {
+ /// offsets = [2, 2],
+ /// sizes = [2, 2],
+ /// strides = [1, 1]} : vector<4x4xi8> to vector<2x2xi8>
+ /// %2 = vector.shape_cast %1 : vector<2x2xi8> to vector<1x2x1x2xi8>
+ RankReduceExtractStridedSlice,
+
+ /// Similar to RankReduceExtractStridedSlice, but both the operands have
+ /// their rank reduced.
+ ///
+ /// BEFORE
+ /// %3 = vector.insert_strided_slice %arg1, %arg0 {[...]}
+ /// vector<1x2x1x2xi8> into vector<2x2x2x2xi8>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<2x2x2x2xi8> to vector<4x4xi8>
+ /// %1 = vector.shape_cast %arg1 : vector<1x2x1x2xi8> to vector<2x2xi8>
+ /// %2 = vector.insert_strided_slice %1, %0 {[...]}
+ /// %3 = vector.shape_cast %2 : vector<4x4xi8> to vector<2x2x2x2xi8>
+ RankReduceInsertStridedSlice,
+
+ /// BEFORE
+ /// %extract = vector.extract %src [ position ]
+ ///
+ /// AFTER
+ /// %src_1d = vector.shape_cast %src : [...]
+ /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices ]
+ /// %out_nd = vector.shape_cast %out_1d : [...]
+ ///
+ /// `shuffle_indices` is computed from `position`.
+ VectorExtractToRankOneShuffle,
+
+ /// BEFORE
+ /// %out_nd = vector.extract_strided_slice %source_nd
+ /// { offsets = [..], strides = [..], sizes = [..] }
+ ///
+ /// AFTER
+ /// %source_1d = vector.shape_cast %source_nd [...]
+ /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+ /// %out_nd = vector.shape_cast %out_1d [...]
+ ///
+ /// `shuffle_indices_1d` is computed using the offsets and sizes of the
+ /// original vector.extract_strided_slice operation.
+ VectorExtractStridedSliceToRankOneShuffle,
+
+ /// BEFORE
+ /// %1 = vector.extract %arg0[1, 2] : vector<2x1xi8> from vector<4x3x2x1xi8>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<4x3x2x1xi8> to vector<24xi8>
+ /// %1 = vector.extract_strided_slice %0 {offsets = [10], sizes = [2] [...]
+ /// %2 = vector.shape_cast %1 : vector<2xi8> to vector<2x1xi8>
+ VectorExtractToRankOneStrided,
+
+ /// BEFORE
+ /// %insert = vector.insert %src %dst [ position ]
+ ///
+ /// AFTER
+ /// %src_1d = vector.shape_cast %src : [...]
+ /// %dst_1d = vector.shape_cast %dst : [...]
+ /// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ]
+ /// %out_nd = vector.shape_cast %out_1d : [...]
+ ///
+ /// `shuffle_indices` is computed from `position`.
+ VectorInsertToRankOneShuffle,
+
+ /// This pattern converts a vector.insert_strided_slice operation into a
+ /// vector.shuffle operation that has rank-1 (linearized) operands and result.
+ ///
+ /// BEFORE
+ /// %0 = vector.insert_strided_slice %to_store, %into
+ /// {offsets = [1, 0, 0, 0], strides = [1, 1]}
+ /// : vector<2x2xi8> into vector<2x1x3x2xi8>
+ /// AFTER
+ /// %to_store_1d
+ /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+ /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+ /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+ /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+ ///
+ /// where shuffle_indices_1d in this case is
+ /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+ /// ^^^^^^^^^^^^^^
+ /// to_store_1d
+ VectorInsertStridedSliceToRankOneShuffle,
+
+ /// Similar to VectorExtractToRankOneStrided, but for insert_strided_slice.
+ VectorInsertToRankOneStrided,
+
+ /// The number of patterns in this enum.
+ N
+};
+
+/// This class contains functions to control the set of linearization patterns
+/// to include for the conversion, and their priority.
+struct VectorLinearizePatterns {
+
+public:
+ /// By default all patterns are enabled and have benefit 1.
+ VectorLinearizePatterns() {
+ enabled.fill(true);
+ benefits.fill(PatternBenefit(1));
+ }
+
+ /// Add the patterns enabled for the conversion to `patterns`.
+ void addToPatternSet(const TypeConverter &,
+ RewritePatternSet &patterns) const;
+
+ VectorLinearizePatterns &enable(LinearizePattern id, bool e = true) {
+ enabled[static_cast<unsigned>(id)] = e;
+ return *this;
+ }
+
+ VectorLinearizePatterns &enableAll(bool e = true) {
+ enabled.fill(e);
+ return *this;
+ }
+
+ bool isEnabled(LinearizePattern id) const {
+ return enabled[static_cast<unsigned>(id)];
+ }
+
+ PatternBenefit getBenefit(LinearizePattern id) const {
+ return benefits[static_cast<unsigned>(id)];
+ }
+
+ VectorLinearizePatterns &setBenefit(LinearizePattern id,
+ PatternBenefit benefit) {
+ getBenefitRef(id) = benefit;
+ return *this;
+ }
+
+ VectorLinearizePatterns &incrementBenefit(LinearizePattern id,
+ unsigned inc = 1) {
+ getBenefitRef(id) = getBenefit(id).getBenefit() + 1;
+ return *this;
+ }
+
+private:
+ std::array<bool, static_cast<unsigned>(LinearizePattern::N)> enabled;
+ std::array<PatternBenefit, static_cast<unsigned>(LinearizePattern::N)>
+ benefits;
+
+ PatternBenefit &getBenefitRef(LinearizePattern id) {
+ unsigned idInt = static_cast<unsigned>(id);
+ assert(idInt < static_cast<unsigned>(LinearizePattern::N) &&
+ "invalid linearization pattern id");
+ return benefits[idInt];
+ }
+};
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumeratates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+SmallVector<int64_t> getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+ ArrayRef<int64_t> large,
+ ArrayRef<int64_t> offsets);
+
+/// Return the strided slice with the lowest rank that is equivalent to the
+/// strided slice of `small` from `large`, starting at `offsets`. The result is
+/// a tuple of three vectors:
+///
+/// 0) The shape of the new small vector.
+/// 1) The shape of the new large vector.
+/// 2) The offsets of the new large vector.
+///
+/// Example 1 (contiguous slices can always be represented in 1-D).
+///
+/// Input:
+/// small = (1, 3, 4)
+/// large = (3, 3, 4)
+/// offset = (2, 3, 4)
+///
+/// Output:
+/// small = (12)
+/// large = (36)
+/// offset = (24)
+///
+/// Example 2 (a non-contiguous slice)
+///
+/// Input:
+/// small = (2, 2, 1, 2)
+/// large = (2, 2, 2, 2, 2)
+/// offset = (1, 1, 0, 1)
+///
+///
+/// Output:
+/// small = (4, 2)
+/// large = (8, 4)
+/// offset = (24, 2)
+std::array<SmallVector<int64_t>, 3>
+getCollapsedStridedSliceShape(ArrayRef<int64_t> small, ArrayRef<int64_t> large,
+ ArrayRef<int64_t> offsets);
+
+std::optional<std::array<SmallVector<int64_t>, 3>>
+getCollapsedExtractStridedSliceShape(vector::ExtractStridedSliceOp extractOp);
+
+std::optional<std::array<SmallVector<int64_t>, 3>>
+getCollapsedInsertStridedSliceShape(vector::InsertStridedSliceOp insertOp);
+
+} // namespace vector
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..6954cb7172129 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,39 +406,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
-/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-///
-/// Definition: here 'linearization' means converting a single operation with
-/// 1+ vector operand/result of rank>1, into a new single operation whose
-/// vector operands and results are all of rank<=1.
-///
-/// This function registers (1) which operations are legal, and hence should not
-/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
-/// materialze the conversion (with shape_cast)
-///
-/// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, by adding additional
-/// dynamically legal ops to `conversionTarget`.
-///
-/// Further note: the choice to use a dialect conversion design for
-/// linearization is to make it easy to reuse generic structural type
-/// conversions for linearizing scf/cf/func operations
-void populateForVectorLinearize(TypeConverter &typeConverter,
- ConversionTarget &conversionTarget);
-
-/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
-/// contains patterns for converting ConstantLike, Vectorizable, and
-/// vector::BitCast ops.
-void populateVectorLinearizeBasePatterns(const TypeConverter &,
- const ConversionTarget &,
- RewritePatternSet &patterns);
-
-/// Populates `patterns` for linearizing ND (N >= 2) vector operations
-/// to 1D vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
- const ConversionTarget &,
- RewritePatternSet &patterns);
-
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fcfb401fd9867..359b2ba091967 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5726,13 +5726,21 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
- // No-op shape cast.
- if (getSource().getType() == resultType)
- return getSource();
- // shape_cast(shape_cast(x)) -> shape_cast(x)
- if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
- setOperand(precedingShapeCast.getSource());
+ // y = shape_cast(shape_cast(shape_cast(x)))
+ // -> shape_cast(x) # if x and y different types
+ // -> x # if x and y same type
+ // Value newSource = getSource();
+ ShapeCastOp parent = *this;
+ while (auto precedingShapeCast = parent.getSource().getDefiningOp<ShapeCastOp>()) {
+ parent = precedingShapeCast;
+ }
+
+ if (parent.getSource().getType() == resultType)
+ return parent.getSource();
+
+ if (parent != *this){
+ setOperand(parent.getSource());
return getResult();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..fae452d8e5dc9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,9 +10,10 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
@@ -45,6 +46,408 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
return rewriter.notifyMatchFailure(loc, "unsupported attr type");
}
+/// Convert an array of attributes into a vector of integers.
+static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
+ if (!attrs)
+ return failure();
+ SmallVector<int64_t> ints;
+ ints.reserve(attrs.size());
+ for (auto attr : attrs) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ ints.push_back(intAttr.getInt());
+ } else {
+ return failure();
+ }
+ }
+ return ints;
+}
+
+/// Convert OpFoldResults into a vector of integers, failing when an
+/// OpFoldResult is not an Attribute (unless the dimension in `shape` is 1, in
+/// which case the offset is 0, irrespective). Ensure that the returned vector
+/// is of the same rank as `shape` by appending zeros.
+static FailureOr<SmallVector<int64_t>>
+getIntegerOffsetsFromFoldResults(ArrayRef<OpFoldResult> offsetFoldResults,
+ ArrayRef<int64_t> shape) {
+ assert(shape.size() >= offsetFoldResults.size() &&
+ "offsets assumed not be be higher rank than shape");
+ unsigned deltaRank = shape.size() - offsetFoldResults.size();
+ SmallVector<int64_t> offsets;
+ offsets.reserve(offsetFoldResults.size());
+ for (auto [offsetFoldResult, dimSize] :
+ llvm::zip(offsetFoldResults, shape.drop_back(deltaRank))) {
+ if (dimSize == 1) {
+ offsets.push_back(0);
+ } else if (auto offsetAttr = dyn_cast<Attribute>(offsetFoldResult)) {
+ offsets.push_back(cast<IntegerAttr>(offsetAttr).getInt());
+ } else {
+ return failure();
+ }
+ }
+ offsets.resize(shape.size(), 0);
+ return offsets;
+}
+
+/// If `ndIndex` is the index in the n-dimensional array of shape `shape`, get
+/// the corresponding index into the flattened array.
+static int64_t getIndexInFlattened(ArrayRef<int64_t> ndIndex,
+ ArrayRef<int64_t> shape) {
+ assert(ndIndex.size() == shape.size() &&
+ "ndIndex and shape assumed to have the same size");
+ int64_t index = 0;
+ int64_t stride = 1;
+ for (int i = shape.size() - 1; i >= 0; --i) {
+ index += ndIndex[i] * stride;
+ stride *= shape[i];
+ }
+ return index;
+}
+
+/// Return true if `op` is an insert, extract, insert_strided_slice, or
+/// extract_strided_slice operation that operates on scalable vectors.
+/// Otherwise return false.
+static bool isScalableExtractOrInsertOrStrided(Operation *op) {
+ return TypeSwitch<Operation *, bool>(op)
+ .Case<vector::ExtractStridedSliceOp>(
+ [&](vector::ExtractStridedSliceOp extractOp) {
+ return extractOp.getType().isScalable();
+ })
+ .Case<vector::InsertStridedSliceOp>(
+ [&](vector::InsertStridedSliceOp insertOp) {
+ return insertOp.getType().isScalable();
+ })
+ .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
+ return insertOp.getType().isScalable();
+ })
+ .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
+ return extractOp.getSourceVectorType().isScalable();
+ })
+ .Default([&](auto) { return false; });
+}
+
+SmallVector<int64_t>
+vector::getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+ ArrayRef<int64_t> large,
+ ArrayRef<int64_t> offsets) {
+
+ // Example of alignment between, `large`, `small` and `offsets`:
+ // large = 4, 5, 6, 7, 8
+ // small = 1, 6, 7, 8
+ // offsets = 2, 3, 0
+ //
+ // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
+ assert((large.size() >= small.size()) &&
+ "rank of 'large' cannot be lower than rank of 'small'");
+ assert((large.size() >= offsets.size()) &&
+ "rank of 'large' cannot be lower than the number of offsets");
+ unsigned delta = large.size() - small.size();
+ unsigned nOffsets = offsets.size();
+ auto getSmall = [&](int64_t i) -> int64_t {
+ return i >= delta ? small[i - delta] : 1;
+ };
+ auto getOffset = [&](int64_t i) -> int64_t {
+ return i < nOffsets ? offsets[i] : 0;
+ };
+
+ // Using 2 vectors of indices, at each iteration populate the updated set of
+ // indices based on the old set of indices, and the size of the small vector
+ // in the current iteration.
+ SmallVector<int64_t> indices{0};
+ int64_t stride = 1;
+ for (int i = large.size() - 1; i >= 0; --i) {
+ int64_t currentSize = indices.size();
+ int64_t smallSize = getSmall(i);
+ int64_t nextSize = currentSize * smallSize;
+ SmallVector<int64_t> nextIndices(nextSize);
+ int64_t *base = nextIndices.begin();
+ int64_t offset = getOffset(i) * stride;
+ for (int j = 0; j < smallSize; ++j) {
+ for (int k = 0; k < currentSize; ++k) {
+ base[k] = indices[k] + offset;
+ }
+ offset += stride;
+ base += currentSize;
+ }
+ stride *= large[i];
+ indices = std::move(nextIndices);
+ }
+ return indices;
+}
+
+void vector::initializeForVectorLinearize(TypeConverter &typeConverter) {
+
+ auto convertType = [](Type type) -> std::optional<Type> {
+ VectorType vectorType = dyn_cast<VectorType>(type);
+
+ if (!vectorType || !vector::isLinearizableVector(vectorType))
+ return type;
+
+ VectorType linearizedType =
+ VectorType::get(vectorType.getNumElements(),
+ vectorType.getElementType(), vectorType.isScalable());
+
+ return linearizedType;
+ };
+ typeConverter.addConversion(convertType);
+
+ auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1) {
+ return nullptr;
+ }
+ Value value = inputs.front();
+ if (!isa<VectorType>(type) || !isa<VectorType>(value.getType())) {
+ return nullptr;
+ }
+ return builder.create<vector::ShapeCastOp>(loc, type, value);
+ };
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+}
+
+void vector::populateForFullVectorLinearize(
+ const TypeConverter &typeConverter, ConversionTarget &target,
+ RewritePatternSet &patterns, InsertExtractLinearizePreference preference) {
+
+ target.markUnknownOpDynamicallyLegal(
+ [=](Operation *op) -> std::optional<bool> {
+ // Only ops that are in the vector dialect, are ConstantLike, or
+ // are Vectorizable might be linearized currently.
+ StringLiteral vectorDialect =
+ vector::VectorDialect::getDialectNamespace();
+ StringRef opDialect = op->getDialect()->getNamespace();
+ bool supported = (opDialect == vectorDialect) ||
+ op->hasTrait<OpTrait::ConstantLike>() ||
+ op->hasTrait<OpTrait::Vectorizable>();
+ if (!supported)
+ return true;
+
+ // As type legalization is done with vector.shape_cast, shape_cast
+ // itself cannot be linearized (doing so would create new shape_casts to
+ // linearize ad infinitum).
+ if (isa<vector::ShapeCastOp>(op))
+ return true;
+
+ // The operations extract_strided_slice, extract, insert_strided_slice,
+ // and insert are linearized to a rank-1 operations that do not fully
+ // support scalable vectors, so it is not generally possible to
+ // linearize these ops if they operate on scalable vectors.
+ if (isScalableExtractOrInsertOrStrided(op))
+ return true;
+
+ // This will return true if, for all operand and result types `t`,
+ // convertType(t) = t. This is true if there are no rank>=2 vectors.
+ return typeConverter.isLegal(op);
+ });
+
+ VectorLinearizePatterns linearizePatterns;
+
+ if (preference == InsertExtractLinearizePreference::Shuffle) {
+ // Mark extract_strided_slice, insert_strided_slice, extract with source
+ // rank > 1, and insert with result rank > 1 as illegal, as they must be
+ // converted to shuffle or rank-1 extract/insert.
+ //
+ // Note that the order of the calls to `markUnknownOpDynamicallyLegal`
+ // is important: the legality rule added here takes precedence over the
+ // generic one preceding it which marked these ops as legal.
+ target.markUnknownOpDynamicallyLegal(
+ [](Operation *op) -> std::optional<bool> {
+ bool isStrided =
+ isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp>(
+ op);
+
+ bool isHighRankExtractOrInsert = [&]() {
+ if (auto extractOp = dyn_cast<vector::ExtractOp>(op)) {
+ return extractOp.getSourceVectorType().getRank() > 1;
+ }
+ if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+ return insertOp.getType().getRank() > 1;
+ }
+ return false;
+ }();
+
+ bool isScalable = isScalableExtractOrInsertOrStrided(op);
+
+ if ((isStrided || isHighRankExtractOrInsert) && !isScalable) {
+ return false;
+ }
+ return std::nullopt;
+ });
+
+ // Ensure that the benefit of patterns targetting shuffle is higher than
+ // the benefit of patterns targeting rank-1 strided slice operations. This
+ // will ensure that patterns for converting to rank-1 shuffle are run first.
+ linearizePatterns
+ .incrementBenefit(
+ LinearizePattern::VectorExtractStridedSliceToRankOneShuffle)
+ .incrementBenefit(
+ LinearizePattern::VectorInsertStridedSliceToRankOneShuffle)
+ .incrementBenefit(LinearizePattern::VectorExtractToRankOneShuffle)
+ .incrementBenefit(LinearizePattern::VectorInsertToRankOneShuffle);
+
+ } else if (preference == InsertExtractLinearizePreference::Strided) {
+ linearizePatterns
+ .incrementBenefit(LinearizePattern::RankReduceInsertStridedSlice)
+ .incrementBenefit(LinearizePattern::RankReduceExtractStridedSlice)
+ .incrementBenefit(LinearizePattern::VectorInsertToRankOneStrided)
+ .incrementBenefit(LinearizePattern::VectorExtractToRankOneStrided);
+ } else {
+ assert(false && "unsupported InsertExtractLinearizePreference");
+ }
+ linearizePatterns.addToPatternSet(typeConverter, patterns);
+}
+
+/// Get the lowest rank shapes and offsets which represent the same strided
+/// slice as the strided slice described by `small`, `large`, and `offsets`.
+///
+/// Example
+///
+/// %0 = vector.extract_strided_slice %1
+/// {ofsets = [0, 0, 0], sizes = [2, 2, 2], strides = [1, 1, 1]} :
+/// vector<4x2x4xf32> to vector<2x2x2xf32>
+///
+/// is equivalent to
+///
+/// [...rank reducing shape casts...]
+/// %0 = vector.extract_strided_slice %1
+/// {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} :
+/// vector<8x4xf32> to vector<4x2xf32>
+/// [...rank increasing shape cast...]
+///
+/// So the output for
+/// (small, large, offsets = [2, 2, 2], [4, 2, 4], [0, 0, 0]) is
+/// (small, large, offsets = [4, 2], [8, 4], [0, 0])
+std::array<SmallVector<int64_t>, 3>
+vector::getCollapsedStridedSliceShape(ArrayRef<int64_t> small,
+ ArrayRef<int64_t> large,
+ ArrayRef<int64_t> offsets) {
+
+ // The total number of elements in the small (large, respectively) vector.
+ int64_t tSmall = std::accumulate(small.begin(), small.end(), 1,
+ std::multiplies<int64_t>());
+ int64_t tLarge = std::accumulate(large.begin(), large.end(), 1,
+ std::multiplies<int64_t>());
+ assert(tLarge >= tSmall &&
+ "total number of elements in 'small' is larger than in 'large'");
+ assert(large.size() >= small.size() &&
+ "rank of 'small' is larger than rank of 'large'");
+ assert(offsets.size() <= large.size() &&
+ "rank of large is less than the number of offsets");
+
+ int64_t nOffsets = offsets.size();
+ auto getOffset = [&](int64_t i) -> int64_t {
+ return i < nOffsets ? offsets[i] : 0;
+ };
+
+ unsigned delta = large.size() - small.size();
+
+ // The cumulative (product of dimensions) number of elements from the back
+ // currently visited in the small (large, respectively) vector.
+ int64_t nSmall = 1;
+ int64_t nLarge = 1;
+
+ // The cumulative number (product of dimensions) of elements from the back
+ // currently visited within the current collapse group in the small (large,
+ // respectively) vector.
+ int64_t cSmall = 1;
+ int64_t cLarge = 1;
+
+ SmallVector<int64_t> newSmall, newLarge, newOffsets;
+ if (large.size() == 0)
+ return {newSmall, newLarge, newOffsets};
+
+ // The offset assigned to the current collapse group.
+ int64_t cOff = 0;
+
+ unsigned index = large.size() - 1;
+ while (nLarge < tLarge) {
+ assert(cSmall <= nSmall && nSmall <= tSmall && //
+ cLarge <= nLarge && nLarge <= tLarge &&
+ "confusion in element accumulation");
+ cOff += getOffset(index) * cLarge;
+ if (nSmall < tSmall) {
+ cSmall *= small[index - delta];
+ nSmall *= small[index - delta];
+ }
+ cLarge *= large[index];
+ nLarge *= large[index];
+ if ((nSmall < tSmall) && (large[index] != small[index - delta])) {
+ newSmall.push_back(cSmall);
+ newLarge.push_back(cLarge);
+ newOffsets.push_back(cOff);
+ cSmall = 1;
+ cLarge = 1;
+ cOff = 0;
+ }
+ --index;
+ }
+ newSmall.push_back(cSmall);
+ newLarge.push_back(cLarge);
+ newOffsets.push_back(cOff);
+ std::reverse(newSmall.begin(), newSmall.end());
+ std::reverse(newLarge.begin(), newLarge.end());
+ std::reverse(newOffsets.begin(), newOffsets.end());
+ return {newSmall, newLarge, newOffsets};
+}
+
+// returns small, large, offsets.
+std::optional<std::array<SmallVector<int64_t>, 3>>
+vector::getCollapsedExtractStridedSliceShape(
+ vector::ExtractStridedSliceOp extractOp) {
+
+ if (extractOp.hasNonUnitStrides())
+ return std::nullopt;
+
+ ArrayRef<int64_t> outShape = extractOp.getType().getShape();
+ ArrayRef<int64_t> inShape = extractOp.getSourceVectorType().getShape();
+
+ auto maybeIntOffsets = intsFromArrayAttr(extractOp.getOffsets());
+ if (failed(maybeIntOffsets))
+ return std::nullopt;
+
+ SmallVector<int64_t> offsets = std::move(maybeIntOffsets.value());
+ const auto &[collapsedOutShape, collapsedInShape, collapsedOffsets] =
+ getCollapsedStridedSliceShape(outShape, inShape, offsets);
+
+ bool unchanged = (collapsedInShape.size() == inShape.size()) &&
+ (collapsedOutShape.size() == outShape.size());
+
+ if (unchanged)
+ return std::nullopt;
+
+ return std::array<SmallVector<int64_t>, 3>{
+ collapsedOutShape, collapsedInShape, collapsedOffsets};
+}
+
+// returns small, large, offsets.
+std::optional<std::array<SmallVector<int64_t>, 3>>
+vector::getCollapsedInsertStridedSliceShape(
+ vector::InsertStridedSliceOp insertOp) {
+
+ if (insertOp.hasNonUnitStrides())
+ return std::nullopt;
+
+ ArrayRef<int64_t> outShape = insertOp.getType().getShape();
+ ArrayRef<int64_t> inShape = insertOp.getSourceVectorType().getShape();
+
+ auto maybeIntOffsets = intsFromArrayAttr(insertOp.getOffsets());
+ if (failed(maybeIntOffsets))
+ return std::nullopt;
+
+ SmallVector<int64_t> offsets = std::move(maybeIntOffsets.value());
+ const auto &[collapsedInShape, collapsedOutShape, collapsedOffsets] =
+ getCollapsedStridedSliceShape(inShape, outShape, offsets);
+
+ bool unchanged = (collapsedInShape.size() == inShape.size()) &&
+ (collapsedOutShape.size() == outShape.size());
+
+ if (unchanged)
+ return std::nullopt;
+
+ return std::array<SmallVector<int64_t>, 3>{
+ collapsedInShape, collapsedOutShape, collapsedOffsets};
+}
+
namespace {
struct LinearizeConstantLike final
@@ -52,7 +455,7 @@ struct LinearizeConstantLike final
using OpTraitConversionPattern::OpTraitConversionPattern;
LinearizeConstantLike(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -94,7 +497,7 @@ struct LinearizeVectorizable final
public:
LinearizeVectorizable(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -109,90 +512,6 @@ struct LinearizeVectorizable final
}
};
-template <typename TOp>
-static bool stridesAllOne(TOp op) {
- static_assert(
- std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
- std::is_same_v<TOp, vector::InsertStridedSliceOp>,
- "expected vector.extract_strided_slice or vector.insert_strided_slice");
- ArrayAttr strides = op.getStrides();
- return llvm::all_of(strides, isOneInteger);
-}
-
-/// Convert an array of attributes into a vector of integers, if possible.
-static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
- if (!attrs)
- return failure();
- SmallVector<int64_t> ints;
- ints.reserve(attrs.size());
- for (auto attr : attrs) {
- if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
- ints.push_back(intAttr.getInt());
- } else {
- return failure();
- }
- }
- return ints;
-}
-
-/// Consider inserting a vector of shape `small` into a vector of shape `large`,
-/// at position `offsets`: this function enumeratates all the indices in `large`
-/// that are written to. The enumeration is with row-major ordering.
-///
-/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
-/// positions written to are (1,3) and (1,4), which have linearized indices 8
-/// and 9. So [8,9] is returned.
-///
-/// The length of the returned vector is equal to the number of elements in
-/// the shape `small` (i.e. the product of dimensions of `small`).
-SmallVector<int64_t> static getStridedSliceInsertionIndices(
- ArrayRef<int64_t> small, ArrayRef<int64_t> large,
- ArrayRef<int64_t> offsets) {
-
- // Example of alignment between, `large`, `small` and `offsets`:
- // large = 4, 5, 6, 7, 8
- // small = 1, 6, 7, 8
- // offsets = 2, 3, 0
- //
- // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
- assert((large.size() >= small.size()) &&
- "rank of 'large' cannot be lower than rank of 'small'");
- assert((large.size() >= offsets.size()) &&
- "rank of 'large' cannot be lower than the number of offsets");
- unsigned delta = large.size() - small.size();
- unsigned nOffsets = offsets.size();
- auto getSmall = [&](int64_t i) -> int64_t {
- return i >= delta ? small[i - delta] : 1;
- };
- auto getOffset = [&](int64_t i) -> int64_t {
- return i < nOffsets ? offsets[i] : 0;
- };
-
- // Using 2 vectors of indices, at each iteration populate the updated set of
- // indices based on the old set of indices, and the size of the small vector
- // in the current iteration.
- SmallVector<int64_t> indices{0};
- int64_t stride = 1;
- for (int i = large.size() - 1; i >= 0; --i) {
- int64_t currentSize = indices.size();
- int64_t smallSize = getSmall(i);
- int64_t nextSize = currentSize * smallSize;
- SmallVector<int64_t> nextIndices(nextSize);
- int64_t *base = nextIndices.begin();
- int64_t offset = getOffset(i) * stride;
- for (int j = 0; j < smallSize; ++j) {
- for (int k = 0; k < currentSize; ++k) {
- base[k] = indices[k] + offset;
- }
- offset += stride;
- base += currentSize;
- }
- stride *= large[i];
- indices = std::move(nextIndices);
- }
- return indices;
-}
-
/// This pattern converts a vector.extract_strided_slice operation into a
/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
///
@@ -212,12 +531,12 @@ SmallVector<int64_t> static getStridedSliceInsertionIndices(
///
/// `shuffle_indices_1d` is computed using the offsets and sizes of the original
/// vector.extract_strided_slice operation.
-struct LinearizeVectorExtractStridedSlice final
- : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
+struct VectorExtractStridedSliceToRankOneShuffle final
+ : public OpConversionPattern<vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
- MLIRContext *context,
- PatternBenefit benefit = 1)
+ VectorExtractStridedSliceToRankOneShuffle(const TypeConverter &typeConverter,
+ MLIRContext *context,
+ PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
@@ -231,7 +550,7 @@ struct LinearizeVectorExtractStridedSlice final
// Expect a legalization failure if the strides are not all 1 (if ever the
// verifier for extract_strided_slice allows non-1 strides).
- if (!stridesAllOne(extractStridedSliceOp)) {
+ if (extractStridedSliceOp.hasNonUnitStrides()) {
return rewriter.notifyMatchFailure(
extractStridedSliceOp,
"extract_strided_slice with strides != 1 not supported");
@@ -249,7 +568,7 @@ struct LinearizeVectorExtractStridedSlice final
ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
- SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
+ SmallVector<int64_t> indices = vector::getStridedSliceInsertionIndices(
outputShape, inputShape, offsets.value());
Value srcVector = adaptor.getVector();
@@ -259,36 +578,24 @@ struct LinearizeVectorExtractStridedSlice final
}
};
-/// This pattern converts a vector.insert_strided_slice operation into a
-/// vector.shuffle operation that has rank-1 (linearized) operands and result.
-///
-/// For example, the following:
-/// ```
-/// %0 = vector.insert_strided_slice %to_store, %into
-/// {offsets = [1, 0, 0, 0], strides = [1, 1]}
-/// : vector<2x2xi8> into vector<2x1x3x2xi8>
-/// ```
-///
-/// is converted to
-/// ```
-/// %to_store_1d
-/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
-/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
-/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
-/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
-/// ```
-///
-/// where shuffle_indices_1d in this case is
-/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
-/// ^^^^^^^^^^^^^^
-/// to_store_1d
-///
-struct LinearizeVectorInsertStridedSlice final
- : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
+static Value asRankOne(ConversionPatternRewriter &rewriter, Value v) {
+ auto vType = dyn_cast<VectorType>(v.getType());
+ assert(vType && "expected vector type");
+ assert(vType.getRank() <= 1 && "expected rank-0 or rank-1 type");
+ if (vType.getRank() == 1)
+ return v;
+ // Convert rank-0 vector to rank-1 vector.
+ v = rewriter.create<vector::ShapeCastOp>(
+ v.getLoc(), VectorType::get({1}, vType.getElementType()), v);
+ return v;
+}
+
+struct VectorInsertStridedSliceToRankOneShuffle final
+ : public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
- MLIRContext *context,
- PatternBenefit benefit = 1)
+ VectorInsertStridedSliceToRankOneShuffle(const TypeConverter &typeConverter,
+ MLIRContext *context,
+ PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
@@ -298,7 +605,7 @@ struct LinearizeVectorInsertStridedSlice final
// Expect a legalization failure if the strides are not all 1 (if ever the
// verifier for insert_strided_slice allows non-1 strides).
- if (!stridesAllOne(insertStridedSliceOp)) {
+ if (insertStridedSliceOp.hasNonUnitStrides()) {
return rewriter.notifyMatchFailure(
insertStridedSliceOp,
"insert_strided_slice with strides != 1 not supported");
@@ -317,7 +624,7 @@ struct LinearizeVectorInsertStridedSlice final
return rewriter.notifyMatchFailure(insertStridedSliceOp,
"failed to get integer offsets");
}
- SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
+ SmallVector<int64_t> sliceIndices = vector::getStridedSliceInsertionIndices(
inputShape, outputShape, offsets.value());
SmallVector<int64_t> indices(nOutputElements);
@@ -326,7 +633,7 @@ struct LinearizeVectorInsertStridedSlice final
indices[sliceIndex] = index + nOutputElements;
}
- Value flatToStore = adaptor.getValueToStore();
+ Value flatToStore = asRankOne(rewriter, adaptor.getValueToStore());
Value flatDest = adaptor.getDest();
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
flatDest.getType(), flatDest,
@@ -350,7 +657,7 @@ struct LinearizeVectorShuffle final
: public OpConversionPattern<vector::ShuffleOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeVectorShuffle(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
@@ -360,8 +667,8 @@ struct LinearizeVectorShuffle final
getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
assert(dstType && "vector type destination expected.");
- Value vec1 = adaptor.getV1();
- Value vec2 = adaptor.getV2();
+ Value vec1 = asRankOne(rewriter, adaptor.getV1());
+ Value vec2 = asRankOne(rewriter, adaptor.getV2());
int shuffleSliceLen = 1;
int rank = shuffleOp.getV1().getType().getRank();
@@ -404,11 +711,11 @@ struct LinearizeVectorShuffle final
/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
/// %out_nd = vector.shape_cast %out_1d
/// `shuffle_indices_1d` is computed using the position of the original extract.
-struct LinearizeVectorExtract final
+struct VectorExtractToRankOneShuffle final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorExtract(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ VectorExtractToRankOneShuffle(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
@@ -436,11 +743,120 @@ struct LinearizeVectorExtract final
linearizedOffset += offsets[i] * size;
}
+ Value v0 = asRankOne(rewriter, adaptor.getVector());
llvm::SmallVector<int64_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), linearizedOffset);
- rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractOp, dstTy, v0, v0,
+ indices);
+
+ return success();
+ }
+};
+
+/// Convert a vector.extract op with input rank > 1, to an operation with input
+/// of rank 1 and output of rank <= 1. Two lowering cases:
+///
+/// 1) If the result of the vector.extract is a scalar, convert it to a
+/// vector.extract on a rank-1 input which still outputs a scalar.
+///
+/// 2) Otherwise, convert to an extract_strided_slice op on a vector of rank-1.
+struct VectorExtractToRankOneStrided final
+ : public OpConversionPattern<vector::ExtractOp> {
+ using OpConversionPattern::OpConversionPattern;
+ VectorExtractToRankOneStrided(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // TypedValue<VectorType> input = extractOp.getVector();
+ VectorType inType = extractOp.getVector().getType();
+ if (inType.getRank() == 1)
+ return failure();
+
+ SmallVector<OpFoldResult> offsets = extractOp.getMixedPosition();
+ auto maybeIntOffsets =
+ getIntegerOffsetsFromFoldResults(offsets, inType.getShape());
+ if (failed(maybeIntOffsets)) {
+ return failure();
+ }
+ const auto &intOffsets = maybeIntOffsets.value();
+ int64_t globalOffset = getIndexInFlattened(intOffsets, inType.getShape());
+
+ Location loc = extractOp.getLoc();
+
+ Type outType = extractOp.getType();
+
+ // Case 1 described above:
+ if (outType.isIntOrIndexOrFloat()) {
+ Value flattened = rewriter.create<vector::ExtractOp>(
+ loc, adaptor.getVector(), SmallVector<int64_t>{globalOffset});
+ rewriter.replaceOp(extractOp, flattened);
+ return success();
+ }
+
+ VectorType vOutType = dyn_cast<VectorType>(outType);
+ assert(vOutType && "expected vector type for output");
+
+ auto numberElementsOut = vOutType.getNumElements();
+ auto strided = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, adaptor.getVector(), SmallVector<int64_t>{globalOffset},
+ SmallVector<int64_t>{numberElementsOut}, SmallVector<int64_t>{1});
+
+ rewriter.replaceOp(extractOp, strided);
+ return success();
+ }
+};
+
+/// Convert vector.insert where the destination is rank > 1. Two cases:
+///
+/// 1) If the source to insert is a scalar, convert to a vector.insert op
+/// where the destination is rank-1.
+///
+/// 2) Otherwise, convert to a vector.insert_strided_slice op into a vector of
+/// rank-1.
+struct VectorInsertToRankOneStrided final
+ : public OpConversionPattern<vector::InsertOp> {
+ using OpConversionPattern::OpConversionPattern;
+ VectorInsertToRankOneStrided(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType largeType = insertOp.getDest().getType();
+ Type smallType = insertOp.getValueToStoreType();
+ SmallVector<OpFoldResult> positions = insertOp.getMixedPosition();
+ auto maybeIntOffsets =
+ getIntegerOffsetsFromFoldResults(positions, largeType.getShape());
+ if (failed(maybeIntOffsets)) {
+ return failure();
+ }
+ const auto &intOffsets = maybeIntOffsets.value();
+ int64_t globalOffset =
+ getIndexInFlattened(intOffsets, largeType.getShape());
+
+ Location loc = insertOp.getLoc();
+
+ // case 1
+ if (smallType.isSignlessIntOrFloat()) {
+ auto flatOut = rewriter.create<vector::InsertOp>(
+ loc, adaptor.getValueToStore(), adaptor.getDest(),
+ SmallVector<int64_t>{globalOffset});
+ rewriter.replaceOp(insertOp, flatOut);
+ return success();
+ }
+ // case 2
+ Value v0 = asRankOne(rewriter, adaptor.getValueToStore());
+ auto flatOut = rewriter.create<vector::InsertStridedSliceOp>(
+ insertOp.getLoc(), v0, adaptor.getDest(),
+ SmallVector<int64_t>{globalOffset}, SmallVector<int64_t>{1});
+ rewriter.replaceOp(insertOp, flatOut);
return success();
}
};
@@ -455,11 +871,11 @@ struct LinearizeVectorExtract final
/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
/// ] %out_nd = vector.shape_cast %out_1d
/// `shuffle_indices_1d` is computed using the position of the original insert.
-struct LinearizeVectorInsert final
+struct VectorInsertToRankOneShuffle final
: public OpConversionPattern<vector::InsertOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorInsert(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ VectorInsertToRankOneShuffle(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
@@ -508,7 +924,8 @@ struct LinearizeVectorInsert final
// [offset+srcNumElements, end)
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
+ insertOp, dstTy, adaptor.getDest(),
+ asRankOne(rewriter, adaptor.getValueToStore()), indices);
return success();
}
@@ -526,7 +943,7 @@ struct LinearizeVectorBitCast final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeVectorBitCast(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
@@ -535,7 +952,7 @@ struct LinearizeVectorBitCast final
assert(resType && "expected 1-D vector type");
rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
adaptor.getSource());
- return mlir::success();
+ return success();
}
};
@@ -550,7 +967,7 @@ struct LinearizeVectorSplat final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
- PatternBenefit benefit = 1)
+ PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
@@ -581,7 +998,7 @@ struct LinearizeVectorCreateMask final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorCreateMask(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
@@ -607,17 +1024,16 @@ struct LinearizeVectorCreateMask final
// The result of the comparison is then multiplied with
// the second operand of create_mask to get the 1D mask.
auto firstOperand = adaptor.getOperands().front();
- auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
- auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
- auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto isNonZero = rewriter.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sgt, firstOperand, zero);
+ auto isNonZeroIndex = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), isNonZero);
auto secondOperand = adaptor.getOperands().back();
- auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
+ auto maskSize = rewriter.createOrFold<arith::AndIOp>(
loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
- auto newMask =
- rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
+ auto newMask = rewriter.create<vector::CreateMaskOp>(loc, dstTy, maskSize);
rewriter.replaceOp(createMaskOp, newMask);
return success();
}
@@ -625,104 +1041,179 @@ struct LinearizeVectorCreateMask final
} // namespace
-/// This method defines the set of operations that are linearizable, and hence
-/// that are considered illegal for the conversion target.
-static bool isLinearizable(Operation *op) {
+/// This pattern converts a vector.extract_strided_slice into a new
+/// vector.extract_strided_slice where the operand and result of the new
+/// vector.extract_strided_slice have ranks that are as low as possible.
+///
+/// If the original vector.extract_strided_slice is a contiguous slice of
+/// a vector, then the new vector.extract_strided_slice will have rank-1
+/// operand and result. Otherwise additional dimensions will remain in the
+/// new operand and result.
+struct RankReduceExtractStridedSlice final
+ : public OpConversionPattern<vector::ExtractStridedSliceOp> {
+ using OpConversionPattern::OpConversionPattern;
- // Only ops that are in the vector dialect, are ConstantLike, or
- // are Vectorizable might be linearized currently.
- StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
- StringRef opDialect = op->getDialect()->getNamespace();
- bool supported = (opDialect == vectorDialect) ||
- op->hasTrait<OpTrait::ConstantLike>() ||
- op->hasTrait<OpTrait::Vectorizable>();
- if (!supported)
- return false;
+ RankReduceExtractStridedSlice(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
- return TypeSwitch<Operation *, bool>(op)
- // As type legalization is done with vector.shape_cast, shape_cast
- // itself cannot be linearized (will create new shape_casts to linearize
- // ad infinitum).
- .Case<vector::ShapeCastOp>([&](auto) { return false; })
- // The operations
- // - vector.extract_strided_slice
- // - vector.extract
- // - vector.insert_strided_slice
- // - vector.insert
- // are linearized to a rank-1 vector.shuffle by the current patterns.
- // vector.shuffle only supports fixed size vectors, so it is impossible to
- // use this approach to linearize these ops if they operate on scalable
- // vectors.
- .Case<vector::ExtractStridedSliceOp>(
- [&](vector::ExtractStridedSliceOp extractOp) {
- return !extractOp.getType().isScalable();
- })
- .Case<vector::InsertStridedSliceOp>(
- [&](vector::InsertStridedSliceOp insertOp) {
- return !insertOp.getType().isScalable();
- })
- .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
- return !insertOp.getType().isScalable();
- })
- .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
- return !extractOp.getSourceVectorType().isScalable();
- })
- .Default([&](auto) { return true; });
-}
+ LogicalResult
+ matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
-void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
- ConversionTarget &target) {
+ auto maybeCollapsed = getCollapsedExtractStridedSliceShape(extractOp);
+ if (!maybeCollapsed.has_value())
+ return failure();
- auto convertType = [](Type type) -> std::optional<Type> {
- VectorType vectorType = dyn_cast<VectorType>(type);
- if (!vectorType || !isLinearizableVector(vectorType))
- return type;
+ const auto &[collapsedOutShape, collapsedInShape, collapsedOffsets] =
+ maybeCollapsed.value();
- VectorType linearizedType =
- VectorType::get(vectorType.getNumElements(),
- vectorType.getElementType(), vectorType.isScalable());
- return linearizedType;
- };
- typeConverter.addConversion(convertType);
+ VectorType collapsedInType =
+ VectorType::get(collapsedInShape, extractOp.getType().getElementType());
- auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
- Location loc) -> Value {
- if (inputs.size() != 1)
- return nullptr;
+ auto collapsedIn = rewriter.createOrFold<vector::ShapeCastOp>(
+ extractOp.getLoc(), collapsedInType, adaptor.getVector());
- Value value = inputs.front();
- if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
- return nullptr;
+ auto replacement = rewriter.create<vector::ExtractStridedSliceOp>(
+ extractOp.getLoc(), collapsedIn, collapsedOffsets, collapsedOutShape,
+ SmallVector<int64_t>(collapsedOffsets.size(), 1));
- return builder.create<vector::ShapeCastOp>(loc, type, value);
- };
- typeConverter.addSourceMaterialization(materializeCast);
- typeConverter.addTargetMaterialization(materializeCast);
+ VectorType flatOutputType =
+ getTypeConverter()->convertType<VectorType>(extractOp.getType());
- target.markUnknownOpDynamicallyLegal(
- [=](Operation *op) -> std::optional<bool> {
- if (!isLinearizable(op))
- return true;
- // This will return true if, for all operand and result types `t`,
- // convertType(t) = t. This is true if there are no rank>=2 vectors.
- return typeConverter.isLegal(op);
- });
-}
+ Value out = rewriter.createOrFold<vector::ShapeCastOp>(
+ extractOp.getLoc(), flatOutputType, replacement);
-void mlir::vector::populateVectorLinearizeBasePatterns(
- const TypeConverter &typeConverter, const ConversionTarget &target,
- RewritePatternSet &patterns) {
- patterns
- .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask>(
- typeConverter, patterns.getContext());
-}
+ rewriter.replaceOp(extractOp, out);
+
+ return success();
+ }
+};
+
+struct RankReduceInsertStridedSlice final
+ : public OpConversionPattern<vector::InsertStridedSliceOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ RankReduceInsertStridedSlice(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto maybeCollapsed = getCollapsedInsertStridedSliceShape(insertOp);
+
+ if (!maybeCollapsed.has_value())
+ return failure();
+
+ const auto &[collapsedInShape, collapsedOutShape, collapsedOffsets] =
+ maybeCollapsed.value();
+
+ VectorType collapsedInType =
+ VectorType::get(collapsedInShape, insertOp.getType().getElementType());
+
+ Value collapsedIn = rewriter.createOrFold<vector::ShapeCastOp>(
+ insertOp.getLoc(), collapsedInType, adaptor.getValueToStore());
+
+ VectorType collapsedOutType =
+ VectorType::get(collapsedOutShape, insertOp.getType().getElementType());
+
+ Value collapsedDst = rewriter.createOrFold<vector::ShapeCastOp>(
+ insertOp.getLoc(), collapsedOutType, adaptor.getDest());
+
+ auto replacement = rewriter.create<vector::InsertStridedSliceOp>(
+ insertOp.getLoc(), collapsedIn, collapsedDst, collapsedOffsets,
+ SmallVector<int64_t>(collapsedOffsets.size(), 1));
+
+ Value out = rewriter.createOrFold<vector::ShapeCastOp>(
+ insertOp.getLoc(), insertOp.getType(), replacement);
+
+ rewriter.replaceOp(insertOp, out);
+
+ return success();
+ }
+};
-void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
- const TypeConverter &typeConverter, const ConversionTarget &target,
- RewritePatternSet &patterns) {
- patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
- LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
- LinearizeVectorInsertStridedSlice>(typeConverter,
- patterns.getContext());
+void vector::VectorLinearizePatterns::addToPatternSet(
+ const TypeConverter &typeConverter, RewritePatternSet &patterns) const {
+
+ MLIRContext *context = patterns.getContext();
+
+ if (isEnabled(LinearizePattern::LinearizeConstantLike))
+ patterns.add<LinearizeConstantLike>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeConstantLike));
+
+ if (isEnabled(LinearizePattern::LinearizeVectorizable))
+ patterns.add<LinearizeVectorizable>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeVectorizable));
+
+ if (isEnabled(LinearizePattern::LinearizeVectorBitCast))
+ patterns.add<LinearizeVectorBitCast>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeVectorBitCast));
+
+ if (isEnabled(LinearizePattern::LinearizeVectorCreateMask))
+ patterns.add<LinearizeVectorCreateMask>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeVectorCreateMask));
+
+ if (isEnabled(LinearizePattern::LinearizeVectorShuffle))
+ patterns.add<LinearizeVectorShuffle>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeVectorShuffle));
+
+ if (isEnabled(LinearizePattern::LinearizeVectorSplat))
+ patterns.add<LinearizeVectorSplat>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeVectorSplat));
+
+ // ------------------------ //
+ // Extract related patterns //
+ // ------------------------ //
+ if (isEnabled(LinearizePattern::VectorExtractToRankOneShuffle))
+ patterns.add<VectorExtractToRankOneShuffle>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::VectorExtractToRankOneShuffle));
+
+ if (isEnabled(LinearizePattern::VectorExtractStridedSliceToRankOneShuffle))
+ patterns.add<VectorExtractStridedSliceToRankOneShuffle>(
+ typeConverter, context,
+ getBenefit(
+ LinearizePattern::VectorExtractStridedSliceToRankOneShuffle));
+
+ if (isEnabled(LinearizePattern::RankReduceExtractStridedSlice))
+ patterns.add<RankReduceExtractStridedSlice>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::RankReduceExtractStridedSlice));
+
+ if (isEnabled(LinearizePattern::VectorExtractToRankOneStrided))
+ patterns.add<VectorExtractToRankOneStrided>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::VectorExtractToRankOneStrided));
+
+ // ------------------------ //
+ // Insert related patterns //
+ // ------------------------ //
+ if (isEnabled(LinearizePattern::VectorInsertToRankOneShuffle))
+ patterns.add<VectorInsertToRankOneShuffle>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::VectorInsertToRankOneShuffle));
+
+ if (isEnabled(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle))
+ patterns.add<VectorInsertStridedSliceToRankOneShuffle>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle));
+
+ if (isEnabled(LinearizePattern::RankReduceInsertStridedSlice))
+ patterns.add<RankReduceInsertStridedSlice>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::RankReduceInsertStridedSlice));
+
+ if (isEnabled(LinearizePattern::VectorInsertToRankOneStrided))
+ patterns.add<VectorInsertToRankOneStrided>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::VectorInsertToRankOneStrided));
}
diff --git a/mlir/test/Dialect/Vector/linearize/linearize-insert-extract-preference.mlir b/mlir/test/Dialect/Vector/linearize/linearize-insert-extract-preference.mlir
new file mode 100644
index 0000000000000..a73602263d06e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize/linearize-insert-extract-preference.mlir
@@ -0,0 +1,287 @@
+// Everything becomes a shuffle (except rank-1 insert/extract).
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=preference=Shuffle | FileCheck %s --check-prefixes=SHUFFLE,ALL
+
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=preference=Strided | FileCheck %s --check-prefixes=STRIDED,ALL
+
+
+// **--------------------------------------------------------**
+// Tests of vector.insert
+// **--------------------------------------------------------**
+
+// vector.insert where the destination is a 1D vector is always unchanged.
+//
+// ALL-LABEL: insert_scalar_to_1D(
+// ALL-SAME: %[[A0:.*]]: i8, %[[A1:.*]]: vector<4xi8>
+// ALL: %[[IN0:.*]] = vector.insert %[[A0]], %[[A1]] [2] : i8 into vector<4xi8>
+// ALL: return %[[IN0]] : vector<4xi8>
+func.func @insert_scalar_to_1D(%arg0 : i8, %arg1 : vector<4xi8>) -> vector<4xi8>
+{
+ %inserted = vector.insert %arg0, %arg1[2] : i8 into vector<4xi8>
+ return %inserted : vector<4xi8>
+}
+
+// -----
+
+// vector.insert of scalar always becomes insert of scalar into 1-D vector.
+//
+// ALL-LABEL: insert_scalar_to_2D(
+// ALL-SAME: %[[A0:.*]]: i8, %[[A1:.*]]: vector<3x4xi8>
+// ALL: %[[SC0:.*]] = vector.shape_cast %[[A1]] : vector<3x4xi8> to vector<12xi8>
+// ALL: %[[IN0:.*]] = vector.insert %[[A0]], %[[SC0]] [9] : i8 into vector<12xi8>
+// ALL: %[[SC1:.*]] = vector.shape_cast %[[IN0]] : vector<12xi8> to vector<3x4xi8>
+// ALL: return %[[SC1]] : vector<3x4xi8>
+func.func @insert_scalar_to_2D(%arg0 : i8, %arg1 : vector<3x4xi8>) -> vector<3x4xi8>
+{
+ %inserted = vector.insert %arg0, %arg1[2, 1] : i8 into vector<3x4xi8>
+ return %inserted : vector<3x4xi8>
+}
+
+// -----
+
+// vector.insert where the source isn't a scalar. First case: 1D -> 2D.
+//
+// ALL-LABEL: insert_1D_to_2D(
+//
+// SHUFFLE: vector.shuffle {{.*}} [0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11]
+//
+// STRIDED: vector.insert_strided_slice {{.*}} {offsets = [4], strides = [1]}
+// STRIDED-SAME: vector<4xi8> into vector<12xi8>
+func.func @insert_1D_to_2D(%arg0 : vector<4xi8>, %arg1 : vector<3x4xi8>) -> vector<3x4xi8>
+{
+ %inserted = vector.insert %arg0, %arg1[1] : vector<4xi8> into vector<3x4xi8>
+ return %inserted : vector<3x4xi8>
+}
+
+
+// -----
+
+// vector.insert where the source isn't a scalar. Second case: 0D -> 2D.
+//
+// ALL-LABEL: insert_OD_to_2D(
+//
+// SHUFFLE: vector.shuffle {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 10, 11] :
+// SHUFFLE-SAME: vector<12xi8>, vector<1xi8>
+//
+// STRIDED: vector.insert_strided_slice {{.*}} {offsets = [9], strides = [1]}
+// STRIDED-SAME: vector<1xi8> into vector<12xi8>
+func.func @insert_OD_to_2D(%arg0 : vector<i8>, %arg1 : vector<3x4xi8>) -> vector<3x4xi8>
+{
+ %inserted = vector.insert %arg0, %arg1[2, 1] : vector<i8> into vector<3x4xi8>
+ return %inserted : vector<3x4xi8>
+}
+
+// -----
+
+// vector.insert where the source isn't a scalar. Third case: 0D -> 1D.
+//
+// ALL-LABEL: insert_OD_to_1D(
+// ALL-SAME: %[[A0:.*]]: vector<i8>, %[[A1:.*]]: vector<4xi8>
+// ALL: %[[IN0:.*]] = vector.insert %[[A0]], %[[A1]] [2] : vector<i8> into vector<4xi8>
+// ALL: return %[[IN0]] : vector<4xi8>
+func.func @insert_OD_to_1D(%arg0 : vector<i8>, %arg1 : vector<4xi8>) -> vector<4xi8>
+{
+ %inserted = vector.insert %arg0, %arg1[2] : vector<i8> into vector<4xi8>
+ return %inserted : vector<4xi8>
+}
+
+// -----
+
+// vector.insert where the source isn't a scalar. Fourth case: 2D -> 4D.
+//
+// ALL-LABEL: insert_2D_to_4D(
+// ALL-COUNT-2: shape_cast
+//
+// SHUFFLE: vector.shuffle {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] :
+// SHUFFLE-SAME: vector<16xi8>, vector<4xi8>
+//
+// STRIDED: vector.insert_strided_slice {{.*}} {offsets = [12], strides = [1]}
+// STRIDED-SAME: vector<4xi8> into vector<16xi8>
+func.func @insert_2D_to_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x2x2x2xi8>) -> vector<2x2x2x2xi8>
+{
+ %inserted = vector.insert %arg0, %arg1[1, 1] : vector<2x2xi8> into vector<2x2x2x2xi8>
+ return %inserted : vector<2x2x2x2xi8>
+}
+
+// -----
+
+// **--------------------------------------------------------**
+// Tests of vector.extract
+// **--------------------------------------------------------**
+
+// vector.extract where the source is 1D vector is always unchanged.
+//
+// ALL-LABEL: extract_scalar_from_1D(
+// ALL-SAME: %[[A0:.*]]: vector<4xi8>
+// ALL: %[[EX0:.*]] = vector.extract %[[A0]][2] : i8 from vector<4xi8>
+// ALL: return %[[EX0]] : i8
+func.func @extract_scalar_from_1D(%arg0 : vector<4xi8>) -> i8
+{
+ %extracted = vector.extract %arg0[2] : i8 from vector<4xi8>
+ return %extracted : i8
+}
+
+// ALL-LABEL: extract_scalar_from_2D(
+// ALL-SAME: %[[A0:.*]]: vector<3x4xi8>
+// ALL: %[[SC0:.*]] = vector.shape_cast %[[A0]] : vector<3x4xi8> to vector<12xi8>
+// ALL: %[[EX0:.*]] = vector.extract %[[SC0]][9] : i8 from vector<12xi8>
+// ALL: return %[[EX0]] : i8
+func.func @extract_scalar_from_2D(%arg0 : vector<3x4xi8>) -> i8
+{
+ %extracted = vector.extract %arg0[2, 1] : i8 from vector<3x4xi8>
+ return %extracted : i8
+}
+
+// -----
+
+// ALL-LABEL: extract_1D_from_2D(
+//
+// SHUFFLE: vector.shuffle
+// SHUFFLE-SAME: [4, 5, 6, 7] : vector<12xi8>, vector<12xi8>
+//
+// STRIDED: vector.extract_strided_slice
+// STRIDED-SAME: {offsets = [4], sizes = [4], strides = [1]} : vector<12xi8> to vector<4xi8>
+func.func @extract_1D_from_2D(%arg0 : vector<3x4xi8>) -> vector<4xi8>
+{
+ %extracted = vector.extract %arg0[1] : vector<4xi8> from vector<3x4xi8>
+ return %extracted : vector<4xi8>
+}
+
+// -----
+
+// ALL-LABEL: extract_2D_from_4D(
+//
+// SHUFFLE: vector.shuffle
+// SHUFFLE-SAME: [10, 11] : vector<24xi8>, vector<24xi8>
+//
+// STRIDED: vector.extract_strided_slice
+// STRIDED-SAME: {offsets = [10], sizes = [2], strides = [1]} : vector<24xi8> to vector<2xi8>
+func.func @extract_2D_from_4D(%arg0 : vector<4x3x2x1xi8>) -> vector<2x1xi8> {
+ %extracted = vector.extract %arg0[1, 2] : vector<2x1xi8> from vector<4x3x2x1xi8>
+ return %extracted : vector<2x1xi8>
+}
+
+// **--------------------------------------------------------**
+// Tests of vector.insert_strided_slice
+// **--------------------------------------------------------**
+
+// -----
+
+// ALL-LABEL: insert_strided_slice_1D(
+//
+// SHUFFLE: shuffle {{.*}} [0, 8, 9, 3, 4, 5, 6, 7]
+//
+// STRIDED: insert_strided_slice {{.*}} {offsets = [1], strides = [1]}
+func.func @insert_strided_slice_1D(%arg0 : vector<2xi8>, %arg1 : vector<8xi8>) -> vector<8xi8>
+{
+ %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xi8> into vector<8xi8>
+ return %inserted : vector<8xi8>
+}
+
+// -----
+
+// ALL-LABEL: insert_strided_slice_4D_contiguous(
+//
+// SHUFFLE: vector.shuffle
+// SHUFFLE-SAME: 52, 53, 120, 121
+// SHUFFLE-SAME: 130, 131, 66, 67
+// SHUFFLE-SAME: vector<120xi8>, vector<12xi8>
+//
+// STRIDED: vector.insert_strided_slice
+// STRIDED-SAME: {offsets = [54], strides = [1]}
+// STRIDED-SAME: vector<12xi8> into vector<120xi8>
+
+
+func.func @insert_strided_slice_4D_contiguous(%arg0 : vector<1x2x2x3xi8>, %arg1 : vector<5x4x2x3xi8>) -> vector<5x4x2x3xi8> {
+ %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [2, 1, 0, 0], strides = [1, 1, 1, 1]} : vector<1x2x2x3xi8> into vector<5x4x2x3xi8>
+ return %inserted : vector<5x4x2x3xi8>
+}
+
+// -----
+
+// This insert_strided_slice is not contiguous, and so it is always linearized to a 1D vector.shuffle
+
+// ALL-LABEL: insert_strided_slice_4D_noncontiguous(
+// ALL: vector.shuffle
+// ALL-SAME: [0, 1, 2, 8, 4, 5, 6, 9] : vector<8xi8>, vector<2xi8>
+
+func.func @insert_strided_slice_4D_noncontiguous(%arg0 : vector<1x2x1x1xi8>, %arg1 : vector<1x2x2x2xi8>) -> vector<1x2x2x2xi8> {
+ %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 1, 1], strides = [1, 1, 1, 1]} : vector<1x2x1x1xi8> into vector<1x2x2x2xi8>
+ return %inserted : vector<1x2x2x2xi8>
+}
+
+// -----
+
+// **--------------------------------------------------------**
+// Tests of vector.extract_strided_slice
+// **--------------------------------------------------------**
+
+// ALL-LABEL: extract_strided_slice_1D(
+//
+// SHUFFLE: vector.shuffle {{.*}} [1, 2]
+//
+// STRIDED: vector.extract_strided_slice
+// STRIDED-SAME: {offsets = [1], sizes = [2], strides = [1]}
+// STRIDED-SAME: vector<8xi8> to vector<2xi8>
+func.func @extract_strided_slice_1D(%arg0 : vector<8xi8>) -> vector<2xi8>
+{
+ %extracted = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<8xi8> to vector<2xi8>
+ return %extracted : vector<2xi8>
+}
+
+// -----
+
+// ALL-LABEL: extract_strided_slice_4D_contiguous_1(
+//
+// SHUFFLE: vector.shuffle
+// SHUFFLE-SAME: [3, 4, 5]
+// SHUFFLE-SAME: vector<6xi8>, vector<6xi8>
+//
+// STRIDED: vector.extract_strided_slice
+// STRIDED-SAME: {offsets = [3], sizes = [3], strides = [1]}
+// STRIDED-SAME: vector<6xi8> to vector<3xi8>
+func.func @extract_strided_slice_4D_contiguous_1(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x3x1xi8> {
+ %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 3, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x3x1xi8>
+ return %extracted : vector<1x1x3x1xi8>
+}
+
+// -----
+
+// ALL-LABEL: extract_strided_slice_4D_contiguous_2(
+//
+// SHUFFLE: vector.shuffle
+// SHUFFLE-SAME: [3, 4]
+// SHUFFLE-SAME: vector<6xi8>, vector<6xi8>
+//
+// STRIDED: vector.extract_strided_slice
+// STRIDED-SAME: {offsets = [3], sizes = [2], strides = [1]}
+// STRIDED-SAME: vector<6xi8> to vector<2xi8>
+func.func @extract_strided_slice_4D_contiguous_2(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> {
+ %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x2x1xi8>
+ return %extracted : vector<1x1x2x1xi8>
+}
+
+// -----
+
+// ALL-LABEL: extract_strided_slice_4D_noncontiguous(
+// ALL: vector.shuffle
+// ALL-SAME: [0, 1, 3, 4]
+// ALL-SAME: vector<6xi8>, vector<6xi8>
+func.func @extract_strided_slice_4D_noncontiguous(%arg0 : vector<2x1x3x1xi8>) -> vector<2x1x2x1xi8> {
+ %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0], sizes = [2, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<2x1x2x1xi8>
+ return %extracted : vector<2x1x2x1xi8>
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir b/mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir
similarity index 100%
rename from mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir
rename to mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize/linearize.mlir
similarity index 99%
rename from mlir/test/Dialect/Vector/linearize.mlir
rename to mlir/test/Dialect/Vector/linearize/linearize.mlir
index 9cbf319ffddb2..b7a5448c7dc22 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize/linearize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=preference=Shuffle -verify-diagnostics | FileCheck %s
// CHECK-LABEL: test_linearize
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -297,7 +297,6 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x
// CHECK-LABEL: test_vector_insert
// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
-
// CHECK-DAG: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
// CHECK-DAG: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
diff --git a/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir b/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir
new file mode 100644
index 0000000000000..342936b0f7eed
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir
@@ -0,0 +1,135 @@
+// RUN: mlir-opt %s -split-input-file -test-rank-reduce-strided-slice-ops -verify-diagnostics | FileCheck %s
+
+
+// **---------------------------------------------**
+// Tests of vector.extract_strided_slice
+// **---------------------------------------------**
+
+// The 6 elements extracted are contiguous, so this can be expressed as a rank-1 vector.extract_strided_slice.
+
+// CHECK-LABEL: @extract_strided_slice_2D_to_1D(
+// CHECK-SAME: %[[A:.*]]: vector<5x2xi8>) -> vector<3x2xi8> {
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<5x2xi8> to vector<10xi8>
+// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]]
+// CHECK-SAME: {offsets = [2], sizes = [6], strides = [1]} : vector<10xi8> to vector<6xi8>
+// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<6xi8> to vector<3x2xi8>
+// CHECK: return %[[CASTED]] : vector<3x2xi8>
+func.func @extract_strided_slice_2D_to_1D(%arg0 : vector<5x2xi8>) -> vector<3x2xi8> {
+ %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [3, 2], strides = [1, 1]} : vector<5x2xi8> to vector<3x2xi8>
+ return %extracted : vector<3x2xi8>
+}
+
+// -----
+
+// The 5 elements extracted are not contiguous, so this cannot be expressed as a rank-1 vector.extract_strided_slice.
+
+// CHECK-LABEL: @negative_extract_strided_slice_2D_to_1D(
+// CHECK-SAME: %[[A:.*]]: vector<5x2xi8>) -> vector<5x1xi8> {
+// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK: return %[[EXTRACTED]] : vector<5x1xi8>
+func.func @negative_extract_strided_slice_2D_to_1D(%arg0 : vector<5x2xi8>) -> vector<5x1xi8> {
+ %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [5, 1], strides = [1, 1]} : vector<5x2xi8> to vector<5x1xi8>
+ return %extracted : vector<5x1xi8>
+}
+
+// -----
+
+// The 2 elements extracted are contiguous, so this can be expressed as a rank-1 vector.extract_strided_slice.
+
+// CHECK-LABEL: @extract_strided_slice_4D_leading_ones(
+// CHECK-SAME: %[[A:.*]]: vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> {
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<2x1x3x1xi8> to vector<6xi8>
+// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]]
+// CHECK-SAME: {offsets = [3], sizes = [2], strides = [1]} : vector<6xi8> to vector<2xi8>
+// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<2xi8> to vector<1x1x2x1xi8>
+// CHECK: return %[[CASTED]] : vector<1x1x2x1xi8>
+
+func.func @extract_strided_slice_4D_leading_ones(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> {
+ %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x2x1xi8>
+ return %extracted : vector<1x1x2x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_strided_slice_4D_becomes_2D(
+// CHECK-SAME: %[[A:.*]]: vector<8x7x6x5xi8>) -> vector<2x7x2x5xi8> {
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<8x7x6x5xi8> to vector<56x30xi8>
+// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]]
+// CHECK-SAME: {offsets = [14, 5], sizes = [14, 10], strides = [1, 1]} : vector<56x30xi8> to vector<14x10xi8>
+// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<14x10xi8> to vector<2x7x2x5xi8>
+// CHECK: return %[[CASTED]] : vector<2x7x2x5xi8>
+func.func @extract_strided_slice_4D_becomes_2D(%arg0 : vector<8x7x6x5xi8>) -> vector<2x7x2x5xi8> {
+ %extracted = vector.extract_strided_slice %arg0 {offsets = [2, 0, 1, 0], sizes = [2, 7, 2, 5], strides = [1, 1, 1, 1]} : vector<8x7x6x5xi8> to vector<2x7x2x5xi8>
+ return %extracted : vector<2x7x2x5xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_strided_slice_4D_becomes_3D(
+// CHECK-SAME: %[[A:.*]]: vector<8x7x6x5xi8>) -> vector<8x2x6x2xi8> {
+ // CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<8x7x6x5xi8> to vector<8x42x5xi8>
+ // CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]]
+ // CHECK-SAME: {offsets = [0, 12, 1], sizes = [8, 12, 2], strides = [1, 1, 1]} : vector<8x42x5xi8> to vector<8x12x2xi8>
+ // CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<8x12x2xi8> to vector<8x2x6x2xi8>
+ // CHECK: return %[[CASTED]] : vector<8x2x6x2xi8>
+
+func.func @extract_strided_slice_4D_becomes_3D(%arg0 : vector<8x7x6x5xi8>) -> vector<8x2x6x2xi8> {
+ %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 2, 0, 1], sizes = [8, 2, 6, 2], strides = [1, 1, 1, 1]} : vector<8x7x6x5xi8> to vector<8x2x6x2xi8>
+ return %extracted : vector<8x2x6x2xi8>
+}
+
+// -----
+
+// **---------------------------------------------**
+// Tests of vector.insert_strided_slice
+// **---------------------------------------------**
+
+
+// CHECK-LABEL: @negative_insert_strided_slice(
+// CHECK-SAME: %[[A:.*]]: vector<2x2xi8>, %[[B:.*]]: vector<2x1xi8>) -> vector<2x2xi8> {
+// CHECK: %[[INSERTED:.*]] = vector.insert_strided_slice %[[B]], %[[A]]
+// CHECK: return %[[INSERTED]] : vector<2x2xi8>
+func.func @negative_insert_strided_slice(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1xi8>) -> vector<2x2xi8> {
+ %inserted = vector.insert_strided_slice %arg1, %arg0 {offsets = [0, 1], strides = [1, 1]} : vector<2x1xi8> into vector<2x2xi8>
+ return %inserted : vector<2x2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @positive_insert_strided_slice(
+// CHECK-SAME: %[[A:.*]]: vector<2x2xi8>, %[[B:.*]]: vector<1x2xi8>) -> vector<2x2xi8> {
+// CHECK-DAG: %[[SCA:.*]] = vector.shape_cast %[[A]] : vector<2x2xi8> to vector<4xi8>
+// CHECK-DAG: %[[SCB:.*]] = vector.shape_cast %[[B]] : vector<1x2xi8> to vector<2xi8>
+// CHECK: %[[INSERTED:.*]] = vector.insert_strided_slice %[[SCB]], %[[SCA]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi8> into vector<4xi8>
+// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INSERTED]] : vector<4xi8> to vector<2x2xi8>
+// CHECK: return %[[CASTED]] : vector<2x2xi8>
+
+func.func @positive_insert_strided_slice(%arg0 : vector<2x2xi8>, %arg1 : vector<1x2xi8>) -> vector<2x2xi8> {
+ %inserted = vector.insert_strided_slice %arg1, %arg0 {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi8> into vector<2x2xi8>
+ return %inserted : vector<2x2xi8>
+}
+
+// -----
+
+func.func @test_extract_strided_slice_4D(%arg0 : vector<2x2x2x2xi8>) -> vector<1x2x1x2xi8> {
+ %0 = vector.extract_strided_slice %arg0
+ {offsets = [1, 0, 1, 0],
+ sizes = [1, 2, 1, 2],
+ strides = [1, 1, 1, 1]} : vector<2x2x2x2xi8> to vector<1x2x1x2xi8>
+ return %0 : vector<1x2x1x2xi8>
+}
+
+// -----
+
+// Equivent to the above but now with an insert strided slice.
+
+
+func.func @test_insert_strided_slice_4D(%arg0 : vector<2x2x2x2xi8>, %arg1 : vector<1x2x1x2xi8>) -> vector<2x2x2x2xi8> {
+ %0 = vector.insert_strided_slice %arg1, %arg0
+ {offsets = [1, 0, 1, 0],
+ strides = [1, 1, 1, 1]} : vector<1x2x1x2xi8> into vector<2x2x2x2xi8>
+ return %0 : vector<2x2x2x2xi8>
+}
+
+
diff --git a/mlir/test/lib/Dialect/Vector/CMakeLists.txt b/mlir/test/lib/Dialect/Vector/CMakeLists.txt
index e16937029ac0e..1ce069599af43 100644
--- a/mlir/test/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Vector/CMakeLists.txt
@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRVectorTestPasses
TestVectorTransforms.cpp
+ TestVectorLinearize.cpp
EXCLUDE_FROM_LIBMLIR
)
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp
new file mode 100644
index 0000000000000..21d6356bf04fa
--- /dev/null
+++ b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp
@@ -0,0 +1,248 @@
+//===- TestVectorLinearize.cpp - Test Vector linearization ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <optional>
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math//IR/Math.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+struct TestVectorLinearize final
+ : public PassWrapper<TestVectorLinearize, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+
+ TestVectorLinearize() = default;
+
+ TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+
+ StringRef getArgument() const override { return "test-vector-linearize"; }
+ StringRef getDescription() const override {
+ return "Linearizes ND vectors for N >= 2 into 1D vectors";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<VectorDialect, arith::ArithDialect, math::MathDialect>();
+ }
+ Option<InsertExtractLinearizePreference> preference{
+ *this, "preference",
+ llvm::cl::desc("Corresponds 1:1 with InsertExtractLinearizePreference"),
+ llvm::cl::values(
+ clEnumValN(
+ static_cast<int>(InsertExtractLinearizePreference::Strided),
+ "Strided", ""),
+ clEnumValN(
+ static_cast<int>(InsertExtractLinearizePreference::Shuffle),
+ "Shuffle", ""))};
+
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ TypeConverter converter;
+ RewritePatternSet patterns(&context);
+ ConversionTarget target(context);
+ initializeForVectorLinearize(converter);
+ populateForFullVectorLinearize(converter, target, patterns, preference);
+
+ mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+ converter, patterns, target);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+struct TestRankReduceStridedSliceOps final
+ : public PassWrapper<TestRankReduceStridedSliceOps, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRankReduceStridedSliceOps)
+
+ TestRankReduceStridedSliceOps() = default;
+
+ StringRef getArgument() const override {
+ return "test-rank-reduce-strided-slice-ops";
+ }
+ StringRef getDescription() const override {
+ return "Test pass for rank-reducing strided slice ops.";
+ }
+
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ TypeConverter typeConverter;
+ RewritePatternSet patterns(&context);
+ ConversionTarget target(context);
+
+ VectorLinearizePatterns()
+ .enableAll(false)
+ .enable(LinearizePattern::RankReduceInsertStridedSlice)
+ .enable(LinearizePattern::RankReduceExtractStridedSlice)
+ .addToPatternSet(typeConverter, patterns);
+
+ typeConverter.addConversion(
+ [](Type t) -> std::optional<Type> { return t; });
+
+ typeConverter.addSourceMaterialization(
+ [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value { return inputs.front(); });
+
+ target.markUnknownOpDynamicallyLegal(
+ [&](Operation *op) -> std::optional<bool> {
+ if (auto insertOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
+ return !getCollapsedInsertStridedSliceShape(insertOp).has_value();
+ }
+ if (auto extractOp = dyn_cast<vector::ExtractStridedSliceOp>(op)) {
+ return !getCollapsedExtractStridedSliceShape(extractOp).has_value();
+ }
+ return true;
+ });
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+struct TestVectorBitWidthLinearize final
+ : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
+
+ TestVectorBitWidthLinearize() = default;
+ TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
+ : PassWrapper(pass) {}
+
+ StringRef getArgument() const override {
+ return "test-bit-width-constrained-vector-linearize";
+ }
+ StringRef getDescription() const override {
+ return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
+ "on inner-most dimension's bit width. If the inner-most dimension "
+ "exceded a threshold, the op is not linearized.";
+ }
+ Option<unsigned> targetVectorBitwidth{
+ *this, "target-vector-bitwidth",
+ llvm::cl::desc(
+ "Minimum vector bitwidth to enable the flattening transformation"),
+ llvm::cl::init(std::numeric_limits<unsigned>::max())};
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ TypeConverter typeConverter;
+ RewritePatternSet patterns(&context);
+ ConversionTarget target(context);
+ populateWithBitWidthConstraints(typeConverter, target, patterns,
+ targetVectorBitwidth);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+
+private:
+ /// If `type` is VectorType with trailing dimension of (bit) size greater than
+ /// or equal to `targetBitWidth`, its defining op is considered legal.
+ static bool
+ isNotLinearizableBecauseLargeInnerDimension(Type type,
+ unsigned targetBitWidth) {
+
+ VectorType vecType = dyn_cast<VectorType>(type);
+
+ // Not linearizable for reasons other than what this function checks.
+ if (!vecType || vecType.getRank() == 0)
+ return false;
+
+ // The width of the type 'index' is unbounded (and therefore potentially
+ // above the target width).
+ if (vecType.getElementType().isIndex())
+ return true;
+
+ unsigned finalDimSize = vecType.getShape().back();
+ unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
+ unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
+ return trailingVecDimBitWidth >= targetBitWidth;
+ }
+
+ static bool
+ isNotLinearizableBecauseLargeInnerDimension(Operation *op,
+ unsigned targetBitWidth) {
+ // Check on bitwidths.
+ SmallVector<std::pair<Type, unsigned>> toCheck =
+ getTypeBitWidthBoundPairs(op, targetBitWidth);
+ return std::any_of(toCheck.begin(), toCheck.end(),
+ [&](std::pair<Type, unsigned> typeWidth) {
+ return isNotLinearizableBecauseLargeInnerDimension(
+ typeWidth.first, typeWidth.second);
+ });
+ }
+
+ static void populateWithBitWidthConstraints(TypeConverter &typeConverter,
+ ConversionTarget &target,
+ RewritePatternSet &patterns,
+ unsigned targetBitWidth) {
+
+ initializeForVectorLinearize(typeConverter);
+ populateForFullVectorLinearize(typeConverter, target, patterns,
+ InsertExtractLinearizePreference::Shuffle);
+
+ // Extend the set of legal ops to include those with large inner-most
+ // dimensions on selected operands/results.
+ target.markUnknownOpDynamicallyLegal(
+ [=](Operation *op) -> std::optional<bool> {
+ if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
+ return true;
+ }
+ return {};
+ });
+ }
+
+ /// Get the set of operand/result types to check for sufficiently
+ /// small inner-most dimension size.
+ static SmallVector<std::pair<Type, unsigned>>
+ getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
+
+ if (auto insertOp = dyn_cast<InsertOp>(op)) {
+ unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
+ ? targetBitWidth + 1
+ : targetBitWidth;
+ return {{insertOp.getValueToStoreType(), w}};
+ }
+
+ auto resultTypes = op->getResultTypes();
+ SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+ resultsWithBitWidth.reserve(resultTypes.size());
+ for (Type type : resultTypes) {
+ resultsWithBitWidth.push_back({type, targetBitWidth});
+ }
+ return resultsWithBitWidth;
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+extern void registerTestVectorLinearize() {
+ PassRegistration<TestVectorLinearize>();
+ PassRegistration<TestVectorBitWidthLinearize>();
+ PassRegistration<TestRankReduceStridedSliceOps>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f4f32e9339870..5c75d32c22236 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -17,7 +17,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -837,160 +836,6 @@ struct TestVectorEmulateMaskedLoadStore final
}
};
-/// Get the set of operand/result types to check for sufficiently
-/// small inner-most dimension size.
-static SmallVector<std::pair<Type, unsigned>>
-getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
-
- if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
- unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
- ? targetBitWidth + 1
- : targetBitWidth;
- return {{insertOp.getValueToStoreType(), w}};
- }
-
- auto resultTypes = op->getResultTypes();
- SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
- resultsWithBitWidth.reserve(resultTypes.size());
- for (Type type : resultTypes) {
- resultsWithBitWidth.push_back({type, targetBitWidth});
- }
- return resultsWithBitWidth;
-}
-
-/// If `type` is VectorType with trailing dimension of (bit) size greater than
-/// or equal to `targetBitWidth`, its defining op is considered legal.
-static bool
-isNotLinearizableBecauseLargeInnerDimension(Type type,
- unsigned targetBitWidth) {
-
- VectorType vecType = dyn_cast<VectorType>(type);
-
- // Not linearizable for reasons other than what this function checks.
- if (!vecType || vecType.getRank() == 0)
- return false;
-
- // The width of the type 'index' is unbounded (and therefore potentially above
- // the target width).
- if (vecType.getElementType().isIndex())
- return true;
-
- unsigned finalDimSize = vecType.getShape().back();
- unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
- unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
- return trailingVecDimBitWidth >= targetBitWidth;
-}
-
-static bool
-isNotLinearizableBecauseLargeInnerDimension(Operation *op,
- unsigned targetBitWidth) {
- // Check on bitwidths.
- SmallVector<std::pair<Type, unsigned>> toCheck =
- getTypeBitWidthBoundPairs(op, targetBitWidth);
- return llvm::any_of(toCheck, [&](std::pair<Type, unsigned> typeWidth) {
- return isNotLinearizableBecauseLargeInnerDimension(typeWidth.first,
- typeWidth.second);
- });
-}
-
-void populateWithBitWidthConstraints(TypeConverter &typeConverter,
- ConversionTarget &target,
- unsigned targetBitWidth) {
-
- // The general purpose definition of what ops are legal must come first.
- populateForVectorLinearize(typeConverter, target);
-
- // Extend the set of legal ops to include those with large inner-most
- // dimensions on selected operands/results.
- target.markUnknownOpDynamicallyLegal(
- [=](Operation *op) -> std::optional<bool> {
- if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
- return true;
- }
- return {};
- });
-}
-
-struct TestVectorBitWidthLinearize final
- : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
-
- TestVectorBitWidthLinearize() = default;
- TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
- : PassWrapper(pass) {}
-
- StringRef getArgument() const override {
- return "test-bit-width-constrained-vector-linearize";
- }
- StringRef getDescription() const override {
- return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
- "in inner-most dimension's bit width.";
- }
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect>();
- }
-
- Option<unsigned> targetVectorBitwidth{
- *this, "target-vector-bitwidth",
- llvm::cl::desc(
- "Minimum vector bitwidth to enable the flattening transformation"),
- llvm::cl::init(std::numeric_limits<unsigned>::max())};
- void runOnOperation() override {
- auto *context = &getContext();
-
- TypeConverter typeConverter;
- RewritePatternSet patterns(context);
- ConversionTarget target(*context);
-
- populateWithBitWidthConstraints(typeConverter, target,
- targetVectorBitwidth);
-
- vector::populateVectorLinearizeBasePatterns(typeConverter, target,
- patterns);
-
- vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target,
- patterns);
-
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- return signalPassFailure();
- }
-};
-
-struct TestVectorLinearize final
- : public PassWrapper<TestVectorLinearize, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
-
- TestVectorLinearize() = default;
-
- StringRef getArgument() const override { return "test-vector-linearize"; }
- StringRef getDescription() const override {
- return "Linearizes ND vectors for N >= 2 into 1D vectors";
- }
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect, arith::ArithDialect>();
- }
-
- void runOnOperation() override {
- MLIRContext &context = getContext();
- TypeConverter converter;
- RewritePatternSet patterns(&context);
- ConversionTarget target(context);
-
- vector::populateForVectorLinearize(converter, target);
-
- vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
- vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
- patterns);
- mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
- converter, patterns, target);
-
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- return signalPassFailure();
- }
-};
-
struct TestEliminateVectorMasks
: public PassWrapper<TestEliminateVectorMasks,
OperationPass<func::FuncOp>> {
@@ -1062,10 +907,6 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorEmulateMaskedLoadStore>();
- PassRegistration<TestVectorLinearize>();
-
- PassRegistration<TestVectorBitWidthLinearize>();
-
PassRegistration<TestEliminateVectorMasks>();
}
} // namespace test
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2e08ae6f37980..f52f36107e301 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -155,6 +155,7 @@ void registerTestTopologicalSortAnalysisPass();
void registerTestTransformDialectEraseSchedulePass();
void registerTestPassStateExtensionCommunication();
void registerTestVectorLowerings();
+void registerTestVectorLinearize();
void registerTestVectorReductionToSPIRVDotProd();
void registerTestVulkanRunnerPipeline();
void registerTestWrittenToPass();
@@ -300,6 +301,7 @@ void registerTestPasses() {
mlir::test::registerTestTransformDialectEraseSchedulePass();
mlir::test::registerTestPassStateExtensionCommunication();
mlir::test::registerTestVectorLowerings();
+ mlir::test::registerTestVectorLinearize();
mlir::test::registerTestVectorReductionToSPIRVDotProd();
mlir::test::registerTestVulkanRunnerPipeline();
mlir::test::registerTestWrittenToPass();
More information about the Mlir-commits
mailing list