[Mlir-commits] [mlir] bad8bf5 - [mlir][vector] Linearization: push 'bit width' logic out of patterns (#136581)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 30 09:05:44 PDT 2025


Author: James Newling
Date: 2025-04-30T09:05:40-07:00
New Revision: bad8bf56d3e4f107423b307f5f75564296703a76

URL: https://github.com/llvm/llvm-project/commit/bad8bf56d3e4f107423b307f5f75564296703a76
DIFF: https://github.com/llvm/llvm-project/commit/bad8bf56d3e4f107423b307f5f75564296703a76.diff

LOG: [mlir][vector] Linearization: push 'bit width' logic out of patterns (#136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In https://github.com/llvm/llvm-project/pull/83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/test/Dialect/Vector/linearize.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7a079dcc6affc..f1100d5cf8b68 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,18 +406,29 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Populates patterns for ND vectors (N >= 2) linearization and sets up the
-/// provided ConversionTarget with the appropriate legality configuration for
-/// the ops to get converted properly.
-void populateVectorLinearizeTypeConversionsAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
-
-/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
-/// vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(
-    const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
+/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
+/// This registers (1) which operations are legal and hence should not be
+/// linearized, (2) what converted types are (rank-1 vectors) and how to
+/// materialze the conversion (with shape_cast)
+///
+/// Note: the set of legal operations can be extended by a user if for example
+/// certain rank>1 vectors are considered valid, but adding additional
+/// dynamically legal ops to `conversionTarget`.
+void populateForVectorLinearize(TypeConverter &typeConverter,
+                                ConversionTarget &conversionTarget);
+
+/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
+/// contains patterns for converting ConstantLike, Vectorizable, and
+/// vector::BitCast ops.
+void populateVectorLinearizeBasePatterns(const TypeConverter &,
+                                         const ConversionTarget &,
+                                         RewritePatternSet &patterns);
+
+/// Populates `patterns` for linearizing ND (N >= 2) vector operations
+/// to 1D vector shuffle operations.
+void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
+                                                   const ConversionTarget &,
+                                                   RewritePatternSet &patterns);
 
 } // namespace vector
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..67e15852dc5ea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,7 +10,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -22,44 +21,16 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include <cstdint>
+#include <limits>
 #include <numeric>
+#include <optional>
 
 using namespace mlir;
 
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
-  auto resultTypes = op->getResultTypes();
-  for (auto resType : resultTypes) {
-    VectorType vecType = dyn_cast<VectorType>(resType);
-    // Reject index since getElementTypeBitWidth will abort for Index types.
-    if (!vecType || vecType.getElementType().isIndex())
-      return false;
-    // There are no dimension to fold if it is a 0-D vector.
-    if (vecType.getRank() == 0)
-      return false;
-    unsigned trailingVecDimBitWidth =
-        vecType.getShape().back() * vecType.getElementTypeBitWidth();
-    if (trailingVecDimBitWidth >= targetBitWidth)
-      return false;
-  }
-  return true;
-}
-
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
-  VectorType vecType = dyn_cast<VectorType>(t);
-  // Reject index since getElementTypeBitWidth will abort for Index types.
-  if (!vecType || vecType.getElementType().isIndex())
-    return false;
-  // There are no dimension to fold if it is a 0-D vector.
-  if (vecType.getRank() == 0)
-    return false;
-  unsigned trailingVecDimBitWidth =
-      vecType.getShape().back() * vecType.getElementTypeBitWidth();
-  return trailingVecDimBitWidth <= targetBitWidth;
-}
-
 static FailureOr<Attribute>
 linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
                    VectorType resType, Attribute value) {
+
   if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
     if (resType.isScalable() && !isa<SplatElementsAttr>(value))
       return rewriter.notifyMatchFailure(
@@ -76,16 +47,14 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
 }
 
 namespace {
+
 struct LinearizeConstantLike final
     : OpTraitConversionPattern<OpTrait::ConstantLike> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
-  LinearizeConstantLike(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpTraitConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeConstantLike(const TypeConverter &typeConverter,
+                        MLIRContext *context, PatternBenefit benefit = 1)
+      : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
@@ -93,16 +62,10 @@ struct LinearizeConstantLike final
     if (op->getNumResults() != 1)
       return rewriter.notifyMatchFailure(loc, "expected 1 result");
 
-    const TypeConverter &converter = *getTypeConverter();
+    const TypeConverter &typeConverter = *getTypeConverter();
     auto resType =
-        converter.convertType<VectorType>(op->getResult(0).getType());
-
-    if (!resType)
-      return rewriter.notifyMatchFailure(loc, "can't convert return type");
-
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          loc, "Can't flatten since targetBitWidth <= OpSize");
+        typeConverter.convertType<VectorType>(op->getResult(0).getType());
+    assert(resType && "expected 1-D vector type");
 
     StringAttr attrName = rewriter.getStringAttr("value");
     Attribute value = op->getAttr(attrName);
@@ -115,7 +78,7 @@ struct LinearizeConstantLike final
       return failure();
 
     FailureOr<Operation *> convertResult =
-        convertOpResultTypes(op, /*operands=*/{}, converter, rewriter);
+        convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter);
     if (failed(convertResult))
       return failure();
 
@@ -124,9 +87,6 @@ struct LinearizeConstantLike final
     rewriter.replaceOp(op, newOp);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 struct LinearizeVectorizable final
@@ -134,18 +94,12 @@ struct LinearizeVectorizable final
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
 public:
-  LinearizeVectorizable(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpTraitConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeVectorizable(const TypeConverter &typeConverter,
+                        MLIRContext *context, PatternBenefit benefit = 1)
+      : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
     FailureOr<Operation *> newOp =
         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
     if (failed(newOp))
@@ -154,9 +108,6 @@ struct LinearizeVectorizable final
     rewriter.replaceOp(op, (*newOp)->getResults());
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
@@ -173,12 +124,10 @@ struct LinearizeVectorizable final
 struct LinearizeVectorExtractStridedSlice final
     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorExtractStridedSlice(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
+                                     MLIRContext *context,
+                                     PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -189,9 +138,6 @@ struct LinearizeVectorExtractStridedSlice final
     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
     ArrayAttr offsets = extractOp.getOffsets();
     ArrayAttr sizes = extractOp.getSizes();
@@ -268,9 +214,6 @@ struct LinearizeVectorExtractStridedSlice final
         extractOp, dstType, srcVector, srcVector, indices);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ShuffleOp that works on nD (n > 1)
@@ -291,8 +234,7 @@ struct LinearizeVectorShuffle final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -300,15 +242,6 @@ struct LinearizeVectorShuffle final
     VectorType dstType =
         getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
     assert(dstType && "vector type destination expected.");
-    // The assert is used because vector.shuffle does not support scalable
-    // vectors.
-    assert(!(shuffleOp.getV1VectorType().isScalable() ||
-             shuffleOp.getV2VectorType().isScalable() ||
-             dstType.isScalable()) &&
-           "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
 
     Value vec1 = adaptor.getV1();
     Value vec2 = adaptor.getV2();
@@ -327,7 +260,7 @@ struct LinearizeVectorShuffle final
     }
 
     // For each value in the mask, we generate the indices of the source vectors
-    // that needs to be shuffled to the destination vector. If shuffleSliceLen >
+    // that need to be shuffled to the destination vector. If shuffleSliceLen >
     // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
     // elements) instead of scalars.
     ArrayRef<int64_t> mask = shuffleOp.getMask();
@@ -343,9 +276,6 @@ struct LinearizeVectorShuffle final
                                                    vec2, indices);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
@@ -364,23 +294,12 @@ struct LinearizeVectorExtract final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
-    if (!dstTy)
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "expected n-D vector type.");
-
-    if (extractOp.getVector().getType().isScalable() ||
-        cast<VectorType>(dstTy).isScalable())
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          extractOp, "Can't flatten since targetBitWidth <= OpSize");
+    assert(dstTy && "expected 1-D vector type");
 
     // Dynamic position is not supported.
     if (extractOp.hasDynamicPosition())
@@ -405,9 +324,6 @@ struct LinearizeVectorExtract final
 
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the InsertOp to a ShuffleOp that works on a
@@ -427,22 +343,13 @@ struct LinearizeVectorInsert final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     VectorType dstTy = getTypeConverter()->convertType<VectorType>(
         insertOp.getDestVectorType());
     assert(dstTy && "vector type destination expected.");
-    if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
-      return rewriter.notifyMatchFailure(insertOp,
-                                         "scalable vectors are not supported.");
-
-    if (!isLessThanOrEqualTargetBitWidth(insertOp.getValueToStoreType(),
-                                         targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          insertOp, "Can't flatten since targetBitWidth < OpSize");
 
     // dynamic position is not supported
     if (insertOp.hasDynamicPosition())
@@ -471,11 +378,11 @@ struct LinearizeVectorInsert final
     }
 
     llvm::SmallVector<int64_t, 2> indices(dstSize);
-    auto origValsUntil = indices.begin();
+    auto *origValsUntil = indices.begin();
     std::advance(origValsUntil, linearizedOffset);
     std::iota(indices.begin(), origValsUntil,
               0); // original values that remain [0, offset)
-    auto newValsUntil = origValsUntil;
+    auto *newValsUntil = origValsUntil;
     std::advance(newValsUntil, srcSize);
     std::iota(origValsUntil, newValsUntil,
               dstSize); // new values [offset, offset+srcNumElements)
@@ -488,9 +395,6 @@ struct LinearizeVectorInsert final
 
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the BitCastOp that works on nD (n > 1)
@@ -508,82 +412,111 @@ struct LinearizeVectorBitCast final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = castOp.getLoc();
     auto resType = getTypeConverter()->convertType(castOp.getType());
-    if (!resType)
-      return rewriter.notifyMatchFailure(loc, "can't convert return type.");
-
-    if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          loc, "Can't flatten since targetBitWidth <= OpSize");
-
+    assert(resType && "expected 1-D vector type");
     rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
                                                    adaptor.getSource());
     return mlir::success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 } // namespace
 
-void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth) {
+/// Return true if the operation `op` does not support scalable vectors and
+/// has at least 1 scalable vector result. These ops should all eventually
+/// support scalable vectors, and this function should be removed.
+static bool isNotLinearizableBecauseScalable(Operation *op) {
+
+  bool unsupported =
+      isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
+          op);
+  if (!unsupported)
+    return false;
+
+  // Check if any of the results is a scalable vector type.
+  auto types = op->getResultTypes();
+  bool containsScalableResult =
+      std::any_of(types.begin(), types.end(), [](Type type) {
+        auto vecType = dyn_cast<VectorType>(type);
+        return vecType && vecType.isScalable();
+      });
+
+  return containsScalableResult;
+}
+
+static bool isNotLinearizable(Operation *op) {
+
+  // Only ops that are in the vector dialect, are ConstantLike, or
+  // are Vectorizable might be linearized currently.
+  StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
+  StringRef opDialect = op->getDialect()->getNamespace();
+  bool unsupported = (opDialect != vectorDialect) &&
+                     !op->hasTrait<OpTrait::ConstantLike>() &&
+                     !op->hasTrait<OpTrait::Vectorizable>();
+  if (unsupported)
+    return true;
+
+  // Some ops currently don't support scalable vectors.
+  if (isNotLinearizableBecauseScalable(op))
+    return true;
+
+  return false;
+}
 
-  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
-    if (!isLinearizableVector(type))
+void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
+                                              ConversionTarget &target) {
+
+  auto convertType = [](Type type) -> std::optional<Type> {
+    VectorType vectorType = dyn_cast<VectorType>(type);
+    if (!vectorType || !isLinearizableVector(vectorType))
       return type;
 
-    return VectorType::get(type.getNumElements(), type.getElementType(),
-                           type.isScalable());
-  });
+    VectorType linearizedType =
+        VectorType::get(vectorType.getNumElements(),
+                        vectorType.getElementType(), vectorType.isScalable());
+    return linearizedType;
+  };
+  typeConverter.addConversion(convertType);
 
   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
                             Location loc) -> Value {
-    if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
-        !isa<VectorType>(type))
+    if (inputs.size() != 1)
+      return nullptr;
+
+    Value value = inputs.front();
+    if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
       return nullptr;
 
-    return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
+    return builder.create<vector::ShapeCastOp>(loc, type, value);
   };
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
+
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp>(op) ||
-             op->hasTrait<OpTrait::ConstantLike>() ||
-             op->hasTrait<OpTrait::Vectorizable>())) {
-          return (isLessThanTargetBitWidth(op, targetBitWidth)
-                      ? typeConverter.isLegal(op)
-                      : true);
-        }
-        return std::nullopt;
+        if (isNotLinearizable(op))
+          return true;
+        // This will return true if, for all operand and result types `t`,
+        // convertType(t) = t. This is true if there are no rank>=2 vectors.
+        return typeConverter.isLegal(op);
       });
+}
 
+void mlir::vector::populateVectorLinearizeBasePatterns(
+    const TypeConverter &typeConverter, const ConversionTarget &target,
+    RewritePatternSet &patterns) {
   patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
-                                       targetBitWidth);
+               LinearizeVectorBitCast>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
-    const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned int targetBitWidth) {
-  target.addDynamicallyLegalOp<vector::ShuffleOp>(
-      [=](vector::ShuffleOp shuffleOp) -> bool {
-        return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
-                   ? (typeConverter.isLegal(shuffleOp) &&
-                      cast<mlir::VectorType>(shuffleOp.getResult().getType())
-                              .getRank() == 1)
-                   : true;
-      });
+    const TypeConverter &typeConverter, const ConversionTarget &target,
+    RewritePatternSet &patterns) {
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
-      typeConverter, patterns.getContext(), targetBitWidth);
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..06eaf58b225ae 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128  -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
+
+// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=128  -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
+// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
 
 // ALL-LABEL: test_linearize
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -97,7 +98,7 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
 
 // ALL-LABEL: test_index_no_linearize
 func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
-    // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+    // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
     %0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
     return %0 : vector<2x2xindex>
 }
@@ -171,6 +172,7 @@ func.func @test_0d_vector() -> vector<f32> {
 }
 
 // -----
+
 // ALL-LABEL: test_extract_strided_slice_1
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
 func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
@@ -193,6 +195,8 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
   return %0 : vector<2x2xf32>
 }
 
+// -----
+
 // ALL-LABEL:   func.func @test_extract_strided_slice_1_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
 func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
@@ -205,6 +209,7 @@ func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> ve
 }
 
 // -----
+
 // ALL-LABEL: test_extract_strided_slice_2
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
 func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
@@ -228,6 +233,7 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
 }
 
 // -----
+
 // ALL-LABEL: test_vector_shuffle
 // ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
 func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
@@ -252,6 +258,7 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
 }
 
 // -----
+
 // ALL-LABEL: test_vector_extract
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
 func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -273,6 +280,8 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
   return %0 : vector<8x2xf32>
 }
 
+// -----
+
 // ALL-LABEL:   func.func @test_vector_extract_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
 func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
@@ -283,7 +292,9 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x
   // ALL: return %[[RES]] : vector<8x[2]xf32>
   return %0 : vector<8x[2]xf32>
 }
+
 // -----
+
 // ALL-LABEL: test_vector_insert
 // ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
 func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -312,6 +323,8 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
   return %0 : vector<2x8x4xf32>
 }
 
+// -----
+
 // ALL-LABEL:   func.func @test_vector_insert_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<2x8x[4]xf32>, %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
 func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
@@ -385,6 +398,7 @@ func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
 }
 
 // -----
+
 // ALL-LABEL: test_vector_bitcast
 // ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32>
 func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03f907e46c2c6..eda2594fbc7c7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -7,17 +7,13 @@
 //===----------------------------------------------------------------------===//
 
 #include <optional>
-#include <type_traits>
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -840,16 +836,98 @@ struct TestVectorEmulateMaskedLoadStore final
   }
 };
 
-struct TestVectorLinearize final
-    : public PassWrapper<TestVectorLinearize, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+// TODO: move this code into the user project.
+namespace vendor {
 
-  TestVectorLinearize() = default;
-  TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+/// Get the set of operand/result types to check for sufficiently
+/// small inner-most dimension size.
+static SmallVector<std::pair<Type, unsigned>>
+getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
 
-  StringRef getArgument() const override { return "test-vector-linearize"; }
+  if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+    unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
+                     ? targetBitWidth + 1
+                     : targetBitWidth;
+    return {{insertOp.getValueToStoreType(), w}};
+  }
+
+  auto resultTypes = op->getResultTypes();
+  SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+  resultsWithBitWidth.reserve(resultTypes.size());
+  for (Type type : resultTypes) {
+    resultsWithBitWidth.push_back({type, targetBitWidth});
+  }
+  return resultsWithBitWidth;
+}
+
+/// If `type` is VectorType with trailing dimension of (bit) size greater than
+/// or equal to `targetBitWidth`, its defining op is considered legal.
+static bool
+isNotLinearizableBecauseLargeInnerDimension(Type type,
+                                            unsigned targetBitWidth) {
+
+  VectorType vecType = dyn_cast<VectorType>(type);
+
+  // Not linearizable for reasons other than what this function checks.
+  if (!vecType || vecType.getRank() == 0)
+    return false;
+
+  // The width of the type 'index' is unbounded (and therefore potentially above
+  // the target width).
+  if (vecType.getElementType().isIndex())
+    return true;
+
+  unsigned finalDimSize = vecType.getShape().back();
+  unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
+  unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
+  return trailingVecDimBitWidth >= targetBitWidth;
+}
+
+static bool
+isNotLinearizableBecauseLargeInnerDimension(Operation *op,
+                                            unsigned targetBitWidth) {
+  // Check on bitwidths.
+  SmallVector<std::pair<Type, unsigned>> toCheck =
+      getTypeBitWidthBoundPairs(op, targetBitWidth);
+  return std::any_of(toCheck.begin(), toCheck.end(),
+                     [&](std::pair<Type, unsigned> typeWidth) {
+                       return isNotLinearizableBecauseLargeInnerDimension(
+                           typeWidth.first, typeWidth.second);
+                     });
+}
+
+void populateWithBitWidthConstraints(TypeConverter &typeConverter,
+                                     ConversionTarget &target,
+                                     unsigned targetBitWidth) {
+
+  // The general purpose definition of what ops are legal must come first.
+  populateForVectorLinearize(typeConverter, target);
+
+  // Extend the set of legal ops to include those with large inner-most
+  // dimensions on selected operands/results.
+  target.markUnknownOpDynamicallyLegal(
+      [=](Operation *op) -> std::optional<bool> {
+        if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
+          return true;
+        }
+        return {};
+      });
+}
+
+struct TestVectorBitWidthLinearize final
+    : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
+
+  TestVectorBitWidthLinearize() = default;
+  TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
+      : PassWrapper(pass) {}
+
+  StringRef getArgument() const override {
+    return "test-bit-width-constrained-vector-linearize";
+  }
   StringRef getDescription() const override {
-    return "Linearizes ND vectors for N >= 2 into 1D vectors";
+    return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
+           "in inner-most dimension's bit width.";
   }
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<vector::VectorDialect>();
@@ -867,10 +945,49 @@ struct TestVectorLinearize final
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
 
-    vector::populateVectorLinearizeTypeConversionsAndLegality(
-        typeConverter, patterns, target, targetVectorBitwidth);
-    vector::populateVectorLinearizeShuffleLikeOpsPatterns(
-        typeConverter, patterns, target, targetVectorBitwidth);
+    populateWithBitWidthConstraints(typeConverter, target,
+                                    targetVectorBitwidth);
+
+    vector::populateVectorLinearizeBasePatterns(typeConverter, target,
+                                                patterns);
+
+    vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target,
+                                                          patterns);
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace vendor
+
+struct TestVectorLinearize final
+    : public PassWrapper<TestVectorLinearize, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+
+  TestVectorLinearize() = default;
+
+  StringRef getArgument() const override { return "test-vector-linearize"; }
+  StringRef getDescription() const override {
+    return "Linearizes ND vectors for N >= 2 into 1D vectors";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    TypeConverter converter;
+    RewritePatternSet patterns(&context);
+    ConversionTarget target(context);
+
+    vector::populateForVectorLinearize(converter, target);
+
+    vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
+    vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
+                                                          patterns);
+
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();
@@ -950,6 +1067,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorLinearize>();
 
+  PassRegistration<vendor::TestVectorBitWidthLinearize>();
+
   PassRegistration<TestEliminateVectorMasks>();
 }
 } // namespace test


        


More information about the Mlir-commits mailing list