[Mlir-commits] [mlir] [vector][linearize] Refactor code to push target bit width out of patterns (PR #136581)
James Newling
llvmlistbot at llvm.org
Tue Apr 22 13:32:43 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/136581
>From b523a5a654c9889e30f3f822af1cc30257aeb107 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Apr 2025 09:32:22 -0700
Subject: [PATCH 1/4] factorize out the logic about bounded bitwidth
---
.../Vector/Transforms/VectorRewritePatterns.h | 30 +-
.../Vector/Transforms/VectorLinearize.cpp | 264 +++++++++---------
mlir/test/Dialect/Vector/linearize.mlir | 13 +
.../Dialect/Vector/TestVectorTransforms.cpp | 17 +-
4 files changed, 170 insertions(+), 154 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ce97847172197..d9a0791cdea33 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -392,18 +392,24 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
-/// Populates patterns for ND vectors (N >= 2) linearization and sets up the
-/// provided ConversionTarget with the appropriate legality configuration for
-/// the ops to get converted properly.
-void populateVectorLinearizeTypeConversionsAndLegality(
- TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
-
-/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
-/// vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(
- const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+/// Populate `typeConverter` and `conversionTarget` with the definition of
+/// legal types and operations, for the specific case where vectors with
+/// trailing dimensions of size greater than `targetBitWidth` are legal.
+void populateVectorLinearizeBitWidthTargetAndConverter(
+ TypeConverter &typeConverter, ConversionTarget &conversionTarget,
+ unsigned targetBitWidth);
+
+/// Populates `patterns` for ND vector (N >= 2) linearization. Patterns for
+/// converting ConstantLike, Vectorizable, and vector::BitCast.
+void populateVectorLinearizeBasePatterns(const TypeConverter &,
+ RewritePatternSet &patterns,
+ const ConversionTarget &);
+
+/// Populates `patterns` for linearizing ND (N >= 2) vector operations
+/// to 1D vector shuffle operations.
+void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
+ RewritePatternSet &patterns,
+ const ConversionTarget &);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..3a80ce815b766 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,7 +10,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -22,44 +21,16 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include <cstdint>
+#include <limits>
#include <numeric>
+#include <optional>
using namespace mlir;
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
- auto resultTypes = op->getResultTypes();
- for (auto resType : resultTypes) {
- VectorType vecType = dyn_cast<VectorType>(resType);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
- return false;
- // There are no dimension to fold if it is a 0-D vector.
- if (vecType.getRank() == 0)
- return false;
- unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
- if (trailingVecDimBitWidth >= targetBitWidth)
- return false;
- }
- return true;
-}
-
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
- VectorType vecType = dyn_cast<VectorType>(t);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
- return false;
- // There are no dimension to fold if it is a 0-D vector.
- if (vecType.getRank() == 0)
- return false;
- unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
- return trailingVecDimBitWidth <= targetBitWidth;
-}
-
static FailureOr<Attribute>
linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
VectorType resType, Attribute value) {
+
if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
if (resType.isScalable() && !isa<SplatElementsAttr>(value))
return rewriter.notifyMatchFailure(
@@ -76,16 +47,14 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
}
namespace {
+
struct LinearizeConstantLike final
: OpTraitConversionPattern<OpTrait::ConstantLike> {
using OpTraitConversionPattern::OpTraitConversionPattern;
- LinearizeConstantLike(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpTraitConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ LinearizeConstantLike(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
@@ -100,10 +69,6 @@ struct LinearizeConstantLike final
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
- if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- loc, "Can't flatten since targetBitWidth <= OpSize");
-
StringAttr attrName = rewriter.getStringAttr("value");
Attribute value = op->getAttr(attrName);
if (!value)
@@ -124,9 +89,6 @@ struct LinearizeConstantLike final
rewriter.replaceOp(op, newOp);
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
struct LinearizeVectorizable final
@@ -134,18 +96,12 @@ struct LinearizeVectorizable final
using OpTraitConversionPattern::OpTraitConversionPattern;
public:
- LinearizeVectorizable(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpTraitConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ LinearizeVectorizable(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
FailureOr<Operation *> newOp =
convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
if (failed(newOp))
@@ -154,9 +110,6 @@ struct LinearizeVectorizable final
rewriter.replaceOp(op, (*newOp)->getResults());
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
@@ -173,12 +126,10 @@ struct LinearizeVectorizable final
struct LinearizeVectorExtractStridedSlice final
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorExtractStridedSlice(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
+ MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -189,9 +140,6 @@ struct LinearizeVectorExtractStridedSlice final
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- extractOp, "Can't flatten since targetBitWidth <= OpSize");
ArrayAttr offsets = extractOp.getOffsets();
ArrayAttr sizes = extractOp.getSizes();
@@ -268,9 +216,6 @@ struct LinearizeVectorExtractStridedSlice final
extractOp, dstType, srcVector, srcVector, indices);
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the ShuffleOp that works on nD (n > 1)
@@ -291,8 +236,7 @@ struct LinearizeVectorShuffle final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -302,13 +246,12 @@ struct LinearizeVectorShuffle final
assert(dstType && "vector type destination expected.");
// The assert is used because vector.shuffle does not support scalable
// vectors.
- assert(!(shuffleOp.getV1VectorType().isScalable() ||
- shuffleOp.getV2VectorType().isScalable() ||
- dstType.isScalable()) &&
- "scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
+ bool scalable = shuffleOp.getV1VectorType().isScalable() ||
+ shuffleOp.getV2VectorType().isScalable() ||
+ dstType.isScalable();
+ if (scalable)
+ return rewriter.notifyMatchFailure(shuffleOp,
+ "scalable vectors are not supported.");
Value vec1 = adaptor.getV1();
Value vec2 = adaptor.getV2();
@@ -343,9 +286,6 @@ struct LinearizeVectorShuffle final
vec2, indices);
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the ExtractOp to a ShuffleOp that works on a
@@ -364,8 +304,7 @@ struct LinearizeVectorExtract final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -378,9 +317,6 @@ struct LinearizeVectorExtract final
cast<VectorType>(dstTy).isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- extractOp, "Can't flatten since targetBitWidth <= OpSize");
// Dynamic position is not supported.
if (extractOp.hasDynamicPosition())
@@ -405,9 +341,6 @@ struct LinearizeVectorExtract final
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the InsertOp to a ShuffleOp that works on a
@@ -427,8 +360,7 @@ struct LinearizeVectorInsert final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -439,11 +371,6 @@ struct LinearizeVectorInsert final
return rewriter.notifyMatchFailure(insertOp,
"scalable vectors are not supported.");
- if (!isLessThanOrEqualTargetBitWidth(insertOp.getValueToStoreType(),
- targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- insertOp, "Can't flatten since targetBitWidth < OpSize");
-
// dynamic position is not supported
if (insertOp.hasDynamicPosition())
return rewriter.notifyMatchFailure(insertOp,
@@ -471,11 +398,11 @@ struct LinearizeVectorInsert final
}
llvm::SmallVector<int64_t, 2> indices(dstSize);
- auto origValsUntil = indices.begin();
+ auto *origValsUntil = indices.begin();
std::advance(origValsUntil, linearizedOffset);
std::iota(indices.begin(), origValsUntil,
0); // original values that remain [0, offset)
- auto newValsUntil = origValsUntil;
+ auto *newValsUntil = origValsUntil;
std::advance(newValsUntil, srcSize);
std::iota(origValsUntil, newValsUntil,
dstSize); // new values [offset, offset+srcNumElements)
@@ -488,9 +415,6 @@ struct LinearizeVectorInsert final
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the BitCastOp that works on nD (n > 1)
@@ -508,8 +432,7 @@ struct LinearizeVectorBitCast final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -518,24 +441,103 @@ struct LinearizeVectorBitCast final
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type.");
- if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- loc, "Can't flatten since targetBitWidth <= OpSize");
-
rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
adaptor.getSource());
return mlir::success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
} // namespace
-void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
- TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth) {
+/// If `type` is VectorType with trailing dimension of (bit) size greater than
+/// or equal to `targetBitWidth`, its defining op is considered legal.
+static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
+
+ VectorType vecType = dyn_cast<VectorType>(type);
+
+ if (!vecType)
+ return true;
+
+ // The width of the type 'index' is unbounded (and therefore potentially above
+ // the target width).
+ if (vecType.getElementType().isIndex())
+ return true;
+
+ unsigned finalDimSize =
+ vecType.getRank() == 0 ? 0 : vecType.getShape().back();
+
+ unsigned trailingVecDimBitWidth =
+ finalDimSize * vecType.getElementTypeBitWidth();
+
+ return trailingVecDimBitWidth >= targetBitWidth;
+}
+
+static SmallVector<std::pair<Type, unsigned>>
+getChecksForBitwidth(Operation *op, unsigned targetBitWidth) {
+
+ if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+ auto w = targetBitWidth < std::numeric_limits<unsigned>::max()
+ ? targetBitWidth + 1
+ : targetBitWidth;
+ return {{insertOp.getValueToStoreType(), w}};
+ }
+ auto resultTypes = op->getResultTypes();
+ SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+ resultsWithBitWidth.reserve(resultTypes.size());
+ for (Type type : resultTypes) {
+ resultsWithBitWidth.push_back({type, targetBitWidth});
+ }
+ return resultsWithBitWidth;
+}
+
+/// Return true if the operation `op` does not support scalable vectors and
+/// has at least 1 scalable vector result.
+static bool legalBecauseScalable(Operation *op) {
+
+ bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
+ op->hasTrait<OpTrait::Vectorizable>() ||
+ isa<vector::BitCastOp>(op);
+
+ if (scalableSupported)
+ return false;
+
+ // Check if any of the results is a scalable vector type.
+ auto types = op->getResultTypes();
+ bool containsScalableResult =
+ std::any_of(types.begin(), types.end(), [](Type type) {
+ auto vecType = dyn_cast<VectorType>(type);
+ return vecType && vecType.isScalable();
+ });
+
+ return containsScalableResult;
+}
+
+static bool dynamicallyLegal(Operation *op, unsigned targetBitWidth) {
+
+ // Only ops that are in the vector dialect, are ConstantLike, or
+ // are Vectorizable might be linearized currently, so legalize the others.
+ bool opIsVectorDialect = op->getDialect()->getNamespace() ==
+ vector::VectorDialect::getDialectNamespace();
+ if (!opIsVectorDialect && !op->hasTrait<OpTrait::ConstantLike>() &&
+ !op->hasTrait<OpTrait::Vectorizable>())
+ return true;
+
+ // Some ops will not be linearized if they have scalable vector results.
+ if (legalBecauseScalable(op))
+ return true;
+
+ // Check on bitwidths.
+ auto typesToCheck = getChecksForBitwidth(op, targetBitWidth);
+ return std::any_of(typesToCheck.begin(), typesToCheck.end(),
+ [&](std::pair<Type, unsigned> typeWidth) {
+ return legalBecauseOfBitwidth(typeWidth.first,
+ typeWidth.second);
+ });
+}
+
+void mlir::vector::populateVectorLinearizeBitWidthTargetAndConverter(
+ TypeConverter &typeConverter, ConversionTarget &target,
+ unsigned targetBitWidth) {
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
@@ -550,40 +552,34 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
!isa<VectorType>(type))
return nullptr;
-
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
};
+
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
+
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<vector::BitCastOp>(op) ||
- op->hasTrait<OpTrait::ConstantLike>() ||
- op->hasTrait<OpTrait::Vectorizable>())) {
- return (isLessThanTargetBitWidth(op, targetBitWidth)
- ? typeConverter.isLegal(op)
- : true);
- }
- return std::nullopt;
+ bool isDynamicallyLegal = dynamicallyLegal(op, targetBitWidth);
+ if (isDynamicallyLegal)
+ return true;
+
+ bool shapeUnchanged = typeConverter.isLegal(op);
+ return shapeUnchanged;
});
+}
+void mlir::vector::populateVectorLinearizeBasePatterns(
+ const TypeConverter &typeConverter, RewritePatternSet &patterns,
+ const ConversionTarget &target) {
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
- targetBitWidth);
+ LinearizeVectorBitCast>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned int targetBitWidth) {
- target.addDynamicallyLegalOp<vector::ShuffleOp>(
- [=](vector::ShuffleOp shuffleOp) -> bool {
- return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
- ? (typeConverter.isLegal(shuffleOp) &&
- cast<mlir::VectorType>(shuffleOp.getResult().getType())
- .getRank() == 1)
- : true;
- });
+ const ConversionTarget &target) {
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..76eb93e98599e 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -171,6 +171,7 @@ func.func @test_0d_vector() -> vector<f32> {
}
// -----
+
// ALL-LABEL: test_extract_strided_slice_1
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
@@ -193,6 +194,8 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
return %0 : vector<2x2xf32>
}
+// -----
+
// ALL-LABEL: func.func @test_extract_strided_slice_1_scalable(
// ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
@@ -205,6 +208,7 @@ func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> ve
}
// -----
+
// ALL-LABEL: test_extract_strided_slice_2
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
@@ -228,6 +232,7 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
}
// -----
+
// ALL-LABEL: test_vector_shuffle
// ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
@@ -252,6 +257,7 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
}
// -----
+
// ALL-LABEL: test_vector_extract
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -273,6 +279,8 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
return %0 : vector<8x2xf32>
}
+// -----
+
// ALL-LABEL: func.func @test_vector_extract_scalable(
// ALL-SAME: %[[VAL_0:.*]]: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
@@ -283,7 +291,9 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x
// ALL: return %[[RES]] : vector<8x[2]xf32>
return %0 : vector<8x[2]xf32>
}
+
// -----
+
// ALL-LABEL: test_vector_insert
// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -312,6 +322,8 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
return %0 : vector<2x8x4xf32>
}
+// -----
+
// ALL-LABEL: func.func @test_vector_insert_scalable(
// ALL-SAME: %[[VAL_0:.*]]: vector<2x8x[4]xf32>, %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
@@ -385,6 +397,7 @@ func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
}
// -----
+
// ALL-LABEL: test_vector_bitcast
// ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32>
func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..7d40a416e4128 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -7,17 +7,13 @@
//===----------------------------------------------------------------------===//
#include <optional>
-#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -866,10 +862,15 @@ struct TestVectorLinearize final
RewritePatternSet patterns(context);
ConversionTarget target(*context);
- vector::populateVectorLinearizeTypeConversionsAndLegality(
- typeConverter, patterns, target, targetVectorBitwidth);
- vector::populateVectorLinearizeShuffleLikeOpsPatterns(
- typeConverter, patterns, target, targetVectorBitwidth);
+ vector::populateVectorLinearizeBitWidthTargetAndConverter(
+ typeConverter, target, targetVectorBitwidth);
+
+ vector::populateVectorLinearizeBasePatterns(typeConverter, patterns,
+ target);
+
+ vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter,
+ patterns, target);
+
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
>From 86ceb57b31dffd6a433ac11cf748dc4d5b1ea2e7 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Apr 2025 10:49:23 -0700
Subject: [PATCH 2/4] clang-format
---
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 3a80ce815b766..e24c8ee961c51 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -450,7 +450,7 @@ struct LinearizeVectorBitCast final
} // namespace
/// If `type` is VectorType with trailing dimension of (bit) size greater than
-/// or equal to `targetBitWidth`, its defining op is considered legal.
+/// or equal to `targetBitWidth`, its defining op is considered legal.
static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
VectorType vecType = dyn_cast<VectorType>(type);
>From be48849486b1c1ae68568dee941acc2bc7d49951 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 22 Apr 2025 13:26:11 -0700
Subject: [PATCH 3/4] push further with the separation of concerns
---
.../Vector/Transforms/VectorRewritePatterns.h | 31 ++--
.../Vector/Transforms/VectorLinearize.cpp | 167 ++++++------------
mlir/test/Dialect/Vector/linearize.mlir | 7 +-
.../Dialect/Vector/TestVectorTransforms.cpp | 144 +++++++++++++--
4 files changed, 205 insertions(+), 144 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index d9a0791cdea33..91f77307ddf8b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -392,24 +392,29 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
-/// Populate `typeConverter` and `conversionTarget` with the definition of
-/// legal types and operations, for the specific case where vectors with
-/// trailing dimensions of size greater than `targetBitWidth` are legal.
-void populateVectorLinearizeBitWidthTargetAndConverter(
- TypeConverter &typeConverter, ConversionTarget &conversionTarget,
- unsigned targetBitWidth);
-
-/// Populates `patterns` for ND vector (N >= 2) linearization. Patterns for
-/// converting ConstantLike, Vectorizable, and vector::BitCast.
+/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
+/// This registers (1) which operations are legal and hence should not be
+/// linearized, (2) what converted types are (rank-1 vectors) and how to
+/// materialze the conversion (with shape_cast)
+///
+/// Note: the set of legal operations can be extended by a user if for example
+/// certain rank>1 vectors are considered valid, but adding additional
+/// dynamically legal ops to `conversionTarget`.
+void populateForVectorLinearize(TypeConverter &typeConverter,
+ ConversionTarget &conversionTarget);
+
+/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
+/// contains patterns for converting ConstantLike, Vectorizable, and
+/// vector::BitCast ops.
void populateVectorLinearizeBasePatterns(const TypeConverter &,
- RewritePatternSet &patterns,
- const ConversionTarget &);
+ const ConversionTarget &,
+ RewritePatternSet &patterns);
/// Populates `patterns` for linearizing ND (N >= 2) vector operations
/// to 1D vector shuffle operations.
void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
- RewritePatternSet &patterns,
- const ConversionTarget &);
+ const ConversionTarget &,
+ RewritePatternSet &patterns);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index e24c8ee961c51..67e15852dc5ea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -62,12 +62,10 @@ struct LinearizeConstantLike final
if (op->getNumResults() != 1)
return rewriter.notifyMatchFailure(loc, "expected 1 result");
- const TypeConverter &converter = *getTypeConverter();
+ const TypeConverter &typeConverter = *getTypeConverter();
auto resType =
- converter.convertType<VectorType>(op->getResult(0).getType());
-
- if (!resType)
- return rewriter.notifyMatchFailure(loc, "can't convert return type");
+ typeConverter.convertType<VectorType>(op->getResult(0).getType());
+ assert(resType && "expected 1-D vector type");
StringAttr attrName = rewriter.getStringAttr("value");
Attribute value = op->getAttr(attrName);
@@ -80,7 +78,7 @@ struct LinearizeConstantLike final
return failure();
FailureOr<Operation *> convertResult =
- convertOpResultTypes(op, /*operands=*/{}, converter, rewriter);
+ convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter);
if (failed(convertResult))
return failure();
@@ -244,14 +242,6 @@ struct LinearizeVectorShuffle final
VectorType dstType =
getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
assert(dstType && "vector type destination expected.");
- // The assert is used because vector.shuffle does not support scalable
- // vectors.
- bool scalable = shuffleOp.getV1VectorType().isScalable() ||
- shuffleOp.getV2VectorType().isScalable() ||
- dstType.isScalable();
- if (scalable)
- return rewriter.notifyMatchFailure(shuffleOp,
- "scalable vectors are not supported.");
Value vec1 = adaptor.getV1();
Value vec2 = adaptor.getV2();
@@ -270,7 +260,7 @@ struct LinearizeVectorShuffle final
}
// For each value in the mask, we generate the indices of the source vectors
- // that needs to be shuffled to the destination vector. If shuffleSliceLen >
+ // that need to be shuffled to the destination vector. If shuffleSliceLen >
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
// elements) instead of scalars.
ArrayRef<int64_t> mask = shuffleOp.getMask();
@@ -309,14 +299,7 @@ struct LinearizeVectorExtract final
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
- if (!dstTy)
- return rewriter.notifyMatchFailure(extractOp,
- "expected n-D vector type.");
-
- if (extractOp.getVector().getType().isScalable() ||
- cast<VectorType>(dstTy).isScalable())
- return rewriter.notifyMatchFailure(extractOp,
- "scalable vectors are not supported.");
+ assert(dstTy && "expected 1-D vector type");
// Dynamic position is not supported.
if (extractOp.hasDynamicPosition())
@@ -367,9 +350,6 @@ struct LinearizeVectorInsert final
VectorType dstTy = getTypeConverter()->convertType<VectorType>(
insertOp.getDestVectorType());
assert(dstTy && "vector type destination expected.");
- if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
- return rewriter.notifyMatchFailure(insertOp,
- "scalable vectors are not supported.");
// dynamic position is not supported
if (insertOp.hasDynamicPosition())
@@ -436,11 +416,8 @@ struct LinearizeVectorBitCast final
LogicalResult
matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = castOp.getLoc();
auto resType = getTypeConverter()->convertType(castOp.getType());
- if (!resType)
- return rewriter.notifyMatchFailure(loc, "can't convert return type.");
-
+ assert(resType && "expected 1-D vector type");
rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
adaptor.getSource());
return mlir::success();
@@ -449,56 +426,15 @@ struct LinearizeVectorBitCast final
} // namespace
-/// If `type` is VectorType with trailing dimension of (bit) size greater than
-/// or equal to `targetBitWidth`, its defining op is considered legal.
-static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
-
- VectorType vecType = dyn_cast<VectorType>(type);
-
- if (!vecType)
- return true;
-
- // The width of the type 'index' is unbounded (and therefore potentially above
- // the target width).
- if (vecType.getElementType().isIndex())
- return true;
-
- unsigned finalDimSize =
- vecType.getRank() == 0 ? 0 : vecType.getShape().back();
-
- unsigned trailingVecDimBitWidth =
- finalDimSize * vecType.getElementTypeBitWidth();
-
- return trailingVecDimBitWidth >= targetBitWidth;
-}
-
-static SmallVector<std::pair<Type, unsigned>>
-getChecksForBitwidth(Operation *op, unsigned targetBitWidth) {
-
- if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
- auto w = targetBitWidth < std::numeric_limits<unsigned>::max()
- ? targetBitWidth + 1
- : targetBitWidth;
- return {{insertOp.getValueToStoreType(), w}};
- }
- auto resultTypes = op->getResultTypes();
- SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
- resultsWithBitWidth.reserve(resultTypes.size());
- for (Type type : resultTypes) {
- resultsWithBitWidth.push_back({type, targetBitWidth});
- }
- return resultsWithBitWidth;
-}
-
/// Return true if the operation `op` does not support scalable vectors and
-/// has at least 1 scalable vector result.
-static bool legalBecauseScalable(Operation *op) {
-
- bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
- op->hasTrait<OpTrait::Vectorizable>() ||
- isa<vector::BitCastOp>(op);
-
- if (scalableSupported)
+/// has at least 1 scalable vector result. These ops should all eventually
+/// support scalable vectors, and this function should be removed.
+static bool isNotLinearizableBecauseScalable(Operation *op) {
+
+ bool unsupported =
+ isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
+ op);
+ if (!unsupported)
return false;
// Check if any of the results is a scalable vector type.
@@ -512,73 +448,74 @@ static bool legalBecauseScalable(Operation *op) {
return containsScalableResult;
}
-static bool dynamicallyLegal(Operation *op, unsigned targetBitWidth) {
+static bool isNotLinearizable(Operation *op) {
// Only ops that are in the vector dialect, are ConstantLike, or
- // are Vectorizable might be linearized currently, so legalize the others.
- bool opIsVectorDialect = op->getDialect()->getNamespace() ==
- vector::VectorDialect::getDialectNamespace();
- if (!opIsVectorDialect && !op->hasTrait<OpTrait::ConstantLike>() &&
- !op->hasTrait<OpTrait::Vectorizable>())
+ // are Vectorizable might be linearized currently.
+ StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
+ StringRef opDialect = op->getDialect()->getNamespace();
+ bool unsupported = (opDialect != vectorDialect) &&
+ !op->hasTrait<OpTrait::ConstantLike>() &&
+ !op->hasTrait<OpTrait::Vectorizable>();
+ if (unsupported)
return true;
- // Some ops will not be linearized if they have scalable vector results.
- if (legalBecauseScalable(op))
+ // Some ops currently don't support scalable vectors.
+ if (isNotLinearizableBecauseScalable(op))
return true;
- // Check on bitwidths.
- auto typesToCheck = getChecksForBitwidth(op, targetBitWidth);
- return std::any_of(typesToCheck.begin(), typesToCheck.end(),
- [&](std::pair<Type, unsigned> typeWidth) {
- return legalBecauseOfBitwidth(typeWidth.first,
- typeWidth.second);
- });
+ return false;
}
-void mlir::vector::populateVectorLinearizeBitWidthTargetAndConverter(
- TypeConverter &typeConverter, ConversionTarget &target,
- unsigned targetBitWidth) {
+void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
+ ConversionTarget &target) {
- typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
- if (!isLinearizableVector(type))
+ auto convertType = [](Type type) -> std::optional<Type> {
+ VectorType vectorType = dyn_cast<VectorType>(type);
+ if (!vectorType || !isLinearizableVector(vectorType))
return type;
- return VectorType::get(type.getNumElements(), type.getElementType(),
- type.isScalable());
- });
+ VectorType linearizedType =
+ VectorType::get(vectorType.getNumElements(),
+ vectorType.getElementType(), vectorType.isScalable());
+ return linearizedType;
+ };
+ typeConverter.addConversion(convertType);
auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
- if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
- !isa<VectorType>(type))
+ if (inputs.size() != 1)
return nullptr;
- return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
- };
+ Value value = inputs.front();
+ if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
+ return nullptr;
+
+ return builder.create<vector::ShapeCastOp>(loc, type, value);
+ };
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- bool isDynamicallyLegal = dynamicallyLegal(op, targetBitWidth);
- if (isDynamicallyLegal)
+ if (isNotLinearizable(op))
return true;
-
- bool shapeUnchanged = typeConverter.isLegal(op);
- return shapeUnchanged;
+ // This will return true if, for all operand and result types `t`,
+ // convertType(t) = t. This is true if there are no rank>=2 vectors.
+ return typeConverter.isLegal(op);
});
}
void mlir::vector::populateVectorLinearizeBasePatterns(
- const TypeConverter &typeConverter, RewritePatternSet &patterns,
- const ConversionTarget &target) {
+ const TypeConverter &typeConverter, const ConversionTarget &target,
+ RewritePatternSet &patterns) {
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
LinearizeVectorBitCast>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
- const TypeConverter &typeConverter, RewritePatternSet &patterns,
- const ConversionTarget &target) {
+ const TypeConverter &typeConverter, const ConversionTarget &target,
+ RewritePatternSet &patterns) {
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
typeConverter, patterns.getContext());
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 76eb93e98599e..b3f2dddaee356 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
+
+// RUN: mlir-opt %s -split-input-file -test-bit-width-contrained-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
+// RUN: mlir-opt %s -split-input-file -test-bit-width-contrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
// ALL-LABEL: test_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -97,7 +98,7 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
// ALL-LABEL: test_index_no_linearize
func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
- // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
return %0 : vector<2x2xindex>
}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 7d40a416e4128..ba5d82ad38585 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -835,16 +835,98 @@ struct TestVectorEmulateMaskedLoadStore final
}
};
-struct TestVectorLinearize final
- : public PassWrapper<TestVectorLinearize, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+// TODO: move this code into the user project.
+namespace vendor {
- TestVectorLinearize() = default;
- TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+/// Get the set of operand/result types to check for sufficiently
+/// small inner-most dimension size.
+static SmallVector<std::pair<Type, unsigned>>
+getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
- StringRef getArgument() const override { return "test-vector-linearize"; }
+ if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+ unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
+ ? targetBitWidth + 1
+ : targetBitWidth;
+ return {{insertOp.getValueToStoreType(), w}};
+ }
+
+ auto resultTypes = op->getResultTypes();
+ SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+ resultsWithBitWidth.reserve(resultTypes.size());
+ for (Type type : resultTypes) {
+ resultsWithBitWidth.push_back({type, targetBitWidth});
+ }
+ return resultsWithBitWidth;
+}
+
+/// If `type` is VectorType with trailing dimension of (bit) size greater than
+/// or equal to `targetBitWidth`, its defining op is considered legal.
+static bool
+isNotLinearizableBecauseLargeInnerDimension(Type type,
+ unsigned targetBitWidth) {
+
+ VectorType vecType = dyn_cast<VectorType>(type);
+
+ // Not linearizable for reasons other than what this function checks.
+ if (!vecType || vecType.getRank() == 0)
+ return false;
+
+ // The width of the type 'index' is unbounded (and therefore potentially above
+ // the target width).
+ if (vecType.getElementType().isIndex())
+ return true;
+
+ unsigned finalDimSize = vecType.getShape().back();
+ unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
+ unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
+ return trailingVecDimBitWidth >= targetBitWidth;
+}
+
+static bool
+isNotLinearizableBecauseLargeInnerDimension(Operation *op,
+ unsigned targetBitWidth) {
+ // Check on bitwidths.
+ SmallVector<std::pair<Type, unsigned>> toCheck =
+ getTypeBitWidthBoundPairs(op, targetBitWidth);
+ return std::any_of(toCheck.begin(), toCheck.end(),
+ [&](std::pair<Type, unsigned> typeWidth) {
+ return isNotLinearizableBecauseLargeInnerDimension(
+ typeWidth.first, typeWidth.second);
+ });
+}
+
+void populateWithBitWidthConstraints(TypeConverter &typeConverter,
+ ConversionTarget &target,
+ unsigned targetBitWidth) {
+
+ // The general purpose definition of what ops are legal must come first.
+ populateForVectorLinearize(typeConverter, target);
+
+ // Extend the set of legal ops to include those with large inner-most
+ // dimensions on selected operands/results.
+ target.markUnknownOpDynamicallyLegal(
+ [=](Operation *op) -> std::optional<bool> {
+ if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
+ return true;
+ }
+ return {};
+ });
+}
+
+struct TestVectorBitWidthLinearize final
+ : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
+
+ TestVectorBitWidthLinearize() = default;
+ TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
+ : PassWrapper(pass) {}
+
+ StringRef getArgument() const override {
+ return "test-bit-width-contrained-vector-linearize";
+ }
StringRef getDescription() const override {
- return "Linearizes ND vectors for N >= 2 into 1D vectors";
+ return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
+ "in inner-most dimension's bit width.";
}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<vector::VectorDialect>();
@@ -862,14 +944,48 @@ struct TestVectorLinearize final
RewritePatternSet patterns(context);
ConversionTarget target(*context);
- vector::populateVectorLinearizeBitWidthTargetAndConverter(
- typeConverter, target, targetVectorBitwidth);
+ populateWithBitWidthConstraints(typeConverter, target,
+ targetVectorBitwidth);
- vector::populateVectorLinearizeBasePatterns(typeConverter, patterns,
- target);
+ vector::populateVectorLinearizeBasePatterns(typeConverter, target,
+ patterns);
- vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter,
- patterns, target);
+ vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target,
+ patterns);
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace vendor
+
+struct TestVectorLinearize final
+ : public PassWrapper<TestVectorLinearize, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+
+ TestVectorLinearize() = default;
+
+ StringRef getArgument() const override { return "test-vector-linearize"; }
+ StringRef getDescription() const override {
+ return "Linearizes ND vectors for N >= 2 into 1D vectors";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ TypeConverter converter;
+ RewritePatternSet patterns(&context);
+ ConversionTarget target(context);
+
+ vector::populateForVectorLinearize(converter, target);
+
+ vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
+ vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
+ patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -950,6 +1066,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorLinearize>();
+ PassRegistration<vendor::TestVectorBitWidthLinearize>();
+
PassRegistration<TestEliminateVectorMasks>();
}
} // namespace test
>From 5d213711a3bd83e35f527f60ea82ec18572b9a1b Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 22 Apr 2025 13:32:32 -0700
Subject: [PATCH 4/4] clang-format
---
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ba5d82ad38585..223efbf00742b 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -945,7 +945,7 @@ struct TestVectorBitWidthLinearize final
ConversionTarget target(*context);
populateWithBitWidthConstraints(typeConverter, target,
- targetVectorBitwidth);
+ targetVectorBitwidth);
vector::populateVectorLinearizeBasePatterns(typeConverter, target,
patterns);
More information about the Mlir-commits
mailing list