[Mlir-commits] [mlir] andrzej/refactor narrow type 2 (PR #123527)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 19 12:50:51 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
- **[mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N)**
- **[mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N)**
---
Full diff: https://github.com/llvm/llvm-project/pull/123527.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+74-63)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..4e0be258954496 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -282,13 +282,15 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
OpFoldResult linearizedIndices,
int64_t numEmultedElementsToLoad, Type origElemType,
Type emulatedElemType) {
- auto scale = emulatedElemType.getIntOrFloatBitWidth() /
- origElemType.getIntOrFloatBitWidth();
+ auto elementsPerContainerType = emulatedElemType.getIntOrFloatBitWidth() /
+ origElemType.getIntOrFloatBitWidth();
auto newLoad = rewriter.create<vector::LoadOp>(
loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
return rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
+ loc,
+ VectorType::get(numEmultedElementsToLoad * elementsPerContainerType,
+ origElemType),
newLoad);
}
@@ -314,14 +316,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int elementsPerContainerType = newBits / oldBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -337,7 +339,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>
auto origElements = op.getValueToStore().getType().getNumElements();
- if (origElements % scale != 0)
+ if (origElements % elementsPerContainerType != 0)
return failure();
auto stridedMetadata =
@@ -346,13 +348,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto numElements = origElements / scale;
+ auto numElements = origElements / elementsPerContainerType;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, newElementType),
op.getValueToStore());
@@ -385,17 +387,17 @@ struct ConvertVectorMaskedStore final
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int elementsPerContainerType = newBits / oldBits;
int origElements = op.getValueToStore().getType().getNumElements();
- if (origElements % scale != 0)
+ if (origElements % elementsPerContainerType != 0)
return failure();
auto stridedMetadata =
@@ -404,7 +406,7 @@ struct ConvertVectorMaskedStore final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndicesOfr) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -444,12 +446,13 @@ struct ConvertVectorMaskedStore final
//
// FIXME: Make an example based on the comment above work (see #115460 for
// reproducer).
- FailureOr<Operation *> newMask =
- getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+ FailureOr<Operation *> newMask = getCompressedMaskOp(
+ rewriter, loc, op.getMask(), origElements, elementsPerContainerType);
if (failed(newMask))
return failure();
- auto numElements = (origElements + scale - 1) / scale;
+ auto numElements = (origElements + elementsPerContainerType - 1) /
+ elementsPerContainerType;
auto newType = VectorType::get(numElements, newElementType);
auto passThru = rewriter.create<arith::ConstantOp>(
loc, newType, rewriter.getZeroAttr(newType));
@@ -458,7 +461,8 @@ struct ConvertVectorMaskedStore final
loc, newType, adaptor.getBase(), linearizedIndices,
newMask.value()->getResult(0), passThru);
- auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+ auto newBitCastType =
+ VectorType::get(numElements * elementsPerContainerType, oldElementType);
Value valueToStore =
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
valueToStore = rewriter.create<arith::SelectOp>(
@@ -493,14 +497,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int elementsPerContainerType = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -532,7 +536,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// compile time as they must be constants.
auto origElements = op.getVectorType().getNumElements();
- bool isUnalignedEmulation = origElements % scale != 0;
+ bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -541,7 +545,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -553,9 +557,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
: 0;
// Always load enough elements which can cover the original elements.
- int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
- auto numElements =
- llvm::divideCeil(maxintraDataOffset + origElements, scale);
+ int64_t maxintraDataOffset =
+ foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+ auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
+ elementsPerContainerType);
Value result =
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
numElements, oldElementType, newElementType);
@@ -596,14 +601,14 @@ struct ConvertVectorMaskedLoad final
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int elementsPerContainerType = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -649,7 +654,7 @@ struct ConvertVectorMaskedLoad final
// subvector at the proper offset after bit-casting.
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
- bool isUnalignedEmulation = origElements % scale != 0;
+ bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -657,7 +662,7 @@ struct ConvertVectorMaskedLoad final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -668,18 +673,21 @@ struct ConvertVectorMaskedLoad final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
- FailureOr<Operation *> newMask = getCompressedMaskOp(
- rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
+ int64_t maxIntraDataOffset =
+ foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+ FailureOr<Operation *> newMask =
+ getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
+ elementsPerContainerType, maxIntraDataOffset);
if (failed(newMask))
return failure();
Value passthru = op.getPassThru();
- auto numElements =
- llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+ auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+ elementsPerContainerType);
auto loadType = VectorType::get(numElements, newElementType);
- auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+ auto newBitcastType =
+ VectorType::get(numElements * elementsPerContainerType, oldElementType);
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -706,8 +714,8 @@ struct ConvertVectorMaskedLoad final
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
Value mask = op.getMask();
- auto newSelectMaskType =
- VectorType::get(numElements * scale, rewriter.getI1Type());
+ auto newSelectMaskType = VectorType::get(
+ numElements * elementsPerContainerType, rewriter.getI1Type());
// TODO: try to fold if op's mask is constant
auto emptyMask = rewriter.create<arith::ConstantOp>(
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
@@ -758,18 +766,18 @@ struct ConvertVectorTransferRead final
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int elementsPerContainerType = newBits / oldBits;
auto origElements = op.getVectorType().getNumElements();
- bool isUnalignedEmulation = origElements % scale != 0;
+ bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
adaptor.getPadding());
@@ -781,7 +789,7 @@ struct ConvertVectorTransferRead final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -792,9 +800,10 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
- auto numElements =
- llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+ int64_t maxIntraDataOffset =
+ foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+ auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+ elementsPerContainerType);
auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -802,7 +811,9 @@ struct ConvertVectorTransferRead final
newPadding);
auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements * scale, oldElementType), newRead);
+ loc,
+ VectorType::get(numElements * elementsPerContainerType, oldElementType),
+ newRead);
Value result = bitCast->getResult(0);
if (!foldedIntraVectorOffset) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/123527
More information about the Mlir-commits
mailing list