[Mlir-commits] [mlir] [vector][linearize] Refactor code to push target bit width out of core code (PR #136581)

James Newling llvmlistbot at llvm.org
Mon Apr 21 09:57:57 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/136581

Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results of operations. 

In https://github.com/llvm/llvm-project/pull/83314 an option to ignore (legalize) operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to reduce non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all). 

As a follow-up to this PR, I propose that user(s) of the `targetVectorBitWidth` move `legalBecauseOfBitwidth` to their code bases, and then remove it from upstream. 

The approach I've used is to move the logic pertaining to `targetVectorBitWidth` out the patterns, and into the conversion target, which the end user can control outside of core MLIR. 

>From b523a5a654c9889e30f3f822af1cc30257aeb107 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Apr 2025 09:32:22 -0700
Subject: [PATCH] factorize out the logic about bounded bitwidth

---
 .../Vector/Transforms/VectorRewritePatterns.h |  30 +-
 .../Vector/Transforms/VectorLinearize.cpp     | 264 +++++++++---------
 mlir/test/Dialect/Vector/linearize.mlir       |  13 +
 .../Dialect/Vector/TestVectorTransforms.cpp   |  17 +-
 4 files changed, 170 insertions(+), 154 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ce97847172197..d9a0791cdea33 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -392,18 +392,24 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Populates patterns for ND vectors (N >= 2) linearization and sets up the
-/// provided ConversionTarget with the appropriate legality configuration for
-/// the ops to get converted properly.
-void populateVectorLinearizeTypeConversionsAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
-
-/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
-/// vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(
-    const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
+/// Populate `typeConverter` and `conversionTarget` with the definition of
+/// legal types and operations, for the specific case where vectors with
+/// trailing dimensions of size greater than `targetBitWidth` are legal.
+void populateVectorLinearizeBitWidthTargetAndConverter(
+    TypeConverter &typeConverter, ConversionTarget &conversionTarget,
+    unsigned targetBitWidth);
+
+/// Populates `patterns` for ND vector (N >= 2) linearization. Patterns for
+/// converting ConstantLike, Vectorizable, and vector::BitCast.
+void populateVectorLinearizeBasePatterns(const TypeConverter &,
+                                         RewritePatternSet &patterns,
+                                         const ConversionTarget &);
+
+/// Populates `patterns` for linearizing ND (N >= 2) vector operations
+/// to 1D vector shuffle operations.
+void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
+                                                   RewritePatternSet &patterns,
+                                                   const ConversionTarget &);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..3a80ce815b766 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,7 +10,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -22,44 +21,16 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include <cstdint>
+#include <limits>
 #include <numeric>
+#include <optional>
 
 using namespace mlir;
 
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
-  auto resultTypes = op->getResultTypes();
-  for (auto resType : resultTypes) {
-    VectorType vecType = dyn_cast<VectorType>(resType);
-    // Reject index since getElementTypeBitWidth will abort for Index types.
-    if (!vecType || vecType.getElementType().isIndex())
-      return false;
-    // There are no dimension to fold if it is a 0-D vector.
-    if (vecType.getRank() == 0)
-      return false;
-    unsigned trailingVecDimBitWidth =
-        vecType.getShape().back() * vecType.getElementTypeBitWidth();
-    if (trailingVecDimBitWidth >= targetBitWidth)
-      return false;
-  }
-  return true;
-}
-
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
-  VectorType vecType = dyn_cast<VectorType>(t);
-  // Reject index since getElementTypeBitWidth will abort for Index types.
-  if (!vecType || vecType.getElementType().isIndex())
-    return false;
-  // There are no dimension to fold if it is a 0-D vector.
-  if (vecType.getRank() == 0)
-    return false;
-  unsigned trailingVecDimBitWidth =
-      vecType.getShape().back() * vecType.getElementTypeBitWidth();
-  return trailingVecDimBitWidth <= targetBitWidth;
-}
-
 static FailureOr<Attribute>
 linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
                    VectorType resType, Attribute value) {
+
   if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
     if (resType.isScalable() && !isa<SplatElementsAttr>(value))
       return rewriter.notifyMatchFailure(
@@ -76,16 +47,14 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
 }
 
 namespace {
+
 struct LinearizeConstantLike final
     : OpTraitConversionPattern<OpTrait::ConstantLike> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
-  LinearizeConstantLike(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpTraitConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeConstantLike(const TypeConverter &typeConverter,
+                        MLIRContext *context, PatternBenefit benefit = 1)
+      : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
@@ -100,10 +69,6 @@ struct LinearizeConstantLike final
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
 
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          loc, "Can't flatten since targetBitWidth <= OpSize");
-
     StringAttr attrName = rewriter.getStringAttr("value");
     Attribute value = op->getAttr(attrName);
     if (!value)
@@ -124,9 +89,6 @@ struct LinearizeConstantLike final
     rewriter.replaceOp(op, newOp);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 struct LinearizeVectorizable final
@@ -134,18 +96,12 @@ struct LinearizeVectorizable final
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
 public:
-  LinearizeVectorizable(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpTraitConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeVectorizable(const TypeConverter &typeConverter,
+                        MLIRContext *context, PatternBenefit benefit = 1)
+      : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
     FailureOr<Operation *> newOp =
         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
     if (failed(newOp))
@@ -154,9 +110,6 @@ struct LinearizeVectorizable final
     rewriter.replaceOp(op, (*newOp)->getResults());
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
@@ -173,12 +126,10 @@ struct LinearizeVectorizable final
 struct LinearizeVectorExtractStridedSlice final
     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorExtractStridedSlice(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
+                                     MLIRContext *context,
+                                     PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -189,9 +140,6 @@ struct LinearizeVectorExtractStridedSlice final
     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
     ArrayAttr offsets = extractOp.getOffsets();
     ArrayAttr sizes = extractOp.getSizes();
@@ -268,9 +216,6 @@ struct LinearizeVectorExtractStridedSlice final
         extractOp, dstType, srcVector, srcVector, indices);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ShuffleOp that works on nD (n > 1)
@@ -291,8 +236,7 @@ struct LinearizeVectorShuffle final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -302,13 +246,12 @@ struct LinearizeVectorShuffle final
     assert(dstType && "vector type destination expected.");
     // The assert is used because vector.shuffle does not support scalable
     // vectors.
-    assert(!(shuffleOp.getV1VectorType().isScalable() ||
-             shuffleOp.getV2VectorType().isScalable() ||
-             dstType.isScalable()) &&
-           "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
+    bool scalable = shuffleOp.getV1VectorType().isScalable() ||
+                    shuffleOp.getV2VectorType().isScalable() ||
+                    dstType.isScalable();
+    if (scalable)
+      return rewriter.notifyMatchFailure(shuffleOp,
+                                         "scalable vectors are not supported.");
 
     Value vec1 = adaptor.getV1();
     Value vec2 = adaptor.getV2();
@@ -343,9 +286,6 @@ struct LinearizeVectorShuffle final
                                                    vec2, indices);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
@@ -364,8 +304,7 @@ struct LinearizeVectorExtract final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -378,9 +317,6 @@ struct LinearizeVectorExtract final
         cast<VectorType>(dstTy).isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
     // Dynamic position is not supported.
     if (extractOp.hasDynamicPosition())
@@ -405,9 +341,6 @@ struct LinearizeVectorExtract final
 
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the InsertOp to a ShuffleOp that works on a
@@ -427,8 +360,7 @@ struct LinearizeVectorInsert final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -439,11 +371,6 @@ struct LinearizeVectorInsert final
       return rewriter.notifyMatchFailure(insertOp,
                                          "scalable vectors are not supported.");
 
-    if (!isLessThanOrEqualTargetBitWidth(insertOp.getValueToStoreType(),
-                                         targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          insertOp, "Can't flatten since targetBitWidth < OpSize");
-
     // dynamic position is not supported
     if (insertOp.hasDynamicPosition())
       return rewriter.notifyMatchFailure(insertOp,
@@ -471,11 +398,11 @@ struct LinearizeVectorInsert final
     }
 
     llvm::SmallVector<int64_t, 2> indices(dstSize);
-    auto origValsUntil = indices.begin();
+    auto *origValsUntil = indices.begin();
     std::advance(origValsUntil, linearizedOffset);
     std::iota(indices.begin(), origValsUntil,
               0); // original values that remain [0, offset)
-    auto newValsUntil = origValsUntil;
+    auto *newValsUntil = origValsUntil;
     std::advance(newValsUntil, srcSize);
     std::iota(origValsUntil, newValsUntil,
               dstSize); // new values [offset, offset+srcNumElements)
@@ -488,9 +415,6 @@ struct LinearizeVectorInsert final
 
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the BitCastOp that works on nD (n > 1)
@@ -508,8 +432,7 @@ struct LinearizeVectorBitCast final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -518,24 +441,103 @@ struct LinearizeVectorBitCast final
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type.");
 
-    if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          loc, "Can't flatten since targetBitWidth <= OpSize");
-
     rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
                                                    adaptor.getSource());
     return mlir::success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 } // namespace
 
-void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth) {
+/// If `type` is VectorType with trailing dimension of (bit) size greater than
+/// or equal to `targetBitWidth`, its defining op is considered legal. 
+static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
+
+  VectorType vecType = dyn_cast<VectorType>(type);
+
+  if (!vecType)
+    return true;
+
+  // The width of the type 'index' is unbounded (and therefore potentially above
+  // the target width).
+  if (vecType.getElementType().isIndex())
+    return true;
+
+  unsigned finalDimSize =
+      vecType.getRank() == 0 ? 0 : vecType.getShape().back();
+
+  unsigned trailingVecDimBitWidth =
+      finalDimSize * vecType.getElementTypeBitWidth();
+
+  return trailingVecDimBitWidth >= targetBitWidth;
+}
+
+static SmallVector<std::pair<Type, unsigned>>
+getChecksForBitwidth(Operation *op, unsigned targetBitWidth) {
+
+  if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+    auto w = targetBitWidth < std::numeric_limits<unsigned>::max()
+                 ? targetBitWidth + 1
+                 : targetBitWidth;
+    return {{insertOp.getValueToStoreType(), w}};
+  }
+  auto resultTypes = op->getResultTypes();
+  SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+  resultsWithBitWidth.reserve(resultTypes.size());
+  for (Type type : resultTypes) {
+    resultsWithBitWidth.push_back({type, targetBitWidth});
+  }
+  return resultsWithBitWidth;
+}
+
+/// Return true if the operation `op` does not support scalable vectors and
+/// has at least 1 scalable vector result.
+static bool legalBecauseScalable(Operation *op) {
+
+  bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
+                           op->hasTrait<OpTrait::Vectorizable>() ||
+                           isa<vector::BitCastOp>(op);
+
+  if (scalableSupported)
+    return false;
+
+  // Check if any of the results is a scalable vector type.
+  auto types = op->getResultTypes();
+  bool containsScalableResult =
+      std::any_of(types.begin(), types.end(), [](Type type) {
+        auto vecType = dyn_cast<VectorType>(type);
+        return vecType && vecType.isScalable();
+      });
+
+  return containsScalableResult;
+}
+
+static bool dynamicallyLegal(Operation *op, unsigned targetBitWidth) {
+
+  // Only ops that are in the vector dialect, are ConstantLike, or
+  // are Vectorizable might be linearized currently, so legalize the others.
+  bool opIsVectorDialect = op->getDialect()->getNamespace() ==
+                           vector::VectorDialect::getDialectNamespace();
+  if (!opIsVectorDialect && !op->hasTrait<OpTrait::ConstantLike>() &&
+      !op->hasTrait<OpTrait::Vectorizable>())
+    return true;
+
+  // Some ops will not be linearized if they have scalable vector results.
+  if (legalBecauseScalable(op))
+    return true;
+
+  // Check on bitwidths.
+  auto typesToCheck = getChecksForBitwidth(op, targetBitWidth);
+  return std::any_of(typesToCheck.begin(), typesToCheck.end(),
+                     [&](std::pair<Type, unsigned> typeWidth) {
+                       return legalBecauseOfBitwidth(typeWidth.first,
+                                                     typeWidth.second);
+                     });
+}
+
+void mlir::vector::populateVectorLinearizeBitWidthTargetAndConverter(
+    TypeConverter &typeConverter, ConversionTarget &target,
+    unsigned targetBitWidth) {
 
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
@@ -550,40 +552,34 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
         !isa<VectorType>(type))
       return nullptr;
-
     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
   };
+
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
+
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp>(op) ||
-             op->hasTrait<OpTrait::ConstantLike>() ||
-             op->hasTrait<OpTrait::Vectorizable>())) {
-          return (isLessThanTargetBitWidth(op, targetBitWidth)
-                      ? typeConverter.isLegal(op)
-                      : true);
-        }
-        return std::nullopt;
+        bool isDynamicallyLegal = dynamicallyLegal(op, targetBitWidth);
+        if (isDynamicallyLegal)
+          return true;
+
+        bool shapeUnchanged = typeConverter.isLegal(op);
+        return shapeUnchanged;
       });
+}
 
+void mlir::vector::populateVectorLinearizeBasePatterns(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    const ConversionTarget &target) {
   patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
-                                       targetBitWidth);
+               LinearizeVectorBitCast>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned int targetBitWidth) {
-  target.addDynamicallyLegalOp<vector::ShuffleOp>(
-      [=](vector::ShuffleOp shuffleOp) -> bool {
-        return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
-                   ? (typeConverter.isLegal(shuffleOp) &&
-                      cast<mlir::VectorType>(shuffleOp.getResult().getType())
-                              .getRank() == 1)
-                   : true;
-      });
+    const ConversionTarget &target) {
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
-      typeConverter, patterns.getContext(), targetBitWidth);
+      typeConverter, patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..76eb93e98599e 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -171,6 +171,7 @@ func.func @test_0d_vector() -> vector<f32> {
 }
 
 // -----
+
 // ALL-LABEL: test_extract_strided_slice_1
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
 func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
@@ -193,6 +194,8 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
   return %0 : vector<2x2xf32>
 }
 
+// -----
+
 // ALL-LABEL:   func.func @test_extract_strided_slice_1_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
 func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
@@ -205,6 +208,7 @@ func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> ve
 }
 
 // -----
+
 // ALL-LABEL: test_extract_strided_slice_2
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
 func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
@@ -228,6 +232,7 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
 }
 
 // -----
+
 // ALL-LABEL: test_vector_shuffle
 // ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
 func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
@@ -252,6 +257,7 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
 }
 
 // -----
+
 // ALL-LABEL: test_vector_extract
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
 func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -273,6 +279,8 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
   return %0 : vector<8x2xf32>
 }
 
+// -----
+
 // ALL-LABEL:   func.func @test_vector_extract_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
 func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
@@ -283,7 +291,9 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x
   // ALL: return %[[RES]] : vector<8x[2]xf32>
   return %0 : vector<8x[2]xf32>
 }
+
 // -----
+
 // ALL-LABEL: test_vector_insert
 // ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
 func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -312,6 +322,8 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
   return %0 : vector<2x8x4xf32>
 }
 
+// -----
+
 // ALL-LABEL:   func.func @test_vector_insert_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<2x8x[4]xf32>, %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
 func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
@@ -385,6 +397,7 @@ func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
 }
 
 // -----
+
 // ALL-LABEL: test_vector_bitcast
 // ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32>
 func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..7d40a416e4128 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -7,17 +7,13 @@
 //===----------------------------------------------------------------------===//
 
 #include <optional>
-#include <type_traits>
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -866,10 +862,15 @@ struct TestVectorLinearize final
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
 
-    vector::populateVectorLinearizeTypeConversionsAndLegality(
-        typeConverter, patterns, target, targetVectorBitwidth);
-    vector::populateVectorLinearizeShuffleLikeOpsPatterns(
-        typeConverter, patterns, target, targetVectorBitwidth);
+    vector::populateVectorLinearizeBitWidthTargetAndConverter(
+        typeConverter, target, targetVectorBitwidth);
+
+    vector::populateVectorLinearizeBasePatterns(typeConverter, patterns,
+                                                target);
+
+    vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter,
+                                                          patterns, target);
+
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();



More information about the Mlir-commits mailing list