[Mlir-commits] [mlir] [vector][linearize] Refactor code to push target bit width out of core code (PR #136581)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 21 11:39:22 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: James Newling (newling)
<details>
<summary>Changes</summary>
Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results of operations.
In https://github.com/llvm/llvm-project/pull/83314 an option to ignore (legalize) operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to reduce non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all).
As a follow-up to this PR, I propose that user(s) of the `targetVectorBitWidth` move `legalBecauseOfBitwidth` to their code bases, and then remove it from upstream.
The approach I've used is to move the logic pertaining to `targetVectorBitWidth` out the patterns, and into the conversion target, which the end user can control outside of core MLIR.
---
Patch is 26.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136581.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+18-12)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+130-134)
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+13)
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+9-8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ce97847172197..d9a0791cdea33 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -392,18 +392,24 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
-/// Populates patterns for ND vectors (N >= 2) linearization and sets up the
-/// provided ConversionTarget with the appropriate legality configuration for
-/// the ops to get converted properly.
-void populateVectorLinearizeTypeConversionsAndLegality(
- TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
-
-/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
-/// vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(
- const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+/// Populate `typeConverter` and `conversionTarget` with the definition of
+/// legal types and operations, for the specific case where vectors with
+/// trailing dimensions of size greater than `targetBitWidth` are legal.
+void populateVectorLinearizeBitWidthTargetAndConverter(
+ TypeConverter &typeConverter, ConversionTarget &conversionTarget,
+ unsigned targetBitWidth);
+
+/// Populates `patterns` for ND vector (N >= 2) linearization. Patterns for
+/// converting ConstantLike, Vectorizable, and vector::BitCast.
+void populateVectorLinearizeBasePatterns(const TypeConverter &,
+ RewritePatternSet &patterns,
+ const ConversionTarget &);
+
+/// Populates `patterns` for linearizing ND (N >= 2) vector operations
+/// to 1D vector shuffle operations.
+void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
+ RewritePatternSet &patterns,
+ const ConversionTarget &);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..e24c8ee961c51 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,7 +10,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -22,44 +21,16 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include <cstdint>
+#include <limits>
#include <numeric>
+#include <optional>
using namespace mlir;
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
- auto resultTypes = op->getResultTypes();
- for (auto resType : resultTypes) {
- VectorType vecType = dyn_cast<VectorType>(resType);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
- return false;
- // There are no dimension to fold if it is a 0-D vector.
- if (vecType.getRank() == 0)
- return false;
- unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
- if (trailingVecDimBitWidth >= targetBitWidth)
- return false;
- }
- return true;
-}
-
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
- VectorType vecType = dyn_cast<VectorType>(t);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
- return false;
- // There are no dimension to fold if it is a 0-D vector.
- if (vecType.getRank() == 0)
- return false;
- unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
- return trailingVecDimBitWidth <= targetBitWidth;
-}
-
static FailureOr<Attribute>
linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
VectorType resType, Attribute value) {
+
if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
if (resType.isScalable() && !isa<SplatElementsAttr>(value))
return rewriter.notifyMatchFailure(
@@ -76,16 +47,14 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
}
namespace {
+
struct LinearizeConstantLike final
: OpTraitConversionPattern<OpTrait::ConstantLike> {
using OpTraitConversionPattern::OpTraitConversionPattern;
- LinearizeConstantLike(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpTraitConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ LinearizeConstantLike(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
@@ -100,10 +69,6 @@ struct LinearizeConstantLike final
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
- if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- loc, "Can't flatten since targetBitWidth <= OpSize");
-
StringAttr attrName = rewriter.getStringAttr("value");
Attribute value = op->getAttr(attrName);
if (!value)
@@ -124,9 +89,6 @@ struct LinearizeConstantLike final
rewriter.replaceOp(op, newOp);
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
struct LinearizeVectorizable final
@@ -134,18 +96,12 @@ struct LinearizeVectorizable final
using OpTraitConversionPattern::OpTraitConversionPattern;
public:
- LinearizeVectorizable(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpTraitConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ LinearizeVectorizable(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
FailureOr<Operation *> newOp =
convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
if (failed(newOp))
@@ -154,9 +110,6 @@ struct LinearizeVectorizable final
rewriter.replaceOp(op, (*newOp)->getResults());
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
@@ -173,12 +126,10 @@ struct LinearizeVectorizable final
struct LinearizeVectorExtractStridedSlice final
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorExtractStridedSlice(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
+ MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -189,9 +140,6 @@ struct LinearizeVectorExtractStridedSlice final
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- extractOp, "Can't flatten since targetBitWidth <= OpSize");
ArrayAttr offsets = extractOp.getOffsets();
ArrayAttr sizes = extractOp.getSizes();
@@ -268,9 +216,6 @@ struct LinearizeVectorExtractStridedSlice final
extractOp, dstType, srcVector, srcVector, indices);
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the ShuffleOp that works on nD (n > 1)
@@ -291,8 +236,7 @@ struct LinearizeVectorShuffle final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -302,13 +246,12 @@ struct LinearizeVectorShuffle final
assert(dstType && "vector type destination expected.");
// The assert is used because vector.shuffle does not support scalable
// vectors.
- assert(!(shuffleOp.getV1VectorType().isScalable() ||
- shuffleOp.getV2VectorType().isScalable() ||
- dstType.isScalable()) &&
- "scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
+ bool scalable = shuffleOp.getV1VectorType().isScalable() ||
+ shuffleOp.getV2VectorType().isScalable() ||
+ dstType.isScalable();
+ if (scalable)
+ return rewriter.notifyMatchFailure(shuffleOp,
+ "scalable vectors are not supported.");
Value vec1 = adaptor.getV1();
Value vec2 = adaptor.getV2();
@@ -343,9 +286,6 @@ struct LinearizeVectorShuffle final
vec2, indices);
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the ExtractOp to a ShuffleOp that works on a
@@ -364,8 +304,7 @@ struct LinearizeVectorExtract final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -378,9 +317,6 @@ struct LinearizeVectorExtract final
cast<VectorType>(dstTy).isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- extractOp, "Can't flatten since targetBitWidth <= OpSize");
// Dynamic position is not supported.
if (extractOp.hasDynamicPosition())
@@ -405,9 +341,6 @@ struct LinearizeVectorExtract final
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the InsertOp to a ShuffleOp that works on a
@@ -427,8 +360,7 @@ struct LinearizeVectorInsert final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -439,11 +371,6 @@ struct LinearizeVectorInsert final
return rewriter.notifyMatchFailure(insertOp,
"scalable vectors are not supported.");
- if (!isLessThanOrEqualTargetBitWidth(insertOp.getValueToStoreType(),
- targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- insertOp, "Can't flatten since targetBitWidth < OpSize");
-
// dynamic position is not supported
if (insertOp.hasDynamicPosition())
return rewriter.notifyMatchFailure(insertOp,
@@ -471,11 +398,11 @@ struct LinearizeVectorInsert final
}
llvm::SmallVector<int64_t, 2> indices(dstSize);
- auto origValsUntil = indices.begin();
+ auto *origValsUntil = indices.begin();
std::advance(origValsUntil, linearizedOffset);
std::iota(indices.begin(), origValsUntil,
0); // original values that remain [0, offset)
- auto newValsUntil = origValsUntil;
+ auto *newValsUntil = origValsUntil;
std::advance(newValsUntil, srcSize);
std::iota(origValsUntil, newValsUntil,
dstSize); // new values [offset, offset+srcNumElements)
@@ -488,9 +415,6 @@ struct LinearizeVectorInsert final
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the BitCastOp that works on nD (n > 1)
@@ -508,8 +432,7 @@ struct LinearizeVectorBitCast final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -518,24 +441,103 @@ struct LinearizeVectorBitCast final
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type.");
- if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- loc, "Can't flatten since targetBitWidth <= OpSize");
-
rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
adaptor.getSource());
return mlir::success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
} // namespace
-void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
- TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth) {
+/// If `type` is VectorType with trailing dimension of (bit) size greater than
+/// or equal to `targetBitWidth`, its defining op is considered legal.
+static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
+
+ VectorType vecType = dyn_cast<VectorType>(type);
+
+ if (!vecType)
+ return true;
+
+ // The width of the type 'index' is unbounded (and therefore potentially above
+ // the target width).
+ if (vecType.getElementType().isIndex())
+ return true;
+
+ unsigned finalDimSize =
+ vecType.getRank() == 0 ? 0 : vecType.getShape().back();
+
+ unsigned trailingVecDimBitWidth =
+ finalDimSize * vecType.getElementTypeBitWidth();
+
+ return trailingVecDimBitWidth >= targetBitWidth;
+}
+
+static SmallVector<std::pair<Type, unsigned>>
+getChecksForBitwidth(Operation *op, unsigned targetBitWidth) {
+
+ if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+ auto w = targetBitWidth < std::numeric_limits<unsigned>::max()
+ ? targetBitWidth + 1
+ : targetBitWidth;
+ return {{insertOp.getValueToStoreType(), w}};
+ }
+ auto resultTypes = op->getResultTypes();
+ SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+ resultsWithBitWidth.reserve(resultTypes.size());
+ for (Type type : resultTypes) {
+ resultsWithBitWidth.push_back({type, targetBitWidth});
+ }
+ return resultsWithBitWidth;
+}
+
+/// Return true if the operation `op` does not support scalable vectors and
+/// has at least 1 scalable vector result.
+static bool legalBecauseScalable(Operation *op) {
+
+ bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
+ op->hasTrait<OpTrait::Vectorizable>() ||
+ isa<vector::BitCastOp>(op);
+
+ if (scalableSupported)
+ return false;
+
+ // Check if any of the results is a scalable vector type.
+ auto types = op->getResultTypes();
+ bool containsScalableResult =
+ std::any_of(types.begin(), types.end(), [](Type type) {
+ auto vecType = dyn_cast<VectorType>(type);
+ return vecType && vecType.isScalable();
+ });
+
+ return containsScalableResult;
+}
+
+static bool dynamicallyLegal(Operation *op, unsigned targetBitWidth) {
+
+ // Only ops that are in the vector dialect, are ConstantLike, or
+ // are Vectorizable might be linearized currently, so legalize the others.
+ bool opIsVectorDialect = op->getDialect()->getNamespace() ==
+ vector::VectorDialect::getDialectNamespace();
+ if (!opIsVectorDialect && !op->hasTrait<OpTrait::ConstantLike>() &&
+ !op->hasTrait<OpTrait::Vectorizable>())
+ return true;
+
+ // Some ops will not be linearized if they have scalable vector results.
+ if (legalBecauseScalable(op))
+ return true;
+
+ // Check on bitwidths.
+ auto typesToCheck = getChecksForBitwidth(op, targetBitWidth);
+ return std::any_of(typesToCheck.begin(), typesToCheck.end(),
+ [&](std::pair<Type, unsigned> typeWidth) {
+ return legalBecauseOfBitwidth(typeWidth.first,
+ typeWidth.second);
+ });
+}
+
+void mlir::vector::populateVectorLinearizeBitWidthTargetAndConverter(
+ TypeConverter &typeConverter, ConversionTarget &target,
+ unsigned targetBitWidth) {
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
@@ -550,40 +552,34 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
!isa<VectorType>(type))
return nullptr;
-
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
};
+
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
+
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<vector::BitCastOp>(op) ||
- op->hasTrait<OpTrait::ConstantLike>() ||
- op->hasTrait<OpTrait::Vectorizable>())) {
- return (isLessThanTargetBitWidth(op, targetBitWidth)
- ? typeConverter.isLegal(op)
- : true);
- }
- return std::nullopt;
+ bool isDynamicallyLegal = dynamicallyLegal(op, targetBitWidth);
+ if (isDynamicallyLegal)
+ return true;
+
+ bool shapeUnchanged = typeConverter.isLegal(op);
+ return shapeUnchanged;
});
+}
+void mlir::vector::populateVectorLinearizeBasePatterns(
+ const TypeC...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/136581
More information about the Mlir-commits
mailing list