[Mlir-commits] [mlir] [mlir][vector] Refactor vector linearization patterns (PR #142685)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 4 10:56:57 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

<details>
<summary>Changes</summary>

This PR separates out the vector linearization API and testing from other vector rewrite patterns. There is no functional change (although the API changes). 

#### API change:

There is currently a partition into 2 groups of linearization patterns: `populateVectorLinearizeBasePatterns` and `populateVectorLinearizeShuffleLikeOpsPatterns`.  I would like to add more patterns for linearization (draft PR https://github.com/llvm/llvm-project/pull/142672) but don't want to add a third group of patterns because I don't see any obvious grouping. I think it'd be less opiniated if any sub-group of the patterns can be used. That's introduced in this PR. With this PR there is an API which adds all patterns to the `RewritePatternSet` (`populateForFullVectorLinearize`), but a user can also bypass this API and mix-and-match whichever patterns they want (as well as control the pattern benefits). 

#### Test change:

The file `TestVectorTransforms.cpp` was getting large (~1'000 lines) so I split it up (SPIRV is example dialect that has multiple test util .cpp files like this). 


---

Patch is 72.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142685.diff


9 Files Affected:

- (added) mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h (+251) 
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (-33) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+487-399) 
- (renamed) mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir () 
- (renamed) mlir/test/Dialect/Vector/linearize/linearize.mlir () 
- (modified) mlir/test/lib/Dialect/Vector/CMakeLists.txt (+1) 
- (added) mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp (+185) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (-159) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
new file mode 100644
index 0000000000000..de6a441249695
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorLinearize.h
@@ -0,0 +1,251 @@
+//===- VectorLinearize.h - Vector linearization patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORLINEARIZE_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace vector {
+
+/// Initialize `typeConverter` with source and target materializations that
+/// use shape_cast for converting to and from 1D (linearized) vectors.
+void initializeForVectorLinearize(TypeConverter &typeConverter);
+
+/// Initialize `conversionTarget` and `patterns` for linearization. Here
+/// linearization means converting a single operation with 1+ vector
+/// operand/result of rank>1, into a new single operation whose vector operands
+/// and results are all rank<=1.
+///
+/// This function initializes `conversionTarget` with a definition of which
+/// operations are illegal and consequently must be converted to a linearized
+/// (legal) form. It also populates `patterns` with patterns that will be run to
+/// convert illegal operations, and sets the priority/benefit patterns have.
+///
+/// Note: the set of legal operations can be extended by a user by adding
+/// additional legality rules to `conversionTarget`.
+///
+/// Further note: the choice to use a dialect conversion design for
+/// linearization is to enable reusing generic structural type conversions for
+/// linearizing scf/cf/func operations.
+void populateForFullVectorLinearize(const TypeConverter &,
+                                    ConversionTarget &conversionTarget,
+                                    RewritePatternSet &patterns);
+
+/// The set of patterns available for linearization.
+enum class LinearizePattern {
+
+  /// This pattern converts a constant (or poison) vector of rank>1 into a
+  /// 1D vector, followed by a shape_cast.
+  ///
+  /// BEFORE
+  /// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+  ///
+  /// AFTER
+  /// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+  /// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+  LinearizeConstantLike = 0,
+
+  /// BEFORE
+  /// %2 = math.sin %arg0 : vector<2x2xf32>
+  ///
+  /// AFTER
+  /// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+  /// %1 = math.sin %0 : vector<4xf32>
+  /// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+  LinearizeVectorizable,
+
+  /// BEFORE
+  /// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+  ///
+  /// AFTER
+  /// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+  /// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+  /// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+  LinearizeVectorBitCast,
+
+  /// BEFORE
+  /// %mask_2d = vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+  ///
+  /// AFTER
+  /// [...]
+  /// %mask_1d= vector.create_mask %mul : vector<4xi1>
+  /// %mask_2d = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+  ///
+  /// where `%mul` is a function of `%arg0` and `%arg1`.
+  ///
+  /// This pattern currently only supports 2D masks with a unit outer
+  /// dimension.
+  LinearizeVectorCreateMask,
+
+  /// This pattern converts a vector.shuffle that works on nD (n > 1) vectors to
+  /// a one that works on linearized vectors.
+  ///
+  /// BEFORE
+  /// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+  ///
+  /// AFTER
+  /// %v1_1d = vector.shape_cast %v1_3d : [...]
+  /// %v2_1d = vector.shape_cast %v2_3d : [...]
+  /// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+  /// %shuffle_3d = vector.shape_cast %shuffle_1d :  [...]
+  ///
+  /// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`.
+  LinearizeVectorShuffle,
+
+  /// BEFORE
+  /// %1 = vector.splat %value : vector<4x4xf32>
+  ///
+  /// AFTER
+  /// %0 = vector.splat %value : vector<16xf32>
+  /// %1 = vector.shape_cast %0 : vector<16xf32> to vector<4x4xf32>
+  LinearizeVectorSplat,
+
+  /// This pattern converts a vector.extract_strided_slice operation into a
+  /// vector.shuffle operation that has rank-1 (linearized) operand and
+  /// result.
+  ///
+  /// BEFORE
+  /// %out_nd = vector.extract_strided_slice %source_nd
+  ///         { offsets = [..], strides = [..], sizes = [..] }
+  ///
+  /// AFTER
+  /// %source_1d = vector.shape_cast %source_nd [...]
+  /// %out_1d    = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+  /// %out_nd    = vector.shape_cast %out_1d [...]
+  ///
+  /// `shuffle_indices_1d` is computed using the offsets and sizes of the
+  /// original vector.extract_strided_slice operation.
+  VectorExtractStridedSliceToRankOneShuffle,
+
+  /// BEFORE
+  /// %extract = vector.extract %src [ position ]
+  ///
+  /// AFTER
+  /// %src_1d = vector.shape_cast %src : [...]
+  /// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices ]
+  /// %out_nd = vector.shape_cast %out_1d : [...]
+  ///
+  /// `shuffle_indices` is computed from `position` of original extract.
+  VectorExtractToRankOneShuffle,
+
+  /// This pattern converts a vector.insert_strided_slice operation into a
+  /// vector.shuffle operation that has rank-1 (linearized) operands and result.
+  ///
+  /// BEFORE
+  /// %0 = vector.insert_strided_slice %to_store, %into
+  ///             {offsets = [1, 0, 0, 0], strides = [1, 1]}
+  ///                  : vector<2x2xi8> into vector<2x1x3x2xi8>
+  /// AFTER
+  /// %to_store_1d
+  ///          = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+  /// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+  /// %out_1d  = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+  /// %out_nd  = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+  ///
+  /// where shuffle_indices_1d in this case is
+  ///     [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+  ///                        ^^^^^^^^^^^^^^
+  ///                          to_store_1d
+  VectorInsertStridedSliceToRankOneShuffle,
+
+  /// BEFORE
+  /// %insert = vector.insert %src %dst [ position ]
+  ///
+  /// AFTER
+  /// %src_1d = vector.shape_cast %src : [...]
+  /// %dst_1d = vector.shape_cast %dst : [...]
+  /// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ]
+  /// %out_nd = vector.shape_cast %out_1d : [...]
+  ///
+  /// `shuffle_indices` is computed from `position`.
+  VectorInsertToRankOneShuffle,
+
+  /// The number of patterns in this enum.
+  N
+};
+
+/// This class contains functions to control the set of linearization patterns
+/// to include for the conversion, and their priority.
+struct VectorLinearizePatterns {
+
+public:
+  /// By default all patterns are enabled and have benefit 1.
+  VectorLinearizePatterns() {
+    enabled.fill(true);
+    benefits.fill(PatternBenefit(1));
+  }
+
+  /// Add the patterns enabled for the conversion to `patterns`.
+  void addToPatternSet(const TypeConverter &,
+                       RewritePatternSet &patterns) const;
+
+  VectorLinearizePatterns &enable(LinearizePattern id, bool e = true) {
+    enabled[static_cast<unsigned>(id)] = e;
+    return *this;
+  }
+
+  VectorLinearizePatterns &enableAll(bool e = true) {
+    enabled.fill(e);
+    return *this;
+  }
+
+  bool isEnabled(LinearizePattern id) const {
+    return enabled[static_cast<unsigned>(id)];
+  }
+
+  PatternBenefit getBenefit(LinearizePattern id) const {
+    return benefits[static_cast<unsigned>(id)];
+  }
+
+  VectorLinearizePatterns &setBenefit(LinearizePattern id,
+                                      PatternBenefit benefit) {
+    getBenefitRef(id) = benefit;
+    return *this;
+  }
+
+  VectorLinearizePatterns &incrementBenefit(LinearizePattern id,
+                                            unsigned inc = 1) {
+    getBenefitRef(id) = getBenefit(id).getBenefit() + 1;
+    return *this;
+  }
+
+private:
+  std::array<bool, static_cast<unsigned>(LinearizePattern::N)> enabled;
+  std::array<PatternBenefit, static_cast<unsigned>(LinearizePattern::N)>
+      benefits;
+
+  PatternBenefit &getBenefitRef(LinearizePattern id) {
+    unsigned idInt = static_cast<unsigned>(id);
+    assert(idInt < static_cast<unsigned>(LinearizePattern::N) &&
+           "invalid linearization pattern id");
+    return benefits[idInt];
+  }
+};
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumerates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+SmallVector<int64_t> getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+                                                     ArrayRef<int64_t> large,
+                                                     ArrayRef<int64_t> offsets);
+
+} // namespace vector
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..6954cb7172129 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,39 +406,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
-///
-/// Definition: here 'linearization' means converting a single operation with
-/// 1+ vector operand/result of rank>1, into a new single operation whose
-/// vector operands and results are all of rank<=1.
-///
-/// This function registers (1) which operations are legal, and hence should not
-/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
-/// materialze the conversion (with shape_cast)
-///
-/// Note: the set of legal operations can be extended by a user if for example
-/// certain rank>1 vectors are considered valid, by adding additional
-/// dynamically legal ops to `conversionTarget`.
-///
-/// Further note: the choice to use a dialect conversion design for
-/// linearization is to make it easy to reuse generic structural type
-/// conversions for linearizing scf/cf/func operations
-void populateForVectorLinearize(TypeConverter &typeConverter,
-                                ConversionTarget &conversionTarget);
-
-/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
-/// contains patterns for converting ConstantLike, Vectorizable, and
-/// vector::BitCast ops.
-void populateVectorLinearizeBasePatterns(const TypeConverter &,
-                                         const ConversionTarget &,
-                                         RewritePatternSet &patterns);
-
-/// Populates `patterns` for linearizing ND (N >= 2) vector operations
-/// to 1D vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
-                                                   const ConversionTarget &,
-                                                   RewritePatternSet &patterns);
-
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..0c11c9b5c8740 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,9 +10,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Vector/Transforms/VectorLinearize.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Operation.h"
@@ -47,12 +48,21 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
 
 namespace {
 
+/// This pattern converts a constant (or poison) vector of rank>1 into a
+/// 1D vector, followed by a shape_cast.
+///
+/// BEFORE
+/// %1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+///
+/// AFTER
+/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
 struct LinearizeConstantLike final
     : OpTraitConversionPattern<OpTrait::ConstantLike> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
   LinearizeConstantLike(const TypeConverter &typeConverter,
-                        MLIRContext *context, PatternBenefit benefit = 1)
+                        MLIRContext *context, PatternBenefit benefit)
       : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -88,13 +98,20 @@ struct LinearizeConstantLike final
   }
 };
 
+/// BEFORE
+/// %2 = math.sin %arg0 : vector<2x2xf32>
+///
+/// AFTER
+/// %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+/// %1 = math.sin %0 : vector<4xf32>
+/// %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
 struct LinearizeVectorizable final
     : OpTraitConversionPattern<OpTrait::Vectorizable> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
 public:
   LinearizeVectorizable(const TypeConverter &typeConverter,
-                        MLIRContext *context, PatternBenefit benefit = 1)
+                        MLIRContext *context, PatternBenefit benefit)
       : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -109,17 +126,178 @@ struct LinearizeVectorizable final
   }
 };
 
-template <typename TOp>
-static bool stridesAllOne(TOp op) {
-  static_assert(
-      std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
-          std::is_same_v<TOp, vector::InsertStridedSliceOp>,
-      "expected vector.extract_strided_slice or vector.insert_strided_slice");
-  ArrayAttr strides = op.getStrides();
-  return llvm::all_of(strides, isOneInteger);
-}
+/// BEFORE
+/// %2 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+///
+/// AFTER
+/// %0 = vector.shape_cast %arg0 : vector<4x4xf32> to vector<16xf32>
+/// %1 = vector.bitcast %0 : vector<16xf32> to vector<32xf16>
+/// %2 = vector.shape_cast %1 : vector<32xf16> to vector<4x8xf16>
+struct LinearizeVectorBitCast final
+    : public OpConversionPattern<vector::BitCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorBitCast(const TypeConverter &typeConverter,
+                         MLIRContext *context, PatternBenefit benefit)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+  LogicalResult
+  matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto resType = getTypeConverter()->convertType(castOp.getType());
+    assert(resType && "expected 1-D vector type");
+    rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
+                                                   adaptor.getSource());
+    return success();
+  }
+};
+
+/// This pattern converts the vector.create_mask to work on a linearized vector.
+/// It currently supports only 2D masks with a unit outer dimension.
+///
+/// BEFORE
+///   vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+///
+/// AFTER
+///   %zero = arith.constant 0 : index
+///   %cmpi = arith.cmpi sgt, %arg0, %zero : index
+///   %index = arith.index_cast %cmpi : i1 to index
+///   %mul = arith.andi %index, %arg1 : index
+///   %mask = vector.create_mask %mul : vector<4xi1>
+///   %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = createMaskOp.getLoc();
+    VectorType srcTy = createMaskOp.getType();
+    auto srcShape = srcTy.getShape();
+    if (srcShape.size() != 2)
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "only 2D mask is supported.");
+
+    if (srcShape[0] != 1)
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "only unit outer dimension is supported.");
+
+    auto dstTy = getTypeConverter()->convertType(srcTy);
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+    // Compare the first operand with 0. If it is greater than 0, the
+    // corresponding mask element is set to true, otherwise false.
+    // The result of the comparison is then multiplied with
+    // the second operand of create_mask to get the 1D mask.
+    auto firstOperand = adaptor.getOperands().front();
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto isNonZero = rewriter.createOrFold<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sgt, firstOperand, zero);
+    auto isNonZeroIndex = rewriter.createOrFold<arith::IndexCastOp>(
+        loc, rewriter.getIndexType(), isNonZero);
+    auto secondOperand = adaptor.getOperands().back();
+    auto maskSize = rewriter.createOrFold<arith::AndIOp>(
+        loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
+
+    auto newMask = rewriter.create<vector::CreateMaskOp>(loc, dstTy, maskSize);
+    rewriter.replaceOp(createMaskOp, newMask);
+    return success();
+  }
+};
+
+/// This pattern converts a vector.shuffle that works on nD (n > 1) vectors to
+/// a one that works on linearized vectors.
+///
+/// BEFORE
+/// %shuffle_3d = vector.shuffle %v1_3d, %v2_3d [ shuffle_indices ]
+///
+/// AFTER
+/// %v1_1d = vector.shape_cast %v1_3d : [...]
+/// %v2_1d = vector.shape_cast %v2_3d : [...]
+/// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
+/// %shuffle_3d = vector.shape_cast %shuffle_1d :  [...]
+///
+/// Where `shuffle_indices_1d` is computed by expanding `shuffle_indices`.
+struct LinearizeVectorShuffle final
+    : public OpConversionPattern<vector::ShuffleOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorShuffle(const TypeConverter &typeConverter,
+                         MLIRContext *context, PatternBenefit benefit)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType dstType =
+        getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
+    assert(dstType && "vector type destination expected.");
+
+    Value vec1 = adaptor.getV1();
+    Value vec2 = adaptor.getV2();
+    int shuffleSliceLen = 1;
+    int rank = shuffleOp.getV1().getType().getRank();
 
-/// Convert an array of attributes into a vector of integers, if possible.
+    // If rank > 1, we need to do the shuffle in the granu...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/142685


More information about the Mlir-commits mailing list