[Mlir-commits] [mlir] [vector][linearize] Refactor code to push target bit width out of patterns (PR #136581)

James Newling llvmlistbot at llvm.org
Tue Apr 22 13:26:43 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/136581

>From b523a5a654c9889e30f3f822af1cc30257aeb107 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Apr 2025 09:32:22 -0700
Subject: [PATCH 1/3] factorize out the logic about bounded bitwidth

---
 .../Vector/Transforms/VectorRewritePatterns.h |  30 +-
 .../Vector/Transforms/VectorLinearize.cpp     | 264 +++++++++---------
 mlir/test/Dialect/Vector/linearize.mlir       |  13 +
 .../Dialect/Vector/TestVectorTransforms.cpp   |  17 +-
 4 files changed, 170 insertions(+), 154 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ce97847172197..d9a0791cdea33 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -392,18 +392,24 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Populates patterns for ND vectors (N >= 2) linearization and sets up the
-/// provided ConversionTarget with the appropriate legality configuration for
-/// the ops to get converted properly.
-void populateVectorLinearizeTypeConversionsAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
-
-/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
-/// vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(
-    const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
+/// Populate `typeConverter` and `conversionTarget` with the definition of
+/// legal types and operations, for the specific case where vectors with
+/// trailing dimensions of size greater than `targetBitWidth` are legal.
+void populateVectorLinearizeBitWidthTargetAndConverter(
+    TypeConverter &typeConverter, ConversionTarget &conversionTarget,
+    unsigned targetBitWidth);
+
+/// Populates `patterns` for ND vector (N >= 2) linearization. Patterns for
+/// converting ConstantLike, Vectorizable, and vector::BitCast.
+void populateVectorLinearizeBasePatterns(const TypeConverter &,
+                                         RewritePatternSet &patterns,
+                                         const ConversionTarget &);
+
+/// Populates `patterns` for linearizing ND (N >= 2) vector operations
+/// to 1D vector shuffle operations.
+void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
+                                                   RewritePatternSet &patterns,
+                                                   const ConversionTarget &);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..3a80ce815b766 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,7 +10,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -22,44 +21,16 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include <cstdint>
+#include <limits>
 #include <numeric>
+#include <optional>
 
 using namespace mlir;
 
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
-  auto resultTypes = op->getResultTypes();
-  for (auto resType : resultTypes) {
-    VectorType vecType = dyn_cast<VectorType>(resType);
-    // Reject index since getElementTypeBitWidth will abort for Index types.
-    if (!vecType || vecType.getElementType().isIndex())
-      return false;
-    // There are no dimension to fold if it is a 0-D vector.
-    if (vecType.getRank() == 0)
-      return false;
-    unsigned trailingVecDimBitWidth =
-        vecType.getShape().back() * vecType.getElementTypeBitWidth();
-    if (trailingVecDimBitWidth >= targetBitWidth)
-      return false;
-  }
-  return true;
-}
-
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
-  VectorType vecType = dyn_cast<VectorType>(t);
-  // Reject index since getElementTypeBitWidth will abort for Index types.
-  if (!vecType || vecType.getElementType().isIndex())
-    return false;
-  // There are no dimension to fold if it is a 0-D vector.
-  if (vecType.getRank() == 0)
-    return false;
-  unsigned trailingVecDimBitWidth =
-      vecType.getShape().back() * vecType.getElementTypeBitWidth();
-  return trailingVecDimBitWidth <= targetBitWidth;
-}
-
 static FailureOr<Attribute>
 linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
                    VectorType resType, Attribute value) {
+
   if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
     if (resType.isScalable() && !isa<SplatElementsAttr>(value))
       return rewriter.notifyMatchFailure(
@@ -76,16 +47,14 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
 }
 
 namespace {
+
 struct LinearizeConstantLike final
     : OpTraitConversionPattern<OpTrait::ConstantLike> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
-  LinearizeConstantLike(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpTraitConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeConstantLike(const TypeConverter &typeConverter,
+                        MLIRContext *context, PatternBenefit benefit = 1)
+      : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
@@ -100,10 +69,6 @@ struct LinearizeConstantLike final
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
 
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          loc, "Can't flatten since targetBitWidth <= OpSize");
-
     StringAttr attrName = rewriter.getStringAttr("value");
     Attribute value = op->getAttr(attrName);
     if (!value)
@@ -124,9 +89,6 @@ struct LinearizeConstantLike final
     rewriter.replaceOp(op, newOp);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 struct LinearizeVectorizable final
@@ -134,18 +96,12 @@ struct LinearizeVectorizable final
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
 public:
-  LinearizeVectorizable(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpTraitConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeVectorizable(const TypeConverter &typeConverter,
+                        MLIRContext *context, PatternBenefit benefit = 1)
+      : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
     FailureOr<Operation *> newOp =
         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
     if (failed(newOp))
@@ -154,9 +110,6 @@ struct LinearizeVectorizable final
     rewriter.replaceOp(op, (*newOp)->getResults());
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
@@ -173,12 +126,10 @@ struct LinearizeVectorizable final
 struct LinearizeVectorExtractStridedSlice final
     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorExtractStridedSlice(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
+                                     MLIRContext *context,
+                                     PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -189,9 +140,6 @@ struct LinearizeVectorExtractStridedSlice final
     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
     ArrayAttr offsets = extractOp.getOffsets();
     ArrayAttr sizes = extractOp.getSizes();
@@ -268,9 +216,6 @@ struct LinearizeVectorExtractStridedSlice final
         extractOp, dstType, srcVector, srcVector, indices);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ShuffleOp that works on nD (n > 1)
@@ -291,8 +236,7 @@ struct LinearizeVectorShuffle final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -302,13 +246,12 @@ struct LinearizeVectorShuffle final
     assert(dstType && "vector type destination expected.");
     // The assert is used because vector.shuffle does not support scalable
     // vectors.
-    assert(!(shuffleOp.getV1VectorType().isScalable() ||
-             shuffleOp.getV2VectorType().isScalable() ||
-             dstType.isScalable()) &&
-           "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
+    bool scalable = shuffleOp.getV1VectorType().isScalable() ||
+                    shuffleOp.getV2VectorType().isScalable() ||
+                    dstType.isScalable();
+    if (scalable)
+      return rewriter.notifyMatchFailure(shuffleOp,
+                                         "scalable vectors are not supported.");
 
     Value vec1 = adaptor.getV1();
     Value vec2 = adaptor.getV2();
@@ -343,9 +286,6 @@ struct LinearizeVectorShuffle final
                                                    vec2, indices);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
@@ -364,8 +304,7 @@ struct LinearizeVectorExtract final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -378,9 +317,6 @@ struct LinearizeVectorExtract final
         cast<VectorType>(dstTy).isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
     // Dynamic position is not supported.
     if (extractOp.hasDynamicPosition())
@@ -405,9 +341,6 @@ struct LinearizeVectorExtract final
 
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the InsertOp to a ShuffleOp that works on a
@@ -427,8 +360,7 @@ struct LinearizeVectorInsert final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -439,11 +371,6 @@ struct LinearizeVectorInsert final
       return rewriter.notifyMatchFailure(insertOp,
                                          "scalable vectors are not supported.");
 
-    if (!isLessThanOrEqualTargetBitWidth(insertOp.getValueToStoreType(),
-                                         targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          insertOp, "Can't flatten since targetBitWidth < OpSize");
-
     // dynamic position is not supported
     if (insertOp.hasDynamicPosition())
       return rewriter.notifyMatchFailure(insertOp,
@@ -471,11 +398,11 @@ struct LinearizeVectorInsert final
     }
 
     llvm::SmallVector<int64_t, 2> indices(dstSize);
-    auto origValsUntil = indices.begin();
+    auto *origValsUntil = indices.begin();
     std::advance(origValsUntil, linearizedOffset);
     std::iota(indices.begin(), origValsUntil,
               0); // original values that remain [0, offset)
-    auto newValsUntil = origValsUntil;
+    auto *newValsUntil = origValsUntil;
     std::advance(newValsUntil, srcSize);
     std::iota(origValsUntil, newValsUntil,
               dstSize); // new values [offset, offset+srcNumElements)
@@ -488,9 +415,6 @@ struct LinearizeVectorInsert final
 
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the BitCastOp that works on nD (n > 1)
@@ -508,8 +432,7 @@ struct LinearizeVectorBitCast final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -518,24 +441,103 @@ struct LinearizeVectorBitCast final
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type.");
 
-    if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          loc, "Can't flatten since targetBitWidth <= OpSize");
-
     rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
                                                    adaptor.getSource());
     return mlir::success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 } // namespace
 
-void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth) {
+/// If `type` is VectorType with trailing dimension of (bit) size greater than
+/// or equal to `targetBitWidth`, its defining op is considered legal. 
+static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
+
+  VectorType vecType = dyn_cast<VectorType>(type);
+
+  if (!vecType)
+    return true;
+
+  // The width of the type 'index' is unbounded (and therefore potentially above
+  // the target width).
+  if (vecType.getElementType().isIndex())
+    return true;
+
+  unsigned finalDimSize =
+      vecType.getRank() == 0 ? 0 : vecType.getShape().back();
+
+  unsigned trailingVecDimBitWidth =
+      finalDimSize * vecType.getElementTypeBitWidth();
+
+  return trailingVecDimBitWidth >= targetBitWidth;
+}
+
+static SmallVector<std::pair<Type, unsigned>>
+getChecksForBitwidth(Operation *op, unsigned targetBitWidth) {
+
+  if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+    auto w = targetBitWidth < std::numeric_limits<unsigned>::max()
+                 ? targetBitWidth + 1
+                 : targetBitWidth;
+    return {{insertOp.getValueToStoreType(), w}};
+  }
+  auto resultTypes = op->getResultTypes();
+  SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+  resultsWithBitWidth.reserve(resultTypes.size());
+  for (Type type : resultTypes) {
+    resultsWithBitWidth.push_back({type, targetBitWidth});
+  }
+  return resultsWithBitWidth;
+}
+
+/// Return true if the operation `op` does not support scalable vectors and
+/// has at least 1 scalable vector result.
+static bool legalBecauseScalable(Operation *op) {
+
+  bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
+                           op->hasTrait<OpTrait::Vectorizable>() ||
+                           isa<vector::BitCastOp>(op);
+
+  if (scalableSupported)
+    return false;
+
+  // Check if any of the results is a scalable vector type.
+  auto types = op->getResultTypes();
+  bool containsScalableResult =
+      std::any_of(types.begin(), types.end(), [](Type type) {
+        auto vecType = dyn_cast<VectorType>(type);
+        return vecType && vecType.isScalable();
+      });
+
+  return containsScalableResult;
+}
+
+static bool dynamicallyLegal(Operation *op, unsigned targetBitWidth) {
+
+  // Only ops that are in the vector dialect, are ConstantLike, or
+  // are Vectorizable might be linearized currently, so legalize the others.
+  bool opIsVectorDialect = op->getDialect()->getNamespace() ==
+                           vector::VectorDialect::getDialectNamespace();
+  if (!opIsVectorDialect && !op->hasTrait<OpTrait::ConstantLike>() &&
+      !op->hasTrait<OpTrait::Vectorizable>())
+    return true;
+
+  // Some ops will not be linearized if they have scalable vector results.
+  if (legalBecauseScalable(op))
+    return true;
+
+  // Check on bitwidths.
+  auto typesToCheck = getChecksForBitwidth(op, targetBitWidth);
+  return std::any_of(typesToCheck.begin(), typesToCheck.end(),
+                     [&](std::pair<Type, unsigned> typeWidth) {
+                       return legalBecauseOfBitwidth(typeWidth.first,
+                                                     typeWidth.second);
+                     });
+}
+
+void mlir::vector::populateVectorLinearizeBitWidthTargetAndConverter(
+    TypeConverter &typeConverter, ConversionTarget &target,
+    unsigned targetBitWidth) {
 
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
@@ -550,40 +552,34 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
         !isa<VectorType>(type))
       return nullptr;
-
     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
   };
+
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
+
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp>(op) ||
-             op->hasTrait<OpTrait::ConstantLike>() ||
-             op->hasTrait<OpTrait::Vectorizable>())) {
-          return (isLessThanTargetBitWidth(op, targetBitWidth)
-                      ? typeConverter.isLegal(op)
-                      : true);
-        }
-        return std::nullopt;
+        bool isDynamicallyLegal = dynamicallyLegal(op, targetBitWidth);
+        if (isDynamicallyLegal)
+          return true;
+
+        bool shapeUnchanged = typeConverter.isLegal(op);
+        return shapeUnchanged;
       });
+}
 
+void mlir::vector::populateVectorLinearizeBasePatterns(
+    const TypeConverter &typeConverter, RewritePatternSet &patterns,
+    const ConversionTarget &target) {
   patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
-                                       targetBitWidth);
+               LinearizeVectorBitCast>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned int targetBitWidth) {
-  target.addDynamicallyLegalOp<vector::ShuffleOp>(
-      [=](vector::ShuffleOp shuffleOp) -> bool {
-        return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
-                   ? (typeConverter.isLegal(shuffleOp) &&
-                      cast<mlir::VectorType>(shuffleOp.getResult().getType())
-                              .getRank() == 1)
-                   : true;
-      });
+    const ConversionTarget &target) {
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
-      typeConverter, patterns.getContext(), targetBitWidth);
+      typeConverter, patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..76eb93e98599e 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -171,6 +171,7 @@ func.func @test_0d_vector() -> vector<f32> {
 }
 
 // -----
+
 // ALL-LABEL: test_extract_strided_slice_1
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
 func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
@@ -193,6 +194,8 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
   return %0 : vector<2x2xf32>
 }
 
+// -----
+
 // ALL-LABEL:   func.func @test_extract_strided_slice_1_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
 func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
@@ -205,6 +208,7 @@ func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> ve
 }
 
 // -----
+
 // ALL-LABEL: test_extract_strided_slice_2
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
 func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
@@ -228,6 +232,7 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
 }
 
 // -----
+
 // ALL-LABEL: test_vector_shuffle
 // ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
 func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
@@ -252,6 +257,7 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
 }
 
 // -----
+
 // ALL-LABEL: test_vector_extract
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
 func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -273,6 +279,8 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
   return %0 : vector<8x2xf32>
 }
 
+// -----
+
 // ALL-LABEL:   func.func @test_vector_extract_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
 func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
@@ -283,7 +291,9 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x
   // ALL: return %[[RES]] : vector<8x[2]xf32>
   return %0 : vector<8x[2]xf32>
 }
+
 // -----
+
 // ALL-LABEL: test_vector_insert
 // ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
 func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -312,6 +322,8 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
   return %0 : vector<2x8x4xf32>
 }
 
+// -----
+
 // ALL-LABEL:   func.func @test_vector_insert_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<2x8x[4]xf32>, %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
 func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
@@ -385,6 +397,7 @@ func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
 }
 
 // -----
+
 // ALL-LABEL: test_vector_bitcast
 // ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32>
 func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..7d40a416e4128 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -7,17 +7,13 @@
 //===----------------------------------------------------------------------===//
 
 #include <optional>
-#include <type_traits>
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -866,10 +862,15 @@ struct TestVectorLinearize final
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
 
-    vector::populateVectorLinearizeTypeConversionsAndLegality(
-        typeConverter, patterns, target, targetVectorBitwidth);
-    vector::populateVectorLinearizeShuffleLikeOpsPatterns(
-        typeConverter, patterns, target, targetVectorBitwidth);
+    vector::populateVectorLinearizeBitWidthTargetAndConverter(
+        typeConverter, target, targetVectorBitwidth);
+
+    vector::populateVectorLinearizeBasePatterns(typeConverter, patterns,
+                                                target);
+
+    vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter,
+                                                          patterns, target);
+
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();

>From 86ceb57b31dffd6a433ac11cf748dc4d5b1ea2e7 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Apr 2025 10:49:23 -0700
Subject: [PATCH 2/3] clang-format

---
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 3a80ce815b766..e24c8ee961c51 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -450,7 +450,7 @@ struct LinearizeVectorBitCast final
 } // namespace
 
 /// If `type` is VectorType with trailing dimension of (bit) size greater than
-/// or equal to `targetBitWidth`, its defining op is considered legal. 
+/// or equal to `targetBitWidth`, its defining op is considered legal.
 static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
 
   VectorType vecType = dyn_cast<VectorType>(type);

>From be48849486b1c1ae68568dee941acc2bc7d49951 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 22 Apr 2025 13:26:11 -0700
Subject: [PATCH 3/3] push further with the separation of concerns

---
 .../Vector/Transforms/VectorRewritePatterns.h |  31 ++--
 .../Vector/Transforms/VectorLinearize.cpp     | 167 ++++++------------
 mlir/test/Dialect/Vector/linearize.mlir       |   7 +-
 .../Dialect/Vector/TestVectorTransforms.cpp   | 144 +++++++++++++--
 4 files changed, 205 insertions(+), 144 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index d9a0791cdea33..91f77307ddf8b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -392,24 +392,29 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Populate `typeConverter` and `conversionTarget` with the definition of
-/// legal types and operations, for the specific case where vectors with
-/// trailing dimensions of size greater than `targetBitWidth` are legal.
-void populateVectorLinearizeBitWidthTargetAndConverter(
-    TypeConverter &typeConverter, ConversionTarget &conversionTarget,
-    unsigned targetBitWidth);
-
-/// Populates `patterns` for ND vector (N >= 2) linearization. Patterns for
-/// converting ConstantLike, Vectorizable, and vector::BitCast.
+/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
+/// This registers (1) which operations are legal and hence should not be
+/// linearized, (2) what converted types are (rank-1 vectors) and how to
+/// materialze the conversion (with shape_cast)
+///
+/// Note: the set of legal operations can be extended by a user if for example
+/// certain rank>1 vectors are considered valid, but adding additional
+/// dynamically legal ops to `conversionTarget`.
+void populateForVectorLinearize(TypeConverter &typeConverter,
+                                ConversionTarget &conversionTarget);
+
+/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
+/// contains patterns for converting ConstantLike, Vectorizable, and
+/// vector::BitCast ops.
 void populateVectorLinearizeBasePatterns(const TypeConverter &,
-                                         RewritePatternSet &patterns,
-                                         const ConversionTarget &);
+                                         const ConversionTarget &,
+                                         RewritePatternSet &patterns);
 
 /// Populates `patterns` for linearizing ND (N >= 2) vector operations
 /// to 1D vector shuffle operations.
 void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
-                                                   RewritePatternSet &patterns,
-                                                   const ConversionTarget &);
+                                                   const ConversionTarget &,
+                                                   RewritePatternSet &patterns);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index e24c8ee961c51..67e15852dc5ea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -62,12 +62,10 @@ struct LinearizeConstantLike final
     if (op->getNumResults() != 1)
       return rewriter.notifyMatchFailure(loc, "expected 1 result");
 
-    const TypeConverter &converter = *getTypeConverter();
+    const TypeConverter &typeConverter = *getTypeConverter();
     auto resType =
-        converter.convertType<VectorType>(op->getResult(0).getType());
-
-    if (!resType)
-      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+        typeConverter.convertType<VectorType>(op->getResult(0).getType());
+    assert(resType && "expected 1-D vector type");
 
     StringAttr attrName = rewriter.getStringAttr("value");
     Attribute value = op->getAttr(attrName);
@@ -80,7 +78,7 @@ struct LinearizeConstantLike final
       return failure();
 
     FailureOr<Operation *> convertResult =
-        convertOpResultTypes(op, /*operands=*/{}, converter, rewriter);
+        convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter);
     if (failed(convertResult))
       return failure();
 
@@ -244,14 +242,6 @@ struct LinearizeVectorShuffle final
     VectorType dstType =
         getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
     assert(dstType && "vector type destination expected.");
-    // The assert is used because vector.shuffle does not support scalable
-    // vectors.
-    bool scalable = shuffleOp.getV1VectorType().isScalable() ||
-                    shuffleOp.getV2VectorType().isScalable() ||
-                    dstType.isScalable();
-    if (scalable)
-      return rewriter.notifyMatchFailure(shuffleOp,
-                                         "scalable vectors are not supported.");
 
     Value vec1 = adaptor.getV1();
     Value vec2 = adaptor.getV2();
@@ -270,7 +260,7 @@ struct LinearizeVectorShuffle final
     }
 
     // For each value in the mask, we generate the indices of the source vectors
-    // that needs to be shuffled to the destination vector. If shuffleSliceLen >
+    // that need to be shuffled to the destination vector. If shuffleSliceLen >
     // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
     // elements) instead of scalars.
     ArrayRef<int64_t> mask = shuffleOp.getMask();
@@ -309,14 +299,7 @@ struct LinearizeVectorExtract final
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
-    if (!dstTy)
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "expected n-D vector type.");
-
-    if (extractOp.getVector().getType().isScalable() ||
-        cast<VectorType>(dstTy).isScalable())
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "scalable vectors are not supported.");
+    assert(dstTy && "expected 1-D vector type");
 
     // Dynamic position is not supported.
     if (extractOp.hasDynamicPosition())
@@ -367,9 +350,6 @@ struct LinearizeVectorInsert final
     VectorType dstTy = getTypeConverter()->convertType<VectorType>(
         insertOp.getDestVectorType());
     assert(dstTy && "vector type destination expected.");
-    if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
-      return rewriter.notifyMatchFailure(insertOp,
-                                         "scalable vectors are not supported.");
 
     // dynamic position is not supported
     if (insertOp.hasDynamicPosition())
@@ -436,11 +416,8 @@ struct LinearizeVectorBitCast final
   LogicalResult
   matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = castOp.getLoc();
     auto resType = getTypeConverter()->convertType(castOp.getType());
-    if (!resType)
-      return rewriter.notifyMatchFailure(loc, "can't convert return type.");
-
+    assert(resType && "expected 1-D vector type");
     rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
                                                    adaptor.getSource());
     return mlir::success();
@@ -449,56 +426,15 @@ struct LinearizeVectorBitCast final
 
 } // namespace
 
-/// If `type` is VectorType with trailing dimension of (bit) size greater than
-/// or equal to `targetBitWidth`, its defining op is considered legal.
-static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
-
-  VectorType vecType = dyn_cast<VectorType>(type);
-
-  if (!vecType)
-    return true;
-
-  // The width of the type 'index' is unbounded (and therefore potentially above
-  // the target width).
-  if (vecType.getElementType().isIndex())
-    return true;
-
-  unsigned finalDimSize =
-      vecType.getRank() == 0 ? 0 : vecType.getShape().back();
-
-  unsigned trailingVecDimBitWidth =
-      finalDimSize * vecType.getElementTypeBitWidth();
-
-  return trailingVecDimBitWidth >= targetBitWidth;
-}
-
-static SmallVector<std::pair<Type, unsigned>>
-getChecksForBitwidth(Operation *op, unsigned targetBitWidth) {
-
-  if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
-    auto w = targetBitWidth < std::numeric_limits<unsigned>::max()
-                 ? targetBitWidth + 1
-                 : targetBitWidth;
-    return {{insertOp.getValueToStoreType(), w}};
-  }
-  auto resultTypes = op->getResultTypes();
-  SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
-  resultsWithBitWidth.reserve(resultTypes.size());
-  for (Type type : resultTypes) {
-    resultsWithBitWidth.push_back({type, targetBitWidth});
-  }
-  return resultsWithBitWidth;
-}
-
 /// Return true if the operation `op` does not support scalable vectors and
-/// has at least 1 scalable vector result.
-static bool legalBecauseScalable(Operation *op) {
-
-  bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
-                           op->hasTrait<OpTrait::Vectorizable>() ||
-                           isa<vector::BitCastOp>(op);
-
-  if (scalableSupported)
+/// has at least 1 scalable vector result. These ops should all eventually
+/// support scalable vectors, and this function should be removed.
+static bool isNotLinearizableBecauseScalable(Operation *op) {
+
+  bool unsupported =
+      isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
+          op);
+  if (!unsupported)
     return false;
 
   // Check if any of the results is a scalable vector type.
@@ -512,73 +448,74 @@ static bool legalBecauseScalable(Operation *op) {
   return containsScalableResult;
 }
 
-static bool dynamicallyLegal(Operation *op, unsigned targetBitWidth) {
+static bool isNotLinearizable(Operation *op) {
 
   // Only ops that are in the vector dialect, are ConstantLike, or
-  // are Vectorizable might be linearized currently, so legalize the others.
-  bool opIsVectorDialect = op->getDialect()->getNamespace() ==
-                           vector::VectorDialect::getDialectNamespace();
-  if (!opIsVectorDialect && !op->hasTrait<OpTrait::ConstantLike>() &&
-      !op->hasTrait<OpTrait::Vectorizable>())
+  // are Vectorizable might be linearized currently.
+  StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
+  StringRef opDialect = op->getDialect()->getNamespace();
+  bool unsupported = (opDialect != vectorDialect) &&
+                     !op->hasTrait<OpTrait::ConstantLike>() &&
+                     !op->hasTrait<OpTrait::Vectorizable>();
+  if (unsupported)
     return true;
 
-  // Some ops will not be linearized if they have scalable vector results.
-  if (legalBecauseScalable(op))
+  // Some ops currently don't support scalable vectors.
+  if (isNotLinearizableBecauseScalable(op))
     return true;
 
-  // Check on bitwidths.
-  auto typesToCheck = getChecksForBitwidth(op, targetBitWidth);
-  return std::any_of(typesToCheck.begin(), typesToCheck.end(),
-                     [&](std::pair<Type, unsigned> typeWidth) {
-                       return legalBecauseOfBitwidth(typeWidth.first,
-                                                     typeWidth.second);
-                     });
+  return false;
 }
 
-void mlir::vector::populateVectorLinearizeBitWidthTargetAndConverter(
-    TypeConverter &typeConverter, ConversionTarget &target,
-    unsigned targetBitWidth) {
+void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
+                                              ConversionTarget &target) {
 
-  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
-    if (!isLinearizableVector(type))
+  auto convertType = [](Type type) -> std::optional<Type> {
+    VectorType vectorType = dyn_cast<VectorType>(type);
+    if (!vectorType || !isLinearizableVector(vectorType))
       return type;
 
-    return VectorType::get(type.getNumElements(), type.getElementType(),
-                           type.isScalable());
-  });
+    VectorType linearizedType =
+        VectorType::get(vectorType.getNumElements(),
+                        vectorType.getElementType(), vectorType.isScalable());
+    return linearizedType;
+  };
+  typeConverter.addConversion(convertType);
 
   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
                             Location loc) -> Value {
-    if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
-        !isa<VectorType>(type))
+    if (inputs.size() != 1)
       return nullptr;
-    return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
-  };
 
+    Value value = inputs.front();
+    if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
+      return nullptr;
+
+    return builder.create<vector::ShapeCastOp>(loc, type, value);
+  };
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
 
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        bool isDynamicallyLegal = dynamicallyLegal(op, targetBitWidth);
-        if (isDynamicallyLegal)
+        if (isNotLinearizable(op))
           return true;
-
-        bool shapeUnchanged = typeConverter.isLegal(op);
-        return shapeUnchanged;
+        // This will return true if, for all operand and result types `t`,
+        // convertType(t) = t. This is true if there are no rank>=2 vectors.
+        return typeConverter.isLegal(op);
       });
 }
 
 void mlir::vector::populateVectorLinearizeBasePatterns(
-    const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    const ConversionTarget &target) {
+    const TypeConverter &typeConverter, const ConversionTarget &target,
+    RewritePatternSet &patterns) {
   patterns.add<LinearizeConstantLike, LinearizeVectorizable,
                LinearizeVectorBitCast>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
-    const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    const ConversionTarget &target) {
+    const TypeConverter &typeConverter, const ConversionTarget &target,
+    RewritePatternSet &patterns) {
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
       typeConverter, patterns.getContext());
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 76eb93e98599e..b3f2dddaee356 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128  -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
+
+// RUN: mlir-opt %s -split-input-file -test-bit-width-contrained-vector-linearize=target-vector-bitwidth=128  -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
+// RUN: mlir-opt %s -split-input-file -test-bit-width-contrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
 
 // ALL-LABEL: test_linearize
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -97,7 +98,7 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
 
 // ALL-LABEL: test_index_no_linearize
 func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
-    // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
     %0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
     return %0 : vector<2x2xindex>
 }
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 7d40a416e4128..ba5d82ad38585 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -835,16 +835,98 @@ struct TestVectorEmulateMaskedLoadStore final
   }
 };
 
-struct TestVectorLinearize final
-    : public PassWrapper<TestVectorLinearize, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+// TODO: move this code into the user project.
+namespace vendor {
 
-  TestVectorLinearize() = default;
-  TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+/// Get the set of operand/result types to check for sufficiently
+/// small inner-most dimension size.
+static SmallVector<std::pair<Type, unsigned>>
+getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
 
-  StringRef getArgument() const override { return "test-vector-linearize"; }
+  if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+    unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
+                     ? targetBitWidth + 1
+                     : targetBitWidth;
+    return {{insertOp.getValueToStoreType(), w}};
+  }
+
+  auto resultTypes = op->getResultTypes();
+  SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+  resultsWithBitWidth.reserve(resultTypes.size());
+  for (Type type : resultTypes) {
+    resultsWithBitWidth.push_back({type, targetBitWidth});
+  }
+  return resultsWithBitWidth;
+}
+
+/// If `type` is VectorType with trailing dimension of (bit) size greater than
+/// or equal to `targetBitWidth`, its defining op is considered legal.
+static bool
+isNotLinearizableBecauseLargeInnerDimension(Type type,
+                                            unsigned targetBitWidth) {
+
+  VectorType vecType = dyn_cast<VectorType>(type);
+
+  // Not linearizable for reasons other than what this function checks.
+  if (!vecType || vecType.getRank() == 0)
+    return false;
+
+  // The width of the type 'index' is unbounded (and therefore potentially above
+  // the target width).
+  if (vecType.getElementType().isIndex())
+    return true;
+
+  unsigned finalDimSize = vecType.getShape().back();
+  unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
+  unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
+  return trailingVecDimBitWidth >= targetBitWidth;
+}
+
+static bool
+isNotLinearizableBecauseLargeInnerDimension(Operation *op,
+                                            unsigned targetBitWidth) {
+  // Check on bitwidths.
+  SmallVector<std::pair<Type, unsigned>> toCheck =
+      getTypeBitWidthBoundPairs(op, targetBitWidth);
+  return std::any_of(toCheck.begin(), toCheck.end(),
+                     [&](std::pair<Type, unsigned> typeWidth) {
+                       return isNotLinearizableBecauseLargeInnerDimension(
+                           typeWidth.first, typeWidth.second);
+                     });
+}
+
+void populateWithBitWidthConstraints(TypeConverter &typeConverter,
+                                     ConversionTarget &target,
+                                     unsigned targetBitWidth) {
+
+  // The general purpose definition of what ops are legal must come first.
+  populateForVectorLinearize(typeConverter, target);
+
+  // Extend the set of legal ops to include those with large inner-most
+  // dimensions on selected operands/results.
+  target.markUnknownOpDynamicallyLegal(
+      [=](Operation *op) -> std::optional<bool> {
+        if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
+          return true;
+        }
+        return {};
+      });
+}
+
+struct TestVectorBitWidthLinearize final
+    : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
+
+  TestVectorBitWidthLinearize() = default;
+  TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
+      : PassWrapper(pass) {}
+
+  StringRef getArgument() const override {
+    return "test-bit-width-contrained-vector-linearize";
+  }
   StringRef getDescription() const override {
-    return "Linearizes ND vectors for N >= 2 into 1D vectors";
+    return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
+           "in inner-most dimension's bit width.";
   }
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<vector::VectorDialect>();
@@ -862,14 +944,48 @@ struct TestVectorLinearize final
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
 
-    vector::populateVectorLinearizeBitWidthTargetAndConverter(
-        typeConverter, target, targetVectorBitwidth);
+    populateWithBitWidthConstraints(typeConverter, target,
+                                                      targetVectorBitwidth);
 
-    vector::populateVectorLinearizeBasePatterns(typeConverter, patterns,
-                                                target);
+    vector::populateVectorLinearizeBasePatterns(typeConverter, target,
+                                                patterns);
 
-    vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter,
-                                                          patterns, target);
+    vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target,
+                                                          patterns);
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace vendor
+
+struct TestVectorLinearize final
+    : public PassWrapper<TestVectorLinearize, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+
+  TestVectorLinearize() = default;
+
+  StringRef getArgument() const override { return "test-vector-linearize"; }
+  StringRef getDescription() const override {
+    return "Linearizes ND vectors for N >= 2 into 1D vectors";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    TypeConverter converter;
+    RewritePatternSet patterns(&context);
+    ConversionTarget target(context);
+
+    vector::populateForVectorLinearize(converter, target);
+
+    vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
+    vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
+                                                          patterns);
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -950,6 +1066,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorLinearize>();
 
+  PassRegistration<vendor::TestVectorBitWidthLinearize>();
+
   PassRegistration<TestEliminateVectorMasks>();
 }
 } // namespace test



More information about the Mlir-commits mailing list