[Mlir-commits] [mlir] bad8bf5 - [mlir][vector] Linearization: push 'bit width' logic out of patterns (#136581)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 30 09:05:44 PDT 2025
Author: James Newling
Date: 2025-04-30T09:05:40-07:00
New Revision: bad8bf56d3e4f107423b307f5f75564296703a76
URL: https://github.com/llvm/llvm-project/commit/bad8bf56d3e4f107423b307f5f75564296703a76
DIFF: https://github.com/llvm/llvm-project/commit/bad8bf56d3e4f107423b307f5f75564296703a76.diff
LOG: [mlir][vector] Linearization: push 'bit width' logic out of patterns (#136581)
[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.
In https://github.com/llvm/llvm-project/pull/83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).
As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!
The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
mlir/test/Dialect/Vector/linearize.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7a079dcc6affc..f1100d5cf8b68 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,18 +406,29 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
-/// Populates patterns for ND vectors (N >= 2) linearization and sets up the
-/// provided ConversionTarget with the appropriate legality configuration for
-/// the ops to get converted properly.
-void populateVectorLinearizeTypeConversionsAndLegality(
- TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
-
-/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
-/// vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(
- const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth);
+/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
+/// This registers (1) which operations are legal and hence should not be
+/// linearized, (2) what converted types are (rank-1 vectors) and how to
+/// materialze the conversion (with shape_cast)
+///
+/// Note: the set of legal operations can be extended by a user if for example
+/// certain rank>1 vectors are considered valid, but adding additional
+/// dynamically legal ops to `conversionTarget`.
+void populateForVectorLinearize(TypeConverter &typeConverter,
+ ConversionTarget &conversionTarget);
+
+/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
+/// contains patterns for converting ConstantLike, Vectorizable, and
+/// vector::BitCast ops.
+void populateVectorLinearizeBasePatterns(const TypeConverter &,
+ const ConversionTarget &,
+ RewritePatternSet &patterns);
+
+/// Populates `patterns` for linearizing ND (N >= 2) vector operations
+/// to 1D vector shuffle operations.
+void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
+ const ConversionTarget &,
+ RewritePatternSet &patterns);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..67e15852dc5ea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,7 +10,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -22,44 +21,16 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include <cstdint>
+#include <limits>
#include <numeric>
+#include <optional>
using namespace mlir;
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
- auto resultTypes = op->getResultTypes();
- for (auto resType : resultTypes) {
- VectorType vecType = dyn_cast<VectorType>(resType);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
- return false;
- // There are no dimension to fold if it is a 0-D vector.
- if (vecType.getRank() == 0)
- return false;
- unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
- if (trailingVecDimBitWidth >= targetBitWidth)
- return false;
- }
- return true;
-}
-
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
- VectorType vecType = dyn_cast<VectorType>(t);
- // Reject index since getElementTypeBitWidth will abort for Index types.
- if (!vecType || vecType.getElementType().isIndex())
- return false;
- // There are no dimension to fold if it is a 0-D vector.
- if (vecType.getRank() == 0)
- return false;
- unsigned trailingVecDimBitWidth =
- vecType.getShape().back() * vecType.getElementTypeBitWidth();
- return trailingVecDimBitWidth <= targetBitWidth;
-}
-
static FailureOr<Attribute>
linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
VectorType resType, Attribute value) {
+
if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
if (resType.isScalable() && !isa<SplatElementsAttr>(value))
return rewriter.notifyMatchFailure(
@@ -76,16 +47,14 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
}
namespace {
+
struct LinearizeConstantLike final
: OpTraitConversionPattern<OpTrait::ConstantLike> {
using OpTraitConversionPattern::OpTraitConversionPattern;
- LinearizeConstantLike(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpTraitConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ LinearizeConstantLike(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
@@ -93,16 +62,10 @@ struct LinearizeConstantLike final
if (op->getNumResults() != 1)
return rewriter.notifyMatchFailure(loc, "expected 1 result");
- const TypeConverter &converter = *getTypeConverter();
+ const TypeConverter &typeConverter = *getTypeConverter();
auto resType =
- converter.convertType<VectorType>(op->getResult(0).getType());
-
- if (!resType)
- return rewriter.notifyMatchFailure(loc, "can't convert return type");
-
- if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- loc, "Can't flatten since targetBitWidth <= OpSize");
+ typeConverter.convertType<VectorType>(op->getResult(0).getType());
+ assert(resType && "expected 1-D vector type");
StringAttr attrName = rewriter.getStringAttr("value");
Attribute value = op->getAttr(attrName);
@@ -115,7 +78,7 @@ struct LinearizeConstantLike final
return failure();
FailureOr<Operation *> convertResult =
- convertOpResultTypes(op, /*operands=*/{}, converter, rewriter);
+ convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter);
if (failed(convertResult))
return failure();
@@ -124,9 +87,6 @@ struct LinearizeConstantLike final
rewriter.replaceOp(op, newOp);
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
struct LinearizeVectorizable final
@@ -134,18 +94,12 @@ struct LinearizeVectorizable final
using OpTraitConversionPattern::OpTraitConversionPattern;
public:
- LinearizeVectorizable(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpTraitConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ LinearizeVectorizable(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
FailureOr<Operation *> newOp =
convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
if (failed(newOp))
@@ -154,9 +108,6 @@ struct LinearizeVectorizable final
rewriter.replaceOp(op, (*newOp)->getResults());
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
@@ -173,12 +124,10 @@ struct LinearizeVectorizable final
struct LinearizeVectorExtractStridedSlice final
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
- LinearizeVectorExtractStridedSlice(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
+ MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -189,9 +138,6 @@ struct LinearizeVectorExtractStridedSlice final
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- extractOp, "Can't flatten since targetBitWidth <= OpSize");
ArrayAttr offsets = extractOp.getOffsets();
ArrayAttr sizes = extractOp.getSizes();
@@ -268,9 +214,6 @@ struct LinearizeVectorExtractStridedSlice final
extractOp, dstType, srcVector, srcVector, indices);
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the ShuffleOp that works on nD (n > 1)
@@ -291,8 +234,7 @@ struct LinearizeVectorShuffle final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -300,15 +242,6 @@ struct LinearizeVectorShuffle final
VectorType dstType =
getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
assert(dstType && "vector type destination expected.");
- // The assert is used because vector.shuffle does not support scalable
- // vectors.
- assert(!(shuffleOp.getV1VectorType().isScalable() ||
- shuffleOp.getV2VectorType().isScalable() ||
- dstType.isScalable()) &&
- "scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
Value vec1 = adaptor.getV1();
Value vec2 = adaptor.getV2();
@@ -327,7 +260,7 @@ struct LinearizeVectorShuffle final
}
// For each value in the mask, we generate the indices of the source vectors
- // that needs to be shuffled to the destination vector. If shuffleSliceLen >
+ // that need to be shuffled to the destination vector. If shuffleSliceLen >
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
// elements) instead of scalars.
ArrayRef<int64_t> mask = shuffleOp.getMask();
@@ -343,9 +276,6 @@ struct LinearizeVectorShuffle final
vec2, indices);
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the ExtractOp to a ShuffleOp that works on a
@@ -364,23 +294,12 @@ struct LinearizeVectorExtract final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
- if (!dstTy)
- return rewriter.notifyMatchFailure(extractOp,
- "expected n-D vector type.");
-
- if (extractOp.getVector().getType().isScalable() ||
- cast<VectorType>(dstTy).isScalable())
- return rewriter.notifyMatchFailure(extractOp,
- "scalable vectors are not supported.");
- if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- extractOp, "Can't flatten since targetBitWidth <= OpSize");
+ assert(dstTy && "expected 1-D vector type");
// Dynamic position is not supported.
if (extractOp.hasDynamicPosition())
@@ -405,9 +324,6 @@ struct LinearizeVectorExtract final
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the InsertOp to a ShuffleOp that works on a
@@ -427,22 +343,13 @@ struct LinearizeVectorInsert final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType dstTy = getTypeConverter()->convertType<VectorType>(
insertOp.getDestVectorType());
assert(dstTy && "vector type destination expected.");
- if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
- return rewriter.notifyMatchFailure(insertOp,
- "scalable vectors are not supported.");
-
- if (!isLessThanOrEqualTargetBitWidth(insertOp.getValueToStoreType(),
- targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- insertOp, "Can't flatten since targetBitWidth < OpSize");
// dynamic position is not supported
if (insertOp.hasDynamicPosition())
@@ -471,11 +378,11 @@ struct LinearizeVectorInsert final
}
llvm::SmallVector<int64_t, 2> indices(dstSize);
- auto origValsUntil = indices.begin();
+ auto *origValsUntil = indices.begin();
std::advance(origValsUntil, linearizedOffset);
std::iota(indices.begin(), origValsUntil,
0); // original values that remain [0, offset)
- auto newValsUntil = origValsUntil;
+ auto *newValsUntil = origValsUntil;
std::advance(newValsUntil, srcSize);
std::iota(origValsUntil, newValsUntil,
dstSize); // new values [offset, offset+srcNumElements)
@@ -488,9 +395,6 @@ struct LinearizeVectorInsert final
return success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
/// This pattern converts the BitCastOp that works on nD (n > 1)
@@ -508,82 +412,111 @@ struct LinearizeVectorBitCast final
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ : OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = castOp.getLoc();
auto resType = getTypeConverter()->convertType(castOp.getType());
- if (!resType)
- return rewriter.notifyMatchFailure(loc, "can't convert return type.");
-
- if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
- return rewriter.notifyMatchFailure(
- loc, "Can't flatten since targetBitWidth <= OpSize");
-
+ assert(resType && "expected 1-D vector type");
rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
adaptor.getSource());
return mlir::success();
}
-
-private:
- unsigned targetVectorBitWidth;
};
} // namespace
-void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
- TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned targetBitWidth) {
+/// Return true if the operation `op` does not support scalable vectors and
+/// has at least 1 scalable vector result. These ops should all eventually
+/// support scalable vectors, and this function should be removed.
+static bool isNotLinearizableBecauseScalable(Operation *op) {
+
+ bool unsupported =
+ isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
+ op);
+ if (!unsupported)
+ return false;
+
+ // Check if any of the results is a scalable vector type.
+ auto types = op->getResultTypes();
+ bool containsScalableResult =
+ std::any_of(types.begin(), types.end(), [](Type type) {
+ auto vecType = dyn_cast<VectorType>(type);
+ return vecType && vecType.isScalable();
+ });
+
+ return containsScalableResult;
+}
+
+static bool isNotLinearizable(Operation *op) {
+
+ // Only ops that are in the vector dialect, are ConstantLike, or
+ // are Vectorizable might be linearized currently.
+ StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
+ StringRef opDialect = op->getDialect()->getNamespace();
+ bool unsupported = (opDialect != vectorDialect) &&
+ !op->hasTrait<OpTrait::ConstantLike>() &&
+ !op->hasTrait<OpTrait::Vectorizable>();
+ if (unsupported)
+ return true;
+
+ // Some ops currently don't support scalable vectors.
+ if (isNotLinearizableBecauseScalable(op))
+ return true;
+
+ return false;
+}
- typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
- if (!isLinearizableVector(type))
+void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
+ ConversionTarget &target) {
+
+ auto convertType = [](Type type) -> std::optional<Type> {
+ VectorType vectorType = dyn_cast<VectorType>(type);
+ if (!vectorType || !isLinearizableVector(vectorType))
return type;
- return VectorType::get(type.getNumElements(), type.getElementType(),
- type.isScalable());
- });
+ VectorType linearizedType =
+ VectorType::get(vectorType.getNumElements(),
+ vectorType.getElementType(), vectorType.isScalable());
+ return linearizedType;
+ };
+ typeConverter.addConversion(convertType);
auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
- if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
- !isa<VectorType>(type))
+ if (inputs.size() != 1)
+ return nullptr;
+
+ Value value = inputs.front();
+ if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
return nullptr;
- return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
+ return builder.create<vector::ShapeCastOp>(loc, type, value);
};
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
+
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<vector::BitCastOp>(op) ||
- op->hasTrait<OpTrait::ConstantLike>() ||
- op->hasTrait<OpTrait::Vectorizable>())) {
- return (isLessThanTargetBitWidth(op, targetBitWidth)
- ? typeConverter.isLegal(op)
- : true);
- }
- return std::nullopt;
+ if (isNotLinearizable(op))
+ return true;
+ // This will return true if, for all operand and result types `t`,
+ // convertType(t) = t. This is true if there are no rank>=2 vectors.
+ return typeConverter.isLegal(op);
});
+}
+void mlir::vector::populateVectorLinearizeBasePatterns(
+ const TypeConverter &typeConverter, const ConversionTarget &target,
+ RewritePatternSet &patterns) {
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
- targetBitWidth);
+ LinearizeVectorBitCast>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
- const TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target, unsigned int targetBitWidth) {
- target.addDynamicallyLegalOp<vector::ShuffleOp>(
- [=](vector::ShuffleOp shuffleOp) -> bool {
- return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
- ? (typeConverter.isLegal(shuffleOp) &&
- cast<mlir::VectorType>(shuffleOp.getResult().getType())
- .getRank() == 1)
- : true;
- });
+ const TypeConverter &typeConverter, const ConversionTarget &target,
+ RewritePatternSet &patterns) {
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..06eaf58b225ae 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
-// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
+
+// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
+// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
// ALL-LABEL: test_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
@@ -97,7 +98,7 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>
// ALL-LABEL: test_index_no_linearize
func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
- // ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
+ // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
return %0 : vector<2x2xindex>
}
@@ -171,6 +172,7 @@ func.func @test_0d_vector() -> vector<f32> {
}
// -----
+
// ALL-LABEL: test_extract_strided_slice_1
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
@@ -193,6 +195,8 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
return %0 : vector<2x2xf32>
}
+// -----
+
// ALL-LABEL: func.func @test_extract_strided_slice_1_scalable(
// ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
@@ -205,6 +209,7 @@ func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> ve
}
// -----
+
// ALL-LABEL: test_extract_strided_slice_2
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
@@ -228,6 +233,7 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
}
// -----
+
// ALL-LABEL: test_vector_shuffle
// ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
@@ -252,6 +258,7 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
}
// -----
+
// ALL-LABEL: test_vector_extract
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -273,6 +280,8 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
return %0 : vector<8x2xf32>
}
+// -----
+
// ALL-LABEL: func.func @test_vector_extract_scalable(
// ALL-SAME: %[[VAL_0:.*]]: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> {
@@ -283,7 +292,9 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x
// ALL: return %[[RES]] : vector<8x[2]xf32>
return %0 : vector<8x[2]xf32>
}
+
// -----
+
// ALL-LABEL: test_vector_insert
// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -312,6 +323,8 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
return %0 : vector<2x8x4xf32>
}
+// -----
+
// ALL-LABEL: func.func @test_vector_insert_scalable(
// ALL-SAME: %[[VAL_0:.*]]: vector<2x8x[4]xf32>, %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> {
@@ -385,6 +398,7 @@ func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
}
// -----
+
// ALL-LABEL: test_vector_bitcast
// ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32>
func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03f907e46c2c6..eda2594fbc7c7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -7,17 +7,13 @@
//===----------------------------------------------------------------------===//
#include <optional>
-#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -840,16 +836,98 @@ struct TestVectorEmulateMaskedLoadStore final
}
};
-struct TestVectorLinearize final
- : public PassWrapper<TestVectorLinearize, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+// TODO: move this code into the user project.
+namespace vendor {
- TestVectorLinearize() = default;
- TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+/// Get the set of operand/result types to check for sufficiently
+/// small inner-most dimension size.
+static SmallVector<std::pair<Type, unsigned>>
+getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {
- StringRef getArgument() const override { return "test-vector-linearize"; }
+ if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+ unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
+ ? targetBitWidth + 1
+ : targetBitWidth;
+ return {{insertOp.getValueToStoreType(), w}};
+ }
+
+ auto resultTypes = op->getResultTypes();
+ SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+ resultsWithBitWidth.reserve(resultTypes.size());
+ for (Type type : resultTypes) {
+ resultsWithBitWidth.push_back({type, targetBitWidth});
+ }
+ return resultsWithBitWidth;
+}
+
+/// If `type` is VectorType with trailing dimension of (bit) size greater than
+/// or equal to `targetBitWidth`, its defining op is considered legal.
+static bool
+isNotLinearizableBecauseLargeInnerDimension(Type type,
+ unsigned targetBitWidth) {
+
+ VectorType vecType = dyn_cast<VectorType>(type);
+
+ // Not linearizable for reasons other than what this function checks.
+ if (!vecType || vecType.getRank() == 0)
+ return false;
+
+ // The width of the type 'index' is unbounded (and therefore potentially above
+ // the target width).
+ if (vecType.getElementType().isIndex())
+ return true;
+
+ unsigned finalDimSize = vecType.getShape().back();
+ unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
+ unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
+ return trailingVecDimBitWidth >= targetBitWidth;
+}
+
+static bool
+isNotLinearizableBecauseLargeInnerDimension(Operation *op,
+ unsigned targetBitWidth) {
+ // Check on bitwidths.
+ SmallVector<std::pair<Type, unsigned>> toCheck =
+ getTypeBitWidthBoundPairs(op, targetBitWidth);
+ return std::any_of(toCheck.begin(), toCheck.end(),
+ [&](std::pair<Type, unsigned> typeWidth) {
+ return isNotLinearizableBecauseLargeInnerDimension(
+ typeWidth.first, typeWidth.second);
+ });
+}
+
+void populateWithBitWidthConstraints(TypeConverter &typeConverter,
+ ConversionTarget &target,
+ unsigned targetBitWidth) {
+
+ // The general purpose definition of what ops are legal must come first.
+ populateForVectorLinearize(typeConverter, target);
+
+ // Extend the set of legal ops to include those with large inner-most
+ // dimensions on selected operands/results.
+ target.markUnknownOpDynamicallyLegal(
+ [=](Operation *op) -> std::optional<bool> {
+ if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
+ return true;
+ }
+ return {};
+ });
+}
+
+struct TestVectorBitWidthLinearize final
+ : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)
+
+ TestVectorBitWidthLinearize() = default;
+ TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
+ : PassWrapper(pass) {}
+
+ StringRef getArgument() const override {
+ return "test-bit-width-constrained-vector-linearize";
+ }
StringRef getDescription() const override {
- return "Linearizes ND vectors for N >= 2 into 1D vectors";
+ return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
+ "in inner-most dimension's bit width.";
}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<vector::VectorDialect>();
@@ -867,10 +945,49 @@ struct TestVectorLinearize final
RewritePatternSet patterns(context);
ConversionTarget target(*context);
- vector::populateVectorLinearizeTypeConversionsAndLegality(
- typeConverter, patterns, target, targetVectorBitwidth);
- vector::populateVectorLinearizeShuffleLikeOpsPatterns(
- typeConverter, patterns, target, targetVectorBitwidth);
+ populateWithBitWidthConstraints(typeConverter, target,
+ targetVectorBitwidth);
+
+ vector::populateVectorLinearizeBasePatterns(typeConverter, target,
+ patterns);
+
+ vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target,
+ patterns);
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace vendor
+
+struct TestVectorLinearize final
+ : public PassWrapper<TestVectorLinearize, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+
+ TestVectorLinearize() = default;
+
+ StringRef getArgument() const override { return "test-vector-linearize"; }
+ StringRef getDescription() const override {
+ return "Linearizes ND vectors for N >= 2 into 1D vectors";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ TypeConverter converter;
+ RewritePatternSet patterns(&context);
+ ConversionTarget target(context);
+
+ vector::populateForVectorLinearize(converter, target);
+
+ vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
+ vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
+ patterns);
+
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
@@ -950,6 +1067,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorLinearize>();
+ PassRegistration<vendor::TestVectorBitWidthLinearize>();
+
PassRegistration<TestEliminateVectorMasks>();
}
} // namespace test
More information about the Mlir-commits
mailing list