[Mlir-commits] [mlir] andrzej/refactor narrow type 4 (PR #123529)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Sun Jan 19 13:03:05 PST 2025
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/123529
- **[mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N)**
- **[mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N)**
- **[mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N)**
- **[mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N)**
>From 5edc3423c1da0c62c8a7bb5fa3e9d54855bdf3bf Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 18 Jan 2025 21:45:51 +0000
Subject: [PATCH 1/4] [mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N)
This is PR 1 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.
This PR renames `srcBits/dstBits` to `oldBits/newBits` to improve
consistency in naming within the file. This is illustrated below:
```cpp
// Extracted from VectorEmulateNarrowType.cpp
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
// BEFORE (mixing old/new and src/dst):
// int srcBits = oldElementType.getIntOrFloatBitWidth();
// int dstBits = newElementType.getIntOrFloatBitWidth();
// AFTER (consistently using old/new):
int oldBits = oldElementType.getIntOrFloatBitWidth();
int newBits = newElementType.getIntOrFloatBitWidth();
```
Also adds some comments and unifies related "rewriter notification"
messages.
---
.../Transforms/VectorEmulateNarrowType.cpp | 70 +++++++++----------
1 file changed, 35 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..70d50e1d48040c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -314,14 +314,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -346,7 +346,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -385,15 +385,15 @@ struct ConvertVectorMaskedStore final
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
int origElements = op.getValueToStore().getType().getNumElements();
if (origElements % scale != 0)
return failure();
@@ -404,7 +404,7 @@ struct ConvertVectorMaskedStore final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndicesOfr) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -493,14 +493,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -541,7 +541,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -596,14 +596,14 @@ struct ConvertVectorMaskedLoad final
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -657,7 +657,7 @@ struct ConvertVectorMaskedLoad final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -758,14 +758,14 @@ struct ConvertVectorTransferRead final
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
auto origElements = op.getVectorType().getNumElements();
@@ -781,7 +781,7 @@ struct ConvertVectorTransferRead final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
>From d40b31bb098e874be488182050c68b887e8d091a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 18 Jan 2025 21:59:34 +0000
Subject: [PATCH 2/4] [mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N)
This is PR 2 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.
This PR renames the variable "scale". Note, "scale" could mean either:
* "original-elements-per-emulated-type", or
* "emulated-elements-per-original-type".
While from the context it is clear that it's always the former (original
type is always a sub-byte type and the emulated type is usually `i8`),
this PR reduces the cognitive load by making this clear.
---
.../Transforms/VectorEmulateNarrowType.cpp | 77 +++++++++++--------
1 file changed, 44 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 70d50e1d48040c..4e0be258954496 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -282,13 +282,15 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
OpFoldResult linearizedIndices,
int64_t numEmultedElementsToLoad, Type origElemType,
Type emulatedElemType) {
- auto scale = emulatedElemType.getIntOrFloatBitWidth() /
- origElemType.getIntOrFloatBitWidth();
+ auto elementsPerContainerType = emulatedElemType.getIntOrFloatBitWidth() /
+ origElemType.getIntOrFloatBitWidth();
auto newLoad = rewriter.create<vector::LoadOp>(
loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
return rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
+ loc,
+ VectorType::get(numEmultedElementsToLoad * elementsPerContainerType,
+ origElemType),
newLoad);
}
@@ -321,7 +323,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
if (newBits % oldBits != 0) {
return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = newBits / oldBits;
+ int elementsPerContainerType = newBits / oldBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -337,7 +339,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>
auto origElements = op.getValueToStore().getType().getNumElements();
- if (origElements % scale != 0)
+ if (origElements % elementsPerContainerType != 0)
return failure();
auto stridedMetadata =
@@ -352,7 +354,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto numElements = origElements / scale;
+ auto numElements = origElements / elementsPerContainerType;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, newElementType),
op.getValueToStore());
@@ -393,9 +395,9 @@ struct ConvertVectorMaskedStore final
return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = newBits / oldBits;
+ int elementsPerContainerType = newBits / oldBits;
int origElements = op.getValueToStore().getType().getNumElements();
- if (origElements % scale != 0)
+ if (origElements % elementsPerContainerType != 0)
return failure();
auto stridedMetadata =
@@ -444,12 +446,13 @@ struct ConvertVectorMaskedStore final
//
// FIXME: Make an example based on the comment above work (see #115460 for
// reproducer).
- FailureOr<Operation *> newMask =
- getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+ FailureOr<Operation *> newMask = getCompressedMaskOp(
+ rewriter, loc, op.getMask(), origElements, elementsPerContainerType);
if (failed(newMask))
return failure();
- auto numElements = (origElements + scale - 1) / scale;
+ auto numElements = (origElements + elementsPerContainerType - 1) /
+ elementsPerContainerType;
auto newType = VectorType::get(numElements, newElementType);
auto passThru = rewriter.create<arith::ConstantOp>(
loc, newType, rewriter.getZeroAttr(newType));
@@ -458,7 +461,8 @@ struct ConvertVectorMaskedStore final
loc, newType, adaptor.getBase(), linearizedIndices,
newMask.value()->getResult(0), passThru);
- auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+ auto newBitCastType =
+ VectorType::get(numElements * elementsPerContainerType, oldElementType);
Value valueToStore =
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
valueToStore = rewriter.create<arith::SelectOp>(
@@ -500,7 +504,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
if (newBits % oldBits != 0) {
return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = newBits / oldBits;
+ int elementsPerContainerType = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -532,7 +536,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// compile time as they must be constants.
auto origElements = op.getVectorType().getNumElements();
- bool isUnalignedEmulation = origElements % scale != 0;
+ bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -553,9 +557,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
: 0;
// Always load enough elements which can cover the original elements.
- int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
- auto numElements =
- llvm::divideCeil(maxintraDataOffset + origElements, scale);
+ int64_t maxintraDataOffset =
+ foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+ auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
+ elementsPerContainerType);
Value result =
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
numElements, oldElementType, newElementType);
@@ -603,7 +608,7 @@ struct ConvertVectorMaskedLoad final
if (newBits % oldBits != 0) {
return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = newBits / oldBits;
+ int elementsPerContainerType = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -649,7 +654,7 @@ struct ConvertVectorMaskedLoad final
// subvector at the proper offset after bit-casting.
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
- bool isUnalignedEmulation = origElements % scale != 0;
+ bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -668,18 +673,21 @@ struct ConvertVectorMaskedLoad final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
- FailureOr<Operation *> newMask = getCompressedMaskOp(
- rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
+ int64_t maxIntraDataOffset =
+ foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+ FailureOr<Operation *> newMask =
+ getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
+ elementsPerContainerType, maxIntraDataOffset);
if (failed(newMask))
return failure();
Value passthru = op.getPassThru();
- auto numElements =
- llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+ auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+ elementsPerContainerType);
auto loadType = VectorType::get(numElements, newElementType);
- auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+ auto newBitcastType =
+ VectorType::get(numElements * elementsPerContainerType, oldElementType);
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -706,8 +714,8 @@ struct ConvertVectorMaskedLoad final
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
Value mask = op.getMask();
- auto newSelectMaskType =
- VectorType::get(numElements * scale, rewriter.getI1Type());
+ auto newSelectMaskType = VectorType::get(
+ numElements * elementsPerContainerType, rewriter.getI1Type());
// TODO: try to fold if op's mask is constant
auto emptyMask = rewriter.create<arith::ConstantOp>(
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
@@ -765,11 +773,11 @@ struct ConvertVectorTransferRead final
if (newBits % oldBits != 0) {
return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = newBits / oldBits;
+ int elementsPerContainerType = newBits / oldBits;
auto origElements = op.getVectorType().getNumElements();
- bool isUnalignedEmulation = origElements % scale != 0;
+ bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
adaptor.getPadding());
@@ -792,9 +800,10 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
- auto numElements =
- llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+ int64_t maxIntraDataOffset =
+ foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+ auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+ elementsPerContainerType);
auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -802,7 +811,9 @@ struct ConvertVectorTransferRead final
newPadding);
auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements * scale, oldElementType), newRead);
+ loc,
+ VectorType::get(numElements * elementsPerContainerType, oldElementType),
+ newRead);
Value result = bitCast->getResult(0);
if (!foldedIntraVectorOffset) {
>From e8ffbaa0e334e515887d40e6f112161a8476cc7f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 17 Jan 2025 13:54:34 +0000
Subject: [PATCH 3/4] [mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N)
This is PR 3 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.
1. Replaces `isUnalignedEmulation` with `isFullyAligned`
Note, `isUnalignedEmulation` is always computed following a
"per-element-alignment" condition:
```cpp
// Check per-element alignment.
if (newBits % oldBits != 0) {
return rewriter.notifyMatchFailure(op, "unalagined element types");
}
// (...)
bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
```
Given that `isUnalignedEmulation` captures only one of two conditions
required for "full alignment", it should be re-named as
`isPartiallyUnalignedEmulation`. Instead, I've flipped the condition and
renamed it as `isFullyAligned`:
```cpp
bool isFullyAligned = origElements % elementsPerContainerType == 0;
```
2. In addition:
* Unifies various comments throughout the file (for consistency).
* Adds new comments throughout the file and adds TODOs where high-level
comments are missing.
---
.../Transforms/VectorEmulateNarrowType.cpp | 105 ++++++++++--------
1 file changed, 61 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 4e0be258954496..b6c109cc742eca 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -45,6 +45,10 @@ using namespace mlir;
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+//===----------------------------------------------------------------------===//
+// Utils
+//===----------------------------------------------------------------------===//
+
/// Returns a compressed mask for the emulated vector. For example, when
/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
/// elements span two dest elements), this method compresses `vector<8xi1>`
@@ -300,6 +304,7 @@ namespace {
// ConvertVectorStore
//===----------------------------------------------------------------------===//
+// TODO: Document-me
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
@@ -370,6 +375,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// ConvertVectorMaskedStore
//===----------------------------------------------------------------------===//
+// TODO: Document-me
struct ConvertVectorMaskedStore final
: OpConversionPattern<vector::MaskedStoreOp> {
using OpConversionPattern::OpConversionPattern;
@@ -481,6 +487,7 @@ struct ConvertVectorMaskedStore final
// ConvertVectorLoad
//===----------------------------------------------------------------------===//
+// TODO: Document-me
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -536,7 +543,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// compile time as they must be constants.
auto origElements = op.getVectorType().getNumElements();
- bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
+ // Note, per-element-alignment was already verified above.
+ bool isFullyAligned = origElements % elementsPerContainerType == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -552,9 +560,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedIntraVectorOffset =
- isUnalignedEmulation
- ? getConstantIntValue(linearizedInfo.intraDataOffset)
- : 0;
+ isFullyAligned ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
// Always load enough elements which can cover the original elements.
int64_t maxintraDataOffset =
@@ -571,7 +578,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
result = dynamicallyExtractSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
- } else if (isUnalignedEmulation) {
+ } else if (!isFullyAligned) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
@@ -585,6 +592,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// ConvertVectorMaskedLoad
//===----------------------------------------------------------------------===//
+// TODO: Document-me
struct ConvertVectorMaskedLoad final
: OpConversionPattern<vector::MaskedLoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -749,6 +757,7 @@ struct ConvertVectorMaskedLoad final
// ConvertVectorTransferRead
//===----------------------------------------------------------------------===//
+// TODO: Document-me
struct ConvertVectorTransferRead final
: OpConversionPattern<vector::TransferReadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -777,7 +786,8 @@ struct ConvertVectorTransferRead final
auto origElements = op.getVectorType().getNumElements();
- bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
+ // Note, per-element-alignment was already verified above.
+ bool isFullyAligned = origElements % elementsPerContainerType == 0;
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
adaptor.getPadding());
@@ -796,9 +806,8 @@ struct ConvertVectorTransferRead final
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedIntraVectorOffset =
- isUnalignedEmulation
- ? getConstantIntValue(linearizedInfo.intraDataOffset)
- : 0;
+ isFullyAligned ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
int64_t maxIntraDataOffset =
foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
@@ -822,7 +831,7 @@ struct ConvertVectorTransferRead final
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
linearizedInfo.intraDataOffset,
origElements);
- } else if (isUnalignedEmulation) {
+ } else if (!isFullyAligned) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
@@ -1506,33 +1515,34 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// LLVM to scramble with peephole optimizations. Templated to choose between
/// signed and unsigned conversions.
///
-/// For example (signed):
+/// EXAMPLE 1 (signed):
/// arith.extsi %in : vector<8xi4> to vector<8xi32>
-/// is rewriten as
-/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
-/// %1 = arith.shli %0, 4 : vector<4xi8>
-/// %2 = arith.shrsi %1, 4 : vector<4xi8>
-/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
-/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
+/// is rewriten as:
+/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+/// %1 = arith.shli %0, 4 : vector<4xi8>
+/// %2 = arith.shrsi %1, 4 : vector<4xi8>
+/// %3 = arith.shrsi %0, 4 : vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
+/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
///
+/// EXAMPLE 2 (fp):
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
-/// is rewriten as
-/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
-/// %1 = arith.shli %0, 4 : vector<4xi8>
-/// %2 = arith.shrsi %1, 4 : vector<4xi8>
-/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
-/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
+/// is rewriten as:
+/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+/// %1 = arith.shli %0, 4 : vector<4xi8>
+/// %2 = arith.shrsi %1, 4 : vector<4xi8>
+/// %3 = arith.shrsi %0, 4 : vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
+/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
///
-/// Example (unsigned):
+/// EXAMPLE 3 (unsigned):
/// arith.extui %in : vector<8xi4> to vector<8xi32>
-/// is rewritten as
-/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
-/// %1 = arith.andi %0, 15 : vector<4xi8>
-/// %2 = arith.shrui %0, 4 : vector<4xi8>
-/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
-/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
+/// is rewritten as:
+/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+/// %1 = arith.andi %0, 15 : vector<4xi8>
+/// %2 = arith.shrui %0, 4 : vector<4xi8>
+/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
+/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
///
template <typename ConversionOpType, bool isSigned>
struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
@@ -1542,8 +1552,8 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
PatternRewriter &rewriter) const override {
// Verify the preconditions.
Value srcValue = conversionOp.getIn();
- auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
- auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+ VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
+ VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
if (failed(
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
@@ -1583,15 +1593,16 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
///
/// For example:
/// arith.trunci %in : vector<8xi32> to vector<8xi4>
-/// is rewriten as
///
-/// %cst = arith.constant dense<15> : vector<4xi8>
-/// %cst_0 = arith.constant dense<4> : vector<4xi8>
-/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
-/// %2 = arith.andi %0, %cst : vector<4xi8>
-/// %3 = arith.shli %1, %cst_0 : vector<4xi8>
-/// %4 = arith.ori %2, %3 : vector<4xi8>
-/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
+/// is rewriten as:
+///
+/// %cst = arith.constant dense<15> : vector<4xi8>
+/// %cst_0 = arith.constant dense<4> : vector<4xi8>
+/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
+/// %2 = arith.andi %0, %cst : vector<4xi8>
+/// %3 = arith.shli %1, %cst_0 : vector<4xi8>
+/// %4 = arith.ori %2, %3 : vector<4xi8>
+/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
///
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
@@ -1635,10 +1646,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
/// Rewrite a sub-byte vector transpose into a sequence of instructions that
/// perform the transpose on wider (byte) element types.
-/// For example:
+///
+/// EXAMPLE:
/// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
///
-/// is rewritten as:
+/// is rewritten as:
///
/// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
/// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
@@ -1686,6 +1698,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
// Public Interface Definition
//===----------------------------------------------------------------------===//
+// The emulated type is inferred from the converted memref type.
void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
@@ -1698,22 +1711,26 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
void vector::populateVectorNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
+ // TODO: Document what the emulated type is.
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
benefit);
// Patterns for aligned cases. We set higher priority as they are expected to
// generate better performance for aligned cases.
+ // The emulated type is always i8.
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
benefit.getBenefit() + 1);
+ // The emulated type is always i8.
patterns
.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
patterns.getContext(), benefit.getBenefit() + 1);
}
+// The emulated type is always i8.
void vector::populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
>From 5583441cb1a6844d618e5c05e9007606d76beadf Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 17 Jan 2025 13:54:34 +0000
Subject: [PATCH 4/4] [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N)
This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.
1. Update `alignedConversionPrecondition` (1):
This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).
2. Update `alignedConversionPrecondition` (2):
In #121298, we replaced `dstElemBitwidt` in this calculation:
```cpp
const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```
with the hard-coded value of 8:
```cpp
const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```
That was correct as for the patterns for which this hook was/is used:
* `RewriteAlignedSubByteIntExt`,
* `RewriteAlignedSubByteIntTrunc`.
The destination type (or, more precisely, the emulated type) was always
`i8`.
In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.
Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).
3. Update alignedConversionPrecondition (3):
The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).
4. Update alignedConversionPrecondition (4):
NEXT STEPS:
We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.
For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:
```mlir
%0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```
However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:
```mlir
%bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```
I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.
---
.../Transforms/VectorEmulateNarrowType.cpp | 134 +++++++++++++-----
1 file changed, 102 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index b6c109cc742eca..373b8a8822318f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -753,6 +753,38 @@ struct ConvertVectorMaskedLoad final
}
};
+/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
+///
+/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
+/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
+/// (a multi-byte scalar, e.g. i16), where N is some integer.
+///
+/// Put differently, this method checks whether this would be valid:
+///
+/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
+///
+/// EXAMPLES:
+/// * vector<4xi4> -> i16 - yes (N = 1)
+/// * vector<4xi4> -> i8 - yes (N = 2)
+/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
+/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
+static bool isSubByteVecFittable(VectorType subByteVecTy,
+ Type multiByteScalarTy) {
+ assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
+
+ int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
+ int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
+
+ assert(subByteBits < 8 && "Not a sub-byte scalar type!");
+ assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+ assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
+
+ int elemsPerMultiByte = multiByteBits / subByteBits;
+
+ // TODO: This is a bit too restrictive for vectors rank > 1.
+ return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
+}
+
//===----------------------------------------------------------------------===//
// ConvertVectorTransferRead
//===----------------------------------------------------------------------===//
@@ -787,7 +819,8 @@ struct ConvertVectorTransferRead final
auto origElements = op.getVectorType().getNumElements();
// Note, per-element-alignment was already verified above.
- bool isFullyAligned = origElements % elementsPerContainerType == 0;
+ bool isFullyAligned =
+ isSubByteVecFittable(op.getVectorType(), newElementType);
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
adaptor.getPadding());
@@ -1089,41 +1122,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
return commonConversionPrecondition(rewriter, preconditionType, op);
}
-/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
-/// means that:
-/// 1. The `dstType` element type is a multiple of the
-/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
-/// is not supported). Let this multiple be `N`.
-/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
-/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
-/// not supported).
+/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
+///
+/// Alignment means that `subByteVecTy` can be packed into a vector of
+/// `containerTy` elements. More specifically:
+/// 1. The bit-width of `containerTy` is a multiple of the
+/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
+/// this multiple is 4.
+/// 2. The multiple from 1. above divides evenly the number of the (trailing)
+/// elements in `subByteVecTy`.
+///
+/// EXAMPLE 1:
+/// `subByteVecTy = vector<2xi4>`, and
+/// `containerTy = i16`
+///
+/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
+///
+/// EXAMPLE 2:
+/// `subByteVecTy = vector<3xi4>`, and
+/// `containerTy = i16`
+///
+/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
+///
+/// EXAMPLE 3:
+/// `subByteVecTy = vector<3xi3>`, and
+/// `containerTy = i16`
+///
+/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
///
/// NOTE: This method assumes that common conversion preconditions are met. In
-/// particular, the element type of `dstType` is assumed to be a multi-byte
-/// type (e.g. i8, i16, i32).
+/// particular, `containerTy` is assumed to be a
+/// multi-byte scalar type (e.g., i8, i16, i32).
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
- VectorType subByteVecType,
- VectorType dstType,
+ VectorType subByteVecTy,
+ Type containerTy,
Operation *op) {
- if (!subByteVecType || !dstType)
- return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
- unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
- unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+ // TODO: This is validating the inputs rather than checking the conditions
+ // documented above. Replace with an assert.
+ if (!subByteVecTy)
+ return rewriter.notifyMatchFailure(op, "not a vector!");
- if (dstElemBitwidth < 8)
- return rewriter.notifyMatchFailure(
- op, "the bitwidth of dstType must be greater than or equal to 8");
- if (dstElemBitwidth % srcElemBitwidth != 0)
- return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
- if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
+ // TODO: This is validating the inputs rather than checking the conditions
+ // documented above. Replace with an assert.
+ if (!containerTy.isIntOrFloat())
+ return rewriter.notifyMatchFailure(op, "not a scalar!");
+
+ unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
+ unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
+
+ // Enforced by the common pre-conditions.
+ assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+
+ // TODO: Remove this condition - the assert above (and
+ // commonConversionPrecondtion) takes care of that.
+ if (multiByteBits < 8)
+ return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
+
+ // TODO: Add support other widths (when/if needed)
+ if (subByteBits != 2 && subByteBits != 4)
return rewriter.notifyMatchFailure(
- op, "only src bitwidth of 2 or 4 is supported at this moment");
+ op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
+
+ // Condition 1.
+ if (multiByteBits % subByteBits != 0)
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
- const int numSrcElemsPerByte = 8 / srcElemBitwidth;
- if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
+ // Condition 2.
+ if (!isSubByteVecFittable(subByteVecTy, containerTy))
return rewriter.notifyMatchFailure(
- op, "the trailing dimension of the input vector of sub-bytes must be a "
- "multiple of 8 / <sub-byte-width>");
+ op, "not possible to fit this sub-byte vector type into a vector of "
+ "the given multi-byte type");
return success();
}
@@ -1560,8 +1628,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
return failure();
// Check general alignment preconditions.
- if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
- conversionOp)))
+ Type containerType = rewriter.getI8Type();
+ if (failed(alignedConversionPrecondition(rewriter, srcVecType,
+ containerType, conversionOp)))
return failure();
// Perform the rewrite.
@@ -1625,8 +1694,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
// Check general alignment preconditions. We invert the src/dst type order
// to reuse the existing precondition logic.
- if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
- truncOp)))
+ Type containerType = rewriter.getI8Type();
+ if (failed(alignedConversionPrecondition(rewriter, dstVecType,
+ containerType, truncOp)))
return failure();
// Create a new iX -> i8 truncation op.
More information about the Mlir-commits
mailing list