[Mlir-commits] [mlir] [mlir][vector] Refactor vector linearization patterns (PR #142685)
James Newling
llvmlistbot at llvm.org
Tue Jun 3 15:53:18 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/142685
>From 2e6562533eb69815a496189ce6854e6440adcdc4 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 3 Jun 2025 15:13:44 -0700
Subject: [PATCH 1/2] nfc changes to linearization
---
.../Vector/Transforms/VectorLinearize.h | 252 +++++++++
.../Vector/Transforms/VectorRewritePatterns.h | 33 --
.../Vector/Transforms/VectorLinearize.cpp | 508 +++++++++---------
.../linearize-subject-to-bitwidth.mlir | 0
.../Vector/{ => linearize}/linearize.mlir | 0
mlir/test/lib/Dialect/Vector/CMakeLists.txt | 1 +
.../Dialect/Vector/TestVectorLinearize.cpp | 185 +++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 159 ------
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
9 files changed, 696 insertions(+), 444 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
rename mlir/test/Dialect/Vector/{ => linearize}/linearize-subject-to-bitwidth.mlir (100%)
rename mlir/test/Dialect/Vector/{ => linearize}/linearize.mlir (100%)
create mode 100644 mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
new file mode 100644
index 0000000000000..cd62de640d088
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
@@ -0,0 +1,252 @@
+//===- VectorLinearize.h - Vector linearization patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace vector {
+
+/// Initialize `typeConverter` with source and target materialization logic
+/// using shape_casts to/from 1D vectors.
+void initializeForVectorLinearize(TypeConverter &typeConverter);
+
+/// Initialize `conversionTarget`, and `patterns` for linearization. Here
+/// linearization means converting a single operation with 1+ vector
+/// operand/result of rank>1, into a new single operation whose vector operands
+/// and results are all of rank<=1.
+///
+/// This function initializes `conversionTarget` with the set of operations that
+/// are illegal and consequently must be converted to a linearized form. It
+/// also populates the set of patterns that can be run to convert illegal
+/// operations, and what priority/benefit they have.
+///
+/// Note: the set of legal operations can be extended by a user if, for example,
+/// certain rank>1 vectors are considered valid, by adding additional
+/// dynamically legal ops to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to make it easy to reuse generic structural type
+/// conversions for linearizing scf/cf/func operations
+void populateForFullVectorLinearize(const TypeConverter &,
+ ConversionTarget &conversionTarget,
+ RewritePatternSet &patterns);
+
+/// The set of patterns available for linearization.
+enum class LinearizePattern {
+
+ /// This pattern converts a constant (or poison) vector of rank>1 into a
+ /// 1D vector followed by a shape_cast.
+ ///
+ /// BEFORE
+ /// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+ ///
+ /// AFTER
+ /// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+ /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+ LinearizeConstantLike = 0,
+
+ /// BEFORE
+ /// %2 = math.sin %arg0 : vector<2x2xf32>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+ /// %1 = math.sin %0 : vector<4xf32>
+ /// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+ LinearizeVectorizable,
+
+ /// BEFORE
+ /// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+ ///
+ /// AFTER
+ /// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+ /// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+ /// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+ LinearizeVectorBitCast,
+
+ /// This pattern currently only supports 2D masks with a unit outer
+ /// dimension.
+ ///
+ /// BEFORE
+ /// %mask_2d = vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+ ///
+ /// AFTER
+ /// [...]
+ /// %mask_1d= vector.create_mask %mul : vector<4xi1>
+ /// %mask_2d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+ ///
+ /// where `%mul` is a function of `%arg0` and `%arg1`.
+ LinearizeVectorCreateMask,
+
+ /// This pattern converts the ShuffleOp that works on nD (n > 1)
+ /// vectors to a ShuffleOp that works on linearized vectors.
+ ///
+ /// BEFORE
+ /// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+ ///
+ /// AFTER
+ /// %v1_1d = vector.shape_cast %v1_3d : [...]
+ /// %v2_1d = vector.shape_cast %v2_3d : [...]
+ /// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+ /// %shuffle_3d = vector.shape_cast %shuffle_1d : [...]
+ ///
+ /// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`.
+ LinearizeVectorShuffle,
+
+ /// BEFORE
+ /// %1 = vector.splat %value : vector<4x4xf32>
+ ///
+ /// AFTER
+ /// %0 = vector.splat %value : vector<16xf32>
+ /// %1 = vector.shape_cast %0 : vector<16xf32> to vector<4x4xf32>
+ LinearizeVectorSplat,
+
+ /// BEFORE
+ /// %extract = vector.extract %src [ position ]
+ ///
+ /// AFTER
+ /// %src_1d = vector.shape_cast %src : [...]
+ /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices ]
+ /// %out_nd = vector.shape_cast %out_1d : [...]
+ ///
+ /// `shuffle_indices` is computed from `position` of original extract.
+ VectorExtractToRankOneShuffle,
+
+ /// This pattern converts a vector.extract_strided_slice operation into a
+ /// vector.shuffle operation that has a rank-1 (linearized) operand and
+ /// result.
+ ///
+ /// BEFORE
+ /// %out_nd = vector.extract_strided_slice %source_nd
+ /// { offsets = [..], strides = [..], sizes = [..] }
+ ///
+ /// AFTER
+ /// %source_1d = vector.shape_cast %source_nd [...]
+ /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+ /// %out_nd = vector.shape_cast %out_1d [...]
+ ///
+ /// `shuffle_indices_1d` is computed using the offsets and sizes of the
+ /// original vector.extract_strided_slice operation.
+ VectorExtractStridedSliceToRankOneShuffle,
+
+ /// BEFORE
+ /// %insert = vector.insert %src %dst [ position ]
+ ///
+ /// AFTER
+ /// %src_1d = vector.shape_cast %src : [...]
+ /// %dst_1d = vector.shape_cast %dst : [...]
+ /// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ]
+ /// %out_nd = vector.shape_cast %out_1d : [...]
+ ///
+ /// `shuffle_indices` is computed from `position`.
+ VectorInsertToRankOneShuffle,
+
+ /// This pattern converts a vector.insert_strided_slice operation into a
+ /// vector.shuffle operation that has rank-1 (linearized) operands and result.
+ ///
+ /// BEFORE
+ /// %0 = vector.insert_strided_slice %to_store, %into
+ /// {offsets = [1, 0, 0, 0], strides = [1, 1]}
+ /// : vector<2x2xi8> into vector<2x1x3x2xi8>
+ /// AFTER
+ /// %to_store_1d
+ /// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+ /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+ /// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+ /// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+ ///
+ /// where shuffle_indices_1d in this case is
+ /// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+ /// ^^^^^^^^^^^^^^
+ /// to_store_1d
+ VectorInsertStridedSliceToRankOneShuffle,
+
+ /// The number of patterns in this enum.
+ N
+};
+
+/// This class contains functions to control the set of linearization patterns
+/// to include for the conversion, and their priority.
+struct VectorLinearizePatterns {
+
+public:
+ /// By default all patterns are enabled and have benefit 1.
+ VectorLinearizePatterns() {
+ enabled.fill(true);
+ benefits.fill(PatternBenefit(1));
+ }
+
+ /// Add the patterns enabled for the conversion to `patterns`.
+ void addToPatternSet(const TypeConverter &,
+ RewritePatternSet &patterns) const;
+
+ VectorLinearizePatterns &enable(LinearizePattern id, bool e = true) {
+ enabled[static_cast<unsigned>(id)] = e;
+ return *this;
+ }
+
+ VectorLinearizePatterns &enableAll(bool e = true) {
+ enabled.fill(e);
+ return *this;
+ }
+
+ bool isEnabled(LinearizePattern id) const {
+ return enabled[static_cast<unsigned>(id)];
+ }
+
+ PatternBenefit getBenefit(LinearizePattern id) const {
+ return benefits[static_cast<unsigned>(id)];
+ }
+
+ VectorLinearizePatterns &setBenefit(LinearizePattern id,
+ PatternBenefit benefit) {
+ getBenefitRef(id) = benefit;
+ return *this;
+ }
+
+ VectorLinearizePatterns &incrementBenefit(LinearizePattern id,
+ unsigned inc = 1) {
+ getBenefitRef(id) = getBenefit(id).getBenefit() + 1;
+ return *this;
+ }
+
+private:
+ std::array<bool, static_cast<unsigned>(LinearizePattern::N)> enabled;
+ std::array<PatternBenefit, static_cast<unsigned>(LinearizePattern::N)>
+ benefits;
+
+ PatternBenefit &getBenefitRef(LinearizePattern id) {
+ unsigned idInt = static_cast<unsigned>(id);
+ assert(idInt < static_cast<unsigned>(LinearizePattern::N) &&
+ "invalid linearization pattern id");
+ return benefits[idInt];
+ }
+};
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumeratates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+SmallVector<int64_t> getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+ ArrayRef<int64_t> large,
+ ArrayRef<int64_t> offsets);
+
+} // namespace vector
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..6954cb7172129 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,39 +406,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
-/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-///
-/// Definition: here 'linearization' means converting a single operation with
-/// 1+ vector operand/result of rank>1, into a new single operation whose
-/// vector operands and results are all of rank<=1.
-///
-/// This function registers (1) which operations are legal, and hence should not
-/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
-/// materialze the conversion (with shape_cast)
-///
-/// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, by adding additional
-/// dynamically legal ops to `conversionTarget`.
-///
-/// Further note: the choice to use a dialect conversion design for
-/// linearization is to make it easy to reuse generic structural type
-/// conversions for linearizing scf/cf/func operations
-void populateForVectorLinearize(TypeConverter &typeConverter,
- ConversionTarget &conversionTarget);
-
-/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
-/// contains patterns for converting ConstantLike, Vectorizable, and
-/// vector::BitCast ops.
-void populateVectorLinearizeBasePatterns(const TypeConverter &,
- const ConversionTarget &,
- RewritePatternSet &patterns);
-
-/// Populates `patterns` for linearizing ND (N >= 2) vector operations
-/// to 1D vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
- const ConversionTarget &,
- RewritePatternSet &patterns);
-
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..2367e6c99a5f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,9 +10,10 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
@@ -52,7 +53,7 @@ struct LinearizeConstantLike final
using OpTraitConversionPattern::OpTraitConversionPattern;
LinearizeConstantLike(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -94,7 +95,7 @@ struct LinearizeVectorizable final
public:
LinearizeVectorizable(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -109,17 +110,7 @@ struct LinearizeVectorizable final
}
};
-template <typename TOp>
-static bool stridesAllOne(TOp op) {
- static_assert(
- std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
- std::is_same_v<TOp, vector::InsertStridedSliceOp>,
- "expected vector.extract_strided_slice or vector.insert_strided_slice");
- ArrayAttr strides = op.getStrides();
- return llvm::all_of(strides, isOneInteger);
-}
-
-/// Convert an array of attributes into a vector of integers, if possible.
+/// Convert an array of attributes into a vector of integers.
static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
if (!attrs)
return failure();
@@ -135,89 +126,12 @@ static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
return ints;
}
-/// Consider inserting a vector of shape `small` into a vector of shape `large`,
-/// at position `offsets`: this function enumeratates all the indices in `large`
-/// that are written to. The enumeration is with row-major ordering.
-///
-/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
-/// positions written to are (1,3) and (1,4), which have linearized indices 8
-/// and 9. So [8,9] is returned.
-///
-/// The length of the returned vector is equal to the number of elements in
-/// the shape `small` (i.e. the product of dimensions of `small`).
-SmallVector<int64_t> static getStridedSliceInsertionIndices(
- ArrayRef<int64_t> small, ArrayRef<int64_t> large,
- ArrayRef<int64_t> offsets) {
-
- // Example of alignment between, `large`, `small` and `offsets`:
- // large = 4, 5, 6, 7, 8
- // small = 1, 6, 7, 8
- // offsets = 2, 3, 0
- //
- // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
- assert((large.size() >= small.size()) &&
- "rank of 'large' cannot be lower than rank of 'small'");
- assert((large.size() >= offsets.size()) &&
- "rank of 'large' cannot be lower than the number of offsets");
- unsigned delta = large.size() - small.size();
- unsigned nOffsets = offsets.size();
- auto getSmall = [&](int64_t i) -> int64_t {
- return i >= delta ? small[i - delta] : 1;
- };
- auto getOffset = [&](int64_t i) -> int64_t {
- return i < nOffsets ? offsets[i] : 0;
- };
-
- // Using 2 vectors of indices, at each iteration populate the updated set of
- // indices based on the old set of indices, and the size of the small vector
- // in the current iteration.
- SmallVector<int64_t> indices{0};
- int64_t stride = 1;
- for (int i = large.size() - 1; i >= 0; --i) {
- int64_t currentSize = indices.size();
- int64_t smallSize = getSmall(i);
- int64_t nextSize = currentSize * smallSize;
- SmallVector<int64_t> nextIndices(nextSize);
- int64_t *base = nextIndices.begin();
- int64_t offset = getOffset(i) * stride;
- for (int j = 0; j < smallSize; ++j) {
- for (int k = 0; k < currentSize; ++k) {
- base[k] = indices[k] + offset;
- }
- offset += stride;
- base += currentSize;
- }
- stride *= large[i];
- indices = std::move(nextIndices);
- }
- return indices;
-}
-
-/// This pattern converts a vector.extract_strided_slice operation into a
-/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
-///
-/// For example, the following:
-///
-/// ```
-/// vector.extract_strided_slice %source
-/// { offsets = [..], strides = [..], sizes = [..] }
-/// ```
-///
-/// is converted to :
-/// ```
-/// %source_1d = vector.shape_cast %source
-/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-/// %out_nd = vector.shape_cast %out_1d
-/// ```
-///
-/// `shuffle_indices_1d` is computed using the offsets and sizes of the original
-/// vector.extract_strided_slice operation.
-struct LinearizeVectorExtractStridedSlice final
- : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
+struct VectorExtractStridedSliceToRankOneShuffle final
+ : public OpConversionPattern<vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
- MLIRContext *context,
- PatternBenefit benefit = 1)
+ VectorExtractStridedSliceToRankOneShuffle(const TypeConverter &typeConverter,
+ MLIRContext *context,
+ PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
@@ -231,7 +145,7 @@ struct LinearizeVectorExtractStridedSlice final
// Expect a legalization failure if the strides are not all 1 (if ever the
// verifier for extract_strided_slice allows non-1 strides).
- if (!stridesAllOne(extractStridedSliceOp)) {
+ if (extractStridedSliceOp.hasNonUnitStrides()) {
return rewriter.notifyMatchFailure(
extractStridedSliceOp,
"extract_strided_slice with strides != 1 not supported");
@@ -249,7 +163,7 @@ struct LinearizeVectorExtractStridedSlice final
ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
- SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
+ SmallVector<int64_t> indices = vector::getStridedSliceInsertionIndices(
outputShape, inputShape, offsets.value());
Value srcVector = adaptor.getVector();
@@ -259,36 +173,24 @@ struct LinearizeVectorExtractStridedSlice final
}
};
-/// This pattern converts a vector.insert_strided_slice operation into a
-/// vector.shuffle operation that has rank-1 (linearized) operands and result.
-///
-/// For example, the following:
-/// ```
-/// %0 = vector.insert_strided_slice %to_store, %into
-/// {offsets = [1, 0, 0, 0], strides = [1, 1]}
-/// : vector<2x2xi8> into vector<2x1x3x2xi8>
-/// ```
-///
-/// is converted to
-/// ```
-/// %to_store_1d
-/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
-/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
-/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
-/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
-/// ```
-///
-/// where shuffle_indices_1d in this case is
-/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
-/// ^^^^^^^^^^^^^^
-/// to_store_1d
-///
-struct LinearizeVectorInsertStridedSlice final
- : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
+static Value asRankOne(ConversionPatternRewriter &rewriter, Value v) {
+ auto vType = dyn_cast<VectorType>(v.getType());
+ assert(vType && "expected vector type");
+ assert(vType.getRank() <= 1 && "expected rank-0 or rank-1 type");
+ if (vType.getRank() == 1)
+ return v;
+ // Convert rank-0 vector to rank-1 vector.
+ v = rewriter.create<vector::ShapeCastOp>(
+ v.getLoc(), VectorType::get({1}, vType.getElementType()), v);
+ return v;
+}
+
+struct VectorInsertStridedSliceToRankOneShuffle final
+ : public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
- MLIRContext *context,
- PatternBenefit benefit = 1)
+ VectorInsertStridedSliceToRankOneShuffle(const TypeConverter &typeConverter,
+ MLIRContext *context,
+ PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
@@ -298,7 +200,7 @@ struct LinearizeVectorInsertStridedSlice final
// Expect a legalization failure if the strides are not all 1 (if ever the
// verifier for insert_strided_slice allows non-1 strides).
- if (!stridesAllOne(insertStridedSliceOp)) {
+ if (insertStridedSliceOp.hasNonUnitStrides()) {
return rewriter.notifyMatchFailure(
insertStridedSliceOp,
"insert_strided_slice with strides != 1 not supported");
@@ -317,7 +219,7 @@ struct LinearizeVectorInsertStridedSlice final
return rewriter.notifyMatchFailure(insertStridedSliceOp,
"failed to get integer offsets");
}
- SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
+ SmallVector<int64_t> sliceIndices = vector::getStridedSliceInsertionIndices(
inputShape, outputShape, offsets.value());
SmallVector<int64_t> indices(nOutputElements);
@@ -326,7 +228,7 @@ struct LinearizeVectorInsertStridedSlice final
indices[sliceIndex] = index + nOutputElements;
}
- Value flatToStore = adaptor.getValueToStore();
+ Value flatToStore = asRankOne(rewriter, adaptor.getValueToStore());
Value flatDest = adaptor.getDest();
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
flatDest.getType(), flatDest,
@@ -335,22 +237,11 @@ struct LinearizeVectorInsertStridedSlice final
}
};
-/// This pattern converts the ShuffleOp that works on nD (n > 1)
-/// vectors to a ShuffleOp that works on linearized vectors.
-/// Following,
-/// vector.shuffle %v1, %v2 [ shuffle_indices ]
-/// is converted to :
-/// %v1_1d = vector.shape_cast %v1
-/// %v2_1d = vector.shape_cast %v2
-/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
-/// %out_nd = vector.shape_cast %out_1d
-// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
-/// of the original shuffle operation.
struct LinearizeVectorShuffle final
: public OpConversionPattern<vector::ShuffleOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeVectorShuffle(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
@@ -360,8 +251,8 @@ struct LinearizeVectorShuffle final
getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
assert(dstType && "vector type destination expected.");
- Value vec1 = adaptor.getV1();
- Value vec2 = adaptor.getV2();
+ Value vec1 = asRankOne(rewriter, adaptor.getV1());
+ Value vec2 = asRankOne(rewriter, adaptor.getV2());
int shuffleSliceLen = 1;
int rank = shuffleOp.getV1().getType().getRank();
@@ -395,20 +286,11 @@ struct LinearizeVectorShuffle final
}
};
-/// This pattern converts the ExtractOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-/// vector.extract %source [ position ]
-/// is converted to :
-/// %source_1d = vector.shape_cast %source
-/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-/// %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original extract.
-struct LinearizeVectorExtract final
+struct VectorExtractToRankOneShuffle final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorExtract(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ VectorExtractToRankOneShuffle(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
@@ -436,30 +318,21 @@ struct LinearizeVectorExtract final
linearizedOffset += offsets[i] * size;
}
+ Value v0 = asRankOne(rewriter, adaptor.getVector());
llvm::SmallVector<int64_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), linearizedOffset);
- rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
+ rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractOp, dstTy, v0, v0,
+ indices);
return success();
}
};
-/// This pattern converts the InsertOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-/// vector.insert %source %destination [ position ]
-/// is converted to :
-/// %source_1d = vector.shape_cast %source
-/// %destination_1d = vector.shape_cast %destination
-/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
-/// ] %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original insert.
-struct LinearizeVectorInsert final
+struct VectorInsertToRankOneShuffle final
: public OpConversionPattern<vector::InsertOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorInsert(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ VectorInsertToRankOneShuffle(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
@@ -508,25 +381,18 @@ struct LinearizeVectorInsert final
// [offset+srcNumElements, end)
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
+ insertOp, dstTy, adaptor.getDest(),
+ asRankOne(rewriter, adaptor.getValueToStore()), indices);
return success();
}
};
-/// This pattern converts the BitCastOp that works on nD (n > 1)
-/// vectors to a BitCastOp that works on linearized vectors.
-/// Following,
-/// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
-/// is converted to :
-/// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
-/// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
-/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
struct LinearizeVectorBitCast final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeVectorBitCast(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
@@ -535,22 +401,16 @@ struct LinearizeVectorBitCast final
assert(resType && "expected 1-D vector type");
rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
adaptor.getSource());
- return mlir::success();
+ return success();
}
};
-/// This pattern converts the SplatOp to work on a linearized vector.
-/// Following,
-/// vector.splat %value : vector<4x4xf32>
-/// is converted to:
-/// %out_1d = vector.splat %value : vector<16xf32>
-/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
struct LinearizeVectorSplat final
: public OpConversionPattern<vector::SplatOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
- PatternBenefit benefit = 1)
+ PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
@@ -581,7 +441,7 @@ struct LinearizeVectorCreateMask final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorCreateMask(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1)
+ MLIRContext *context, PatternBenefit benefit)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
@@ -607,17 +467,16 @@ struct LinearizeVectorCreateMask final
// The result of the comparison is then multiplied with
// the second operand of create_mask to get the 1D mask.
auto firstOperand = adaptor.getOperands().front();
- auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
- auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
- loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
- auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto isNonZero = rewriter.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sgt, firstOperand, zero);
+ auto isNonZeroIndex = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), isNonZero);
auto secondOperand = adaptor.getOperands().back();
- auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
+ auto maskSize = rewriter.createOrFold<arith::AndIOp>(
loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
- auto newMask =
- rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
+ auto newMask = rewriter.create<vector::CreateMaskOp>(loc, dstTy, maskSize);
rewriter.replaceOp(createMaskOp, newMask);
return success();
}
@@ -625,104 +484,249 @@ struct LinearizeVectorCreateMask final
} // namespace
-/// This method defines the set of operations that are linearizable, and hence
-/// that are considered illegal for the conversion target.
-static bool isLinearizable(Operation *op) {
-
- // Only ops that are in the vector dialect, are ConstantLike, or
- // are Vectorizable might be linearized currently.
- StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
- StringRef opDialect = op->getDialect()->getNamespace();
- bool supported = (opDialect == vectorDialect) ||
- op->hasTrait<OpTrait::ConstantLike>() ||
- op->hasTrait<OpTrait::Vectorizable>();
- if (!supported)
- return false;
-
+/// Return true if `op` is an insert, extract, insert_strided_slice, or
+/// extract_strided_slice operation that operates on scalable vectors.
+/// Otherwise return false.
+static bool isScalableExtractOrInsertOrStrided(Operation *op) {
return TypeSwitch<Operation *, bool>(op)
- // As type legalization is done with vector.shape_cast, shape_cast
- // itself cannot be linearized (will create new shape_casts to linearize
- // ad infinitum).
- .Case<vector::ShapeCastOp>([&](auto) { return false; })
- // The operations
- // - vector.extract_strided_slice
- // - vector.extract
- // - vector.insert_strided_slice
- // - vector.insert
- // are linearized to a rank-1 vector.shuffle by the current patterns.
- // vector.shuffle only supports fixed size vectors, so it is impossible to
- // use this approach to linearize these ops if they operate on scalable
- // vectors.
.Case<vector::ExtractStridedSliceOp>(
[&](vector::ExtractStridedSliceOp extractOp) {
- return !extractOp.getType().isScalable();
+ return extractOp.getType().isScalable();
})
.Case<vector::InsertStridedSliceOp>(
[&](vector::InsertStridedSliceOp insertOp) {
- return !insertOp.getType().isScalable();
+ return insertOp.getType().isScalable();
})
.Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
- return !insertOp.getType().isScalable();
+ return insertOp.getType().isScalable();
})
.Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
- return !extractOp.getSourceVectorType().isScalable();
+ return extractOp.getSourceVectorType().isScalable();
})
- .Default([&](auto) { return true; });
+ .Default([&](auto) { return false; });
}
-void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
- ConversionTarget &target) {
+SmallVector<int64_t>
+vector::getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+ ArrayRef<int64_t> large,
+ ArrayRef<int64_t> offsets) {
+
+ // Example of alignment between, `large`, `small` and `offsets`:
+ // large = 4, 5, 6, 7, 8
+ // small = 1, 6, 7, 8
+ // offsets = 2, 3, 0
+ //
+ // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
+ assert((large.size() >= small.size()) &&
+ "rank of 'large' cannot be lower than rank of 'small'");
+ assert((large.size() >= offsets.size()) &&
+ "rank of 'large' cannot be lower than the number of offsets");
+ unsigned delta = large.size() - small.size();
+ unsigned nOffsets = offsets.size();
+ auto getSmall = [&](int64_t i) -> int64_t {
+ return i >= delta ? small[i - delta] : 1;
+ };
+ auto getOffset = [&](int64_t i) -> int64_t {
+ return i < nOffsets ? offsets[i] : 0;
+ };
+
+ // Using 2 vectors of indices, at each iteration populate the updated set of
+ // indices based on the old set of indices, and the size of the small vector
+ // in the current iteration.
+ SmallVector<int64_t> indices{0};
+ int64_t stride = 1;
+ for (int i = large.size() - 1; i >= 0; --i) {
+ int64_t currentSize = indices.size();
+ int64_t smallSize = getSmall(i);
+ int64_t nextSize = currentSize * smallSize;
+ SmallVector<int64_t> nextIndices(nextSize);
+ int64_t *base = nextIndices.begin();
+ int64_t offset = getOffset(i) * stride;
+ for (int j = 0; j < smallSize; ++j) {
+ for (int k = 0; k < currentSize; ++k) {
+ base[k] = indices[k] + offset;
+ }
+ offset += stride;
+ base += currentSize;
+ }
+ stride *= large[i];
+ indices = std::move(nextIndices);
+ }
+ return indices;
+}
+
+void vector::initializeForVectorLinearize(TypeConverter &typeConverter) {
auto convertType = [](Type type) -> std::optional<Type> {
VectorType vectorType = dyn_cast<VectorType>(type);
- if (!vectorType || !isLinearizableVector(vectorType))
+
+ if (!vectorType || !vector::isLinearizableVector(vectorType))
return type;
VectorType linearizedType =
VectorType::get(vectorType.getNumElements(),
vectorType.getElementType(), vectorType.isScalable());
+
return linearizedType;
};
typeConverter.addConversion(convertType);
auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
- if (inputs.size() != 1)
+ if (inputs.size() != 1) {
return nullptr;
-
+ }
Value value = inputs.front();
- if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
+ if (!isa<VectorType>(type) || !isa<VectorType>(value.getType())) {
return nullptr;
-
+ }
return builder.create<vector::ShapeCastOp>(loc, type, value);
};
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
+}
+
+void vector::populateForFullVectorLinearize(const TypeConverter &typeConverter,
+ ConversionTarget &target,
+ RewritePatternSet &patterns) {
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if (!isLinearizable(op))
+ // Only ops that are in the vector dialect, are ConstantLike, or
+ // are Vectorizable might be linearized currently.
+ StringLiteral vectorDialect =
+ vector::VectorDialect::getDialectNamespace();
+ StringRef opDialect = op->getDialect()->getNamespace();
+ bool supported = (opDialect == vectorDialect) ||
+ op->hasTrait<OpTrait::ConstantLike>() ||
+ op->hasTrait<OpTrait::Vectorizable>();
+ if (!supported)
+ return true;
+
+ // As type legalization is done with vector.shape_cast, shape_cast
+ // itself cannot be linearized (doing so would create new shape_casts to
+ // linearize ad infinitum).
+ if (isa<vector::ShapeCastOp>(op))
+ return true;
+
+ // The operations extract_strided_slice, extract, insert_strided_slice,
+ // and insert are linearized to a rank-1 operations that do not fully
+ // support scalable vectors, so it is not generally possible to
+ // linearize these ops if they operate on scalable vectors.
+ if (isScalableExtractOrInsertOrStrided(op))
return true;
+
// This will return true if, for all operand and result types `t`,
// convertType(t) = t. This is true if there are no rank>=2 vectors.
return typeConverter.isLegal(op);
});
-}
-void mlir::vector::populateVectorLinearizeBasePatterns(
- const TypeConverter &typeConverter, const ConversionTarget &target,
- RewritePatternSet &patterns) {
- patterns
- .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask>(
- typeConverter, patterns.getContext());
+ VectorLinearizePatterns linearizePatterns;
+
+ // Mark extract_strided_slice, insert_strided_slice, extract with source
+ // rank > 1, and insert with result rank > 1 as illegal, as they must be
+ // converted to shuffle or rank-1 extract/insert.
+ //
+ // Note that the order of the calls to `markUnknownOpDynamicallyLegal`
+ // is important: the legality rule added here takes precedence over the
+ // generic one preceding it which marked these ops as legal.
+ target.markUnknownOpDynamicallyLegal(
+ [](Operation *op) -> std::optional<bool> {
+ bool isStrided =
+ isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp>(
+ op);
+
+ bool isHighRankExtractOrInsert = [&]() {
+ if (auto extractOp = dyn_cast<vector::ExtractOp>(op)) {
+ return extractOp.getSourceVectorType().getRank() > 1;
+ }
+ if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+ return insertOp.getType().getRank() > 1;
+ }
+ return false;
+ }();
+
+ bool isScalable = isScalableExtractOrInsertOrStrided(op);
+
+ if ((isStrided || isHighRankExtractOrInsert) && !isScalable) {
+ return false;
+ }
+ return std::nullopt;
+ });
+
+ // Ensure that the benefit of patterns targetting shuffle is higher than
+ // the benefit of patterns targeting rank-1 strided slice operations. This
+ // will ensure that patterns for converting to rank-1 shuffle are run first.
+ linearizePatterns
+ .incrementBenefit(
+ LinearizePattern::VectorExtractStridedSliceToRankOneShuffle)
+ .incrementBenefit(
+ LinearizePattern::VectorInsertStridedSliceToRankOneShuffle)
+ .incrementBenefit(LinearizePattern::VectorExtractToRankOneShuffle)
+ .incrementBenefit(LinearizePattern::VectorInsertToRankOneShuffle);
+
+ linearizePatterns.addToPatternSet(typeConverter, patterns);
}
-void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
- const TypeConverter &typeConverter, const ConversionTarget &target,
- RewritePatternSet &patterns) {
- patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
- LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
- LinearizeVectorInsertStridedSlice>(typeConverter,
- patterns.getContext());
+void vector::VectorLinearizePatterns::addToPatternSet(
+ const TypeConverter &typeConverter, RewritePatternSet &patterns) const {
+
+ MLIRContext *context = patterns.getContext();
+
+ if (isEnabled(LinearizePattern::LinearizeConstantLike))
+ patterns.add<LinearizeConstantLike>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeConstantLike));
+
+ if (isEnabled(LinearizePattern::LinearizeVectorizable))
+ patterns.add<LinearizeVectorizable>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeVectorizable));
+
+ if (isEnabled(LinearizePattern::LinearizeVectorBitCast))
+ patterns.add<LinearizeVectorBitCast>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeVectorBitCast));
+
+ if (isEnabled(LinearizePattern::LinearizeVectorCreateMask))
+ patterns.add<LinearizeVectorCreateMask>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeVectorCreateMask));
+
+ if (isEnabled(LinearizePattern::LinearizeVectorShuffle))
+ patterns.add<LinearizeVectorShuffle>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeVectorShuffle));
+
+ if (isEnabled(LinearizePattern::LinearizeVectorSplat))
+ patterns.add<LinearizeVectorSplat>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::LinearizeVectorSplat));
+
+ // ------------------------ //
+ // Extract related patterns //
+ // ------------------------ //
+ if (isEnabled(LinearizePattern::VectorExtractToRankOneShuffle))
+ patterns.add<VectorExtractToRankOneShuffle>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::VectorExtractToRankOneShuffle));
+
+ if (isEnabled(LinearizePattern::VectorExtractStridedSliceToRankOneShuffle))
+ patterns.add<VectorExtractStridedSliceToRankOneShuffle>(
+ typeConverter, context,
+ getBenefit(
+ LinearizePattern::VectorExtractStridedSliceToRankOneShuffle));
+
+ // ------------------------ //
+ // Insert related patterns //
+ // ------------------------ //
+ if (isEnabled(LinearizePattern::VectorInsertToRankOneShuffle))
+ patterns.add<VectorInsertToRankOneShuffle>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::VectorInsertToRankOneShuffle));
+
+ if (isEnabled(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle))
+ patterns.add<VectorInsertStridedSliceToRankOneShuffle>(
+ typeConverter, context,
+ getBenefit(LinearizePattern::VectorInsertStridedSliceToRankOneShuffle));
}
diff --git a/mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir b/mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir
similarity index 100%
rename from mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir
rename to mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize/linearize.mlir
similarity index 100%
rename from mlir/test/Dialect/Vector/linearize.mlir
rename to mlir/test/Dialect/Vector/linearize/linearize.mlir
diff --git a/mlir/test/lib/Dialect/Vector/CMakeLists.txt b/mlir/test/lib/Dialect/Vector/CMakeLists.txt
index e16937029ac0e..1ce069599af43 100644
--- a/mlir/test/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Vector/CMakeLists.txt
@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRVectorTestPasses
TestVectorTransforms.cpp
+ TestVectorLinearize.cpp
EXCLUDE_FROM_LIBMLIR
)
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp
new file mode 100644
index 0000000000000..67179c9f98e9b
--- /dev/null
+++ b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp
@@ -0,0 +1,185 @@
+//===- TestVectorLinearize.cpp - Test Vector linearization ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <optional>
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math//IR/Math.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+struct TestVectorLinearize final
+ : public PassWrapper<TestVectorLinearize, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+
+ TestVectorLinearize() = default;
+
+ StringRef getArgument() const override { return "test-vector-linearize"; }
+ StringRef getDescription() const override {
+ return "Linearizes ND vectors for N >= 2 into 1D vectors";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<VectorDialect, arith::ArithDialect, math::MathDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ TypeConverter converter;
+ RewritePatternSet patterns(&context);
+ ConversionTarget target(context);
+ initializeForVectorLinearize(converter);
+ populateForFullVectorLinearize(converter, target, patterns);
+
+ mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+ converter, patterns, target);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+struct TestVectorBitWidthLinearize final
+ : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
+
+ TestVectorBitWidthLinearize() = default;
+ TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
+ : PassWrapper(pass) {}
+
+ StringRef getArgument() const override {
+ return "test-bit-width-constrained-vector-linearize";
+ }
+ StringRef getDescription() const override {
+ return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
+ "on inner-most dimension's bit width. If the inner-most dimension "
+ "exceded a threshold, the op is not linearized.";
+ }
+ Option<unsigned> targetVectorBitwidth{
+ *this, "target-vector-bitwidth",
+ llvm::cl::desc(
+ "Minimum vector bitwidth to enable the flattening transformation"),
+ llvm::cl::init(std::numeric_limits<unsigned>::max())};
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ TypeConverter typeConverter;
+ RewritePatternSet patterns(&context);
+ ConversionTarget target(context);
+ populateWithBitWidthConstraints(typeConverter, target, patterns,
+ targetVectorBitwidth);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+
+private:
+ /// If `type` is VectorType with trailing dimension of (bit) size greater than
+ /// or equal to `targetBitWidth`, its defining op is considered legal.
+ static bool
+ isNotLinearizableBecauseLargeInnerDimension(Type type,
+ unsigned targetBitWidth) {
+
+ VectorType vecType = dyn_cast<VectorType>(type);
+
+ // Not linearizable for reasons other than what this function checks.
+ if (!vecType || vecType.getRank() == 0)
+ return false;
+
+ // The width of the type 'index' is unbounded (and therefore potentially
+ // above the target width).
+ if (vecType.getElementType().isIndex())
+ return true;
+
+ unsigned finalDimSize = vecType.getShape().back();
+ unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
+ unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
+ return trailingVecDimBitWidth >= targetBitWidth;
+ }
+
+ static bool
+ isNotLinearizableBecauseLargeInnerDimension(Operation *op,
+ unsigned targetBitWidth) {
+ // Check on bitwidths.
+ SmallVector<std::pair<Type, unsigned>> toCheck =
+ getTypeBitWidthBoundPairs(op, targetBitWidth);
+ return std::any_of(toCheck.begin(), toCheck.end(),
+ [&](std::pair<Type, unsigned> typeWidth) {
+ return isNotLinearizableBecauseLargeInnerDimension(
+ typeWidth.first, typeWidth.second);
+ });
+ }
+
+ static void populateWithBitWidthConstraints(TypeConverter &typeConverter,
+ ConversionTarget &target,
+ RewritePatternSet &patterns,
+ unsigned targetBitWidth) {
+
+ initializeForVectorLinearize(typeConverter);
+ populateForFullVectorLinearize(typeConverter, target, patterns);
+
+ // Extend the set of legal ops to include those with large inner-most
+ // dimensions on selected operands/results.
+ target.markUnknownOpDynamicallyLegal(
+ [=](Operation *op) -> std::optional<bool> {
+ if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
+ return true;
+ }
+ return {};
+ });
+ }
+
+ /// Get the set of operand/result types to check for sufficiently
+ /// small inner-most dimension size.
+ static SmallVector<std::pair<Type, unsigned>>
+ getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
+
+ if (auto insertOp = dyn_cast<InsertOp>(op)) {
+ unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
+ ? targetBitWidth + 1
+ : targetBitWidth;
+ return {{insertOp.getValueToStoreType(), w}};
+ }
+
+ auto resultTypes = op->getResultTypes();
+ SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+ resultsWithBitWidth.reserve(resultTypes.size());
+ for (Type type : resultTypes) {
+ resultsWithBitWidth.push_back({type, targetBitWidth});
+ }
+ return resultsWithBitWidth;
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+extern void registerTestVectorLinearize() {
+ PassRegistration<TestVectorLinearize>();
+ PassRegistration<TestVectorBitWidthLinearize>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f4f32e9339870..5c75d32c22236 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -17,7 +17,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -837,160 +836,6 @@ struct TestVectorEmulateMaskedLoadStore final
}
};
-/// Get the set of operand/result types to check for sufficiently
-/// small inner-most dimension size.
-static SmallVector<std::pair<Type, unsigned>>
-getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
-
- if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
- unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
- ? targetBitWidth + 1
- : targetBitWidth;
- return {{insertOp.getValueToStoreType(), w}};
- }
-
- auto resultTypes = op->getResultTypes();
- SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
- resultsWithBitWidth.reserve(resultTypes.size());
- for (Type type : resultTypes) {
- resultsWithBitWidth.push_back({type, targetBitWidth});
- }
- return resultsWithBitWidth;
-}
-
-/// If `type` is VectorType with trailing dimension of (bit) size greater than
-/// or equal to `targetBitWidth`, its defining op is considered legal.
-static bool
-isNotLinearizableBecauseLargeInnerDimension(Type type,
- unsigned targetBitWidth) {
-
- VectorType vecType = dyn_cast<VectorType>(type);
-
- // Not linearizable for reasons other than what this function checks.
- if (!vecType || vecType.getRank() == 0)
- return false;
-
- // The width of the type 'index' is unbounded (and therefore potentially above
- // the target width).
- if (vecType.getElementType().isIndex())
- return true;
-
- unsigned finalDimSize = vecType.getShape().back();
- unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
- unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
- return trailingVecDimBitWidth >= targetBitWidth;
-}
-
-static bool
-isNotLinearizableBecauseLargeInnerDimension(Operation *op,
- unsigned targetBitWidth) {
- // Check on bitwidths.
- SmallVector<std::pair<Type, unsigned>> toCheck =
- getTypeBitWidthBoundPairs(op, targetBitWidth);
- return llvm::any_of(toCheck, [&](std::pair<Type, unsigned> typeWidth) {
- return isNotLinearizableBecauseLargeInnerDimension(typeWidth.first,
- typeWidth.second);
- });
-}
-
-void populateWithBitWidthConstraints(TypeConverter &typeConverter,
- ConversionTarget &target,
- unsigned targetBitWidth) {
-
- // The general purpose definition of what ops are legal must come first.
- populateForVectorLinearize(typeConverter, target);
-
- // Extend the set of legal ops to include those with large inner-most
- // dimensions on selected operands/results.
- target.markUnknownOpDynamicallyLegal(
- [=](Operation *op) -> std::optional<bool> {
- if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
- return true;
- }
- return {};
- });
-}
-
-struct TestVectorBitWidthLinearize final
- : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
-
- TestVectorBitWidthLinearize() = default;
- TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
- : PassWrapper(pass) {}
-
- StringRef getArgument() const override {
- return "test-bit-width-constrained-vector-linearize";
- }
- StringRef getDescription() const override {
- return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
- "in inner-most dimension's bit width.";
- }
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect>();
- }
-
- Option<unsigned> targetVectorBitwidth{
- *this, "target-vector-bitwidth",
- llvm::cl::desc(
- "Minimum vector bitwidth to enable the flattening transformation"),
- llvm::cl::init(std::numeric_limits<unsigned>::max())};
- void runOnOperation() override {
- auto *context = &getContext();
-
- TypeConverter typeConverter;
- RewritePatternSet patterns(context);
- ConversionTarget target(*context);
-
- populateWithBitWidthConstraints(typeConverter, target,
- targetVectorBitwidth);
-
- vector::populateVectorLinearizeBasePatterns(typeConverter, target,
- patterns);
-
- vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target,
- patterns);
-
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- return signalPassFailure();
- }
-};
-
-struct TestVectorLinearize final
- : public PassWrapper<TestVectorLinearize, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
-
- TestVectorLinearize() = default;
-
- StringRef getArgument() const override { return "test-vector-linearize"; }
- StringRef getDescription() const override {
- return "Linearizes ND vectors for N >= 2 into 1D vectors";
- }
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect, arith::ArithDialect>();
- }
-
- void runOnOperation() override {
- MLIRContext &context = getContext();
- TypeConverter converter;
- RewritePatternSet patterns(&context);
- ConversionTarget target(context);
-
- vector::populateForVectorLinearize(converter, target);
-
- vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
- vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
- patterns);
- mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
- converter, patterns, target);
-
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- return signalPassFailure();
- }
-};
-
struct TestEliminateVectorMasks
: public PassWrapper<TestEliminateVectorMasks,
OperationPass<func::FuncOp>> {
@@ -1062,10 +907,6 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorEmulateMaskedLoadStore>();
- PassRegistration<TestVectorLinearize>();
-
- PassRegistration<TestVectorBitWidthLinearize>();
-
PassRegistration<TestEliminateVectorMasks>();
}
} // namespace test
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2e08ae6f37980..f52f36107e301 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -155,6 +155,7 @@ void registerTestTopologicalSortAnalysisPass();
void registerTestTransformDialectEraseSchedulePass();
void registerTestPassStateExtensionCommunication();
void registerTestVectorLowerings();
+void registerTestVectorLinearize();
void registerTestVectorReductionToSPIRVDotProd();
void registerTestVulkanRunnerPipeline();
void registerTestWrittenToPass();
@@ -300,6 +301,7 @@ void registerTestPasses() {
mlir::test::registerTestTransformDialectEraseSchedulePass();
mlir::test::registerTestPassStateExtensionCommunication();
mlir::test::registerTestVectorLowerings();
+ mlir::test::registerTestVectorLinearize();
mlir::test::registerTestVectorReductionToSPIRVDotProd();
mlir::test::registerTestVulkanRunnerPipeline();
mlir::test::registerTestWrittenToPass();
>From 68183789137dad8c832fb9b889875d07e21d869a Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 3 Jun 2025 15:53:55 -0700
Subject: [PATCH 2/2] typos
---
.../Vector/Transforms/VectorLinearize.h | 28 +++++++++----------
1 file changed, 14 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
index cd62de640d088..af5ce2103f774 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
@@ -16,27 +16,27 @@
namespace mlir {
namespace vector {
-/// Initialize `typeConverter` with source and target materialization logic
-/// using shape_casts to/from 1D vectors.
+/// Initialize `typeConverter` with source and target materializations that
+/// use shape_casts to/from 1D vectors.
void initializeForVectorLinearize(TypeConverter &typeConverter);
-/// Initialize `conversionTarget`, and `patterns` for linearization. Here
+/// Initialize `conversionTarget` and `patterns` for linearization. Here
/// linearization means converting a single operation with 1+ vector
/// operand/result of rank>1, into a new single operation whose vector operands
-/// and results are all of rank<=1.
+/// and results are all rank<=1.
///
-/// This function initializes `conversionTarget` with the set of operations that
-/// are illegal and consequently must be converted to a linearized form. It
-/// also populates the set of patterns that can be run to convert illegal
-/// operations, and what priority/benefit they have.
+/// This function initializes `conversionTarget` with a definition of which
+/// operations are illegal and consequently must be converted to a linearized
+/// (legal) form. It also populates `patterns` with the patterns that will be
+/// run to convert illegal operations, and what sets what priority/benefit they
+/// have.
///
-/// Note: the set of legal operations can be extended by a user if, for example,
-/// certain rank>1 vectors are considered valid, by adding additional
-/// dynamically legal ops to `conversionTarget`.
+/// Note: the set of legal operations can be extended by a user by adding
+/// additional legality rules to `conversionTarget`.
///
/// Further note: the choice to use a dialect conversion design for
-/// linearization is to make it easy to reuse generic structural type
-/// conversions for linearizing scf/cf/func operations
+/// linearization is to enable reuse of generic structural type conversions for
+/// linearizing scf/cf/func operations.
void populateForFullVectorLinearize(const TypeConverter &,
ConversionTarget &conversionTarget,
RewritePatternSet &patterns);
@@ -233,7 +233,7 @@ struct VectorLinearizePatterns {
};
/// Consider inserting a vector of shape `small` into a vector of shape `large`,
-/// at position `offsets`: this function enumeratates all the indices in `large`
+/// at position `offsets`: this function enumerates all the indices in `large`
/// that are written to. The enumeration is with row-major ordering.
///
/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
More information about the Mlir-commits
mailing list