[Mlir-commits] [mlir] 3e5640b - [mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N) (#123526)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 2 06:58:42 PST 2025
Author: Andrzej WarzyĆski
Date: 2025-02-02T14:58:38Z
New Revision: 3e5640b22d3978571816b9a0468a4aed27cdd82c
URL: https://github.com/llvm/llvm-project/commit/3e5640b22d3978571816b9a0468a4aed27cdd82c
DIFF: https://github.com/llvm/llvm-project/commit/3e5640b22d3978571816b9a0468a4aed27cdd82c.diff
LOG: [mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N) (#123526)
This is PR 1 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.
This PR renames:
* `srcBits`/`dstBits` + `oldElementType`/`newElementType`
to improve consistency in naming within the file. This is illustrated
below:
```cpp
// Extracted from VectorEmulateNarrowType.cpp
// BEFORE (mixing old/new and src/dst):
// Type oldElementType = op.getType().getElementType();
// Type newElementType = convertedType.getElementType();
// int srcBits = oldElementType.getIntOrFloatBitWidth();
// int dstBits = newElementType.getIntOrFloatBitWidth();
// AFTER (consistently using emulated/container):
Type emulatedElemType = op.getType().getElementType();
Type containerElemType = convertedType.getElementType();
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
int containerBits = containerElemTy.getIntOrFloatBitWidth();
```
Also adds some comments and unifies related "rewriter notification"
messages.
**GitHub issue to track this work:**
* https://github.com/llvm/llvm-project/issues/123630
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 7ca88f1e0a0df98..63365cb54461244 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -415,18 +415,21 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
"only 1-D vectors are supported ATM");
auto loc = op.getLoc();
+
auto valueToStore = cast<VectorValue>(op.getValueToStore());
- auto oldElementType = valueToStore.getType().getElementType();
- auto newElementType =
+ auto containerElemTy =
cast<MemRefType>(adaptor.getBase().getType()).getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ Type emulatedElemTy = op.getValueToStore().getType().getElementType();
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+ int containerBits = containerElemTy.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
+ // Check per-element alignment.
+ if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ op, "impossible to pack emulated elements into container elements "
+ "(bit-wise misalignment)");
}
- int numSrcElemsPerDest = dstBits / srcBits;
+ int numSrcElemsPerDest = containerBits / emulatedBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -451,7 +454,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -483,7 +486,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// Basic case: storing full bytes.
auto numElements = origElements / numSrcElemsPerDest;
auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements, newElementType),
+ loc, VectorType::get(numElements, containerElemTy),
op.getValueToStore());
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, bitCast.getResult(), memrefBase,
@@ -638,18 +641,20 @@ struct ConvertVectorMaskedStore final
"only 1-D vectors are supported ATM");
auto loc = op.getLoc();
- auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getValueToStore().getType().getElementType();
- Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ auto containerElemTy =
+ cast<MemRefType>(adaptor.getBase().getType()).getElementType();
+ Type emulatedElemTy = op.getValueToStore().getType().getElementType();
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+ int containerBits = containerElemTy.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
+ // Check per-element alignment.
+ if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ op, "impossible to pack emulated elements into container elements "
+ "(bit-wise misalignment)");
}
- int scale = dstBits / srcBits;
+ int scale = containerBits / emulatedBits;
int origElements = op.getValueToStore().getType().getNumElements();
if (origElements % scale != 0)
return failure();
@@ -660,7 +665,7 @@ struct ConvertVectorMaskedStore final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndicesOfr) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -706,7 +711,7 @@ struct ConvertVectorMaskedStore final
return failure();
auto numElements = (origElements + scale - 1) / scale;
- auto newType = VectorType::get(numElements, newElementType);
+ auto newType = VectorType::get(numElements, containerElemTy);
auto passThru = rewriter.create<arith::ConstantOp>(
loc, newType, rewriter.getZeroAttr(newType));
@@ -714,7 +719,7 @@ struct ConvertVectorMaskedStore final
loc, newType, adaptor.getBase(), linearizedIndices,
newMask.value()->getResult(0), passThru);
- auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+ auto newBitCastType = VectorType::get(numElements * scale, emulatedElemTy);
Value valueToStore =
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
valueToStore = rewriter.create<arith::SelectOp>(
@@ -746,17 +751,19 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
"only 1-D vectors are supported ATM");
auto loc = op.getLoc();
- auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getType().getElementType();
- Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ auto containerElemTy =
+ cast<MemRefType>(adaptor.getBase().getType()).getElementType();
+ Type emulatedElemTy = op.getType().getElementType();
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+ int containerBits = containerElemTy.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
+ // Check per-element alignment.
+ if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ op, "impossible to pack emulated elements into container elements "
+ "(bit-wise misalignment)");
}
- int scale = dstBits / srcBits;
+ int scale = containerBits / emulatedBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -797,7 +804,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -814,7 +821,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
llvm::divideCeil(maxintraDataOffset + origElements, scale);
Value result =
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
- numElements, oldElementType, newElementType);
+ numElements, emulatedElemTy, containerElemTy);
if (!foldedIntraVectorOffset) {
auto resultVector = rewriter.create<arith::ConstantOp>(
@@ -848,17 +855,20 @@ struct ConvertVectorMaskedLoad final
"only 1-D vectors are supported ATM");
auto loc = op.getLoc();
- auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getType().getElementType();
- Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
+ auto containerElemTy =
+ cast<MemRefType>(adaptor.getBase().getType()).getElementType();
+ Type emulatedElemTy = op.getType().getElementType();
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+ int containerBits = containerElemTy.getIntOrFloatBitWidth();
+
+ // Check per-element alignment.
+ if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ op, "impossible to pack emulated elements into container elements "
+ "(bit-wise misalignment)");
}
- int scale = dstBits / srcBits;
+ int scale = containerBits / emulatedBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -912,7 +922,7 @@ struct ConvertVectorMaskedLoad final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -933,8 +943,8 @@ struct ConvertVectorMaskedLoad final
auto numElements =
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
- auto loadType = VectorType::get(numElements, newElementType);
- auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+ auto loadType = VectorType::get(numElements, containerElemTy);
+ auto newBitcastType = VectorType::get(numElements * scale, emulatedElemTy);
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -1009,23 +1019,25 @@ struct ConvertVectorTransferRead final
"only 1-D vectors are supported ATM");
auto loc = op.getLoc();
- auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
- Type oldElementType = op.getType().getElementType();
- Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
-
- if (dstBits % srcBits != 0) {
+ auto containerElemTy =
+ cast<MemRefType>(adaptor.getSource().getType()).getElementType();
+ Type emulatedElemTy = op.getType().getElementType();
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+ int containerBits = containerElemTy.getIntOrFloatBitWidth();
+
+ // Check per-element alignment.
+ if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ op, "impossible to pack emulated elements into container elements "
+ "(bit-wise misalignment)");
}
- int scale = dstBits / srcBits;
+ int scale = containerBits / emulatedBits;
auto origElements = op.getVectorType().getNumElements();
bool isAlignedEmulation = origElements % scale == 0;
- auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
+ auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
adaptor.getPadding());
auto stridedMetadata =
@@ -1035,7 +1047,7 @@ struct ConvertVectorTransferRead final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -1051,12 +1063,12 @@ struct ConvertVectorTransferRead final
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto newRead = rewriter.create<vector::TransferReadOp>(
- loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
+ loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
newPadding);
auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements * scale, oldElementType), newRead);
+ loc, VectorType::get(numElements * scale, emulatedElemTy), newRead);
Value result = bitCast->getResult(0);
if (!foldedIntraVectorOffset) {
More information about the Mlir-commits
mailing list