[Mlir-commits] [mlir] [mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N) (PR #123526)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Jan 24 04:35:10 PST 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/123526
>From 5edc3423c1da0c62c8a7bb5fa3e9d54855bdf3bf Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 18 Jan 2025 21:45:51 +0000
Subject: [PATCH 1/2] [mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N)
This is PR 1 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.
This PR renames `srcBits/dstBits` to `oldBits/newBits` to improve
consistency in naming within the file. This is illustrated below:
```cpp
// Extracted from VectorEmulateNarrowType.cpp
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
// BEFORE (mixing old/new and src/dst):
// int srcBits = oldElementType.getIntOrFloatBitWidth();
// int dstBits = newElementType.getIntOrFloatBitWidth();
// AFTER (consistently using old/new):
int oldBits = oldElementType.getIntOrFloatBitWidth();
int newBits = newElementType.getIntOrFloatBitWidth();
```
Also adds some comments and unifies related "rewriter notification"
messages.
---
.../Transforms/VectorEmulateNarrowType.cpp | 70 +++++++++----------
1 file changed, 35 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..70d50e1d48040c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -314,14 +314,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -346,7 +346,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -385,15 +385,15 @@ struct ConvertVectorMaskedStore final
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
int origElements = op.getValueToStore().getType().getNumElements();
if (origElements % scale != 0)
return failure();
@@ -404,7 +404,7 @@ struct ConvertVectorMaskedStore final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndicesOfr) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -493,14 +493,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -541,7 +541,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -596,14 +596,14 @@ struct ConvertVectorMaskedLoad final
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -657,7 +657,7 @@ struct ConvertVectorMaskedLoad final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -758,14 +758,14 @@ struct ConvertVectorTransferRead final
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = newElementType.getIntOrFloatBitWidth();
+ int oldBits = oldElementType.getIntOrFloatBitWidth();
+ int newBits = newElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
+ // Check per-element alignment.
+ if (newBits % oldBits != 0) {
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = dstBits / srcBits;
+ int scale = newBits / oldBits;
auto origElements = op.getVectorType().getNumElements();
@@ -781,7 +781,7 @@ struct ConvertVectorTransferRead final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, dstBits,
+ rewriter, loc, oldBits, newBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
>From cda648efccd3153ef52fe6cef3d6d58b8258237c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 24 Jan 2025 12:34:06 +0000
Subject: [PATCH 2/2] fixup! [mlir][Vector] Update VectorEmulateNarrowType.cpp
(1/N)
Introduce emulatedType and containerType
---
.../Transforms/VectorEmulateNarrowType.cpp | 88 +++++++++----------
1 file changed, 44 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 70d50e1d48040c..1ca72c3ff01213 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -311,17 +311,17 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
"only 1-D vectors are supported ATM");
auto loc = op.getLoc();
- auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getValueToStore().getType().getElementType();
- Type newElementType = convertedType.getElementType();
- int oldBits = oldElementType.getIntOrFloatBitWidth();
- int newBits = newElementType.getIntOrFloatBitWidth();
+ auto containerElemTy = cast<MemRefType>(adaptor.getBase().getType());
+ Type emulatedElemTy = op.getValueToStore().getType().getElementType();
+ Type newElementType = containerElemTy.getElementType();
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+ int containerBits = newElementType.getIntOrFloatBitWidth();
// Check per-element alignment.
- if (newBits % oldBits != 0) {
+ if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = newBits / oldBits;
+ int scale = containerBits / emulatedBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -346,7 +346,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, oldBits, newBits,
+ rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -382,18 +382,18 @@ struct ConvertVectorMaskedStore final
"only 1-D vectors are supported ATM");
auto loc = op.getLoc();
- auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getValueToStore().getType().getElementType();
- Type newElementType = convertedType.getElementType();
- int oldBits = oldElementType.getIntOrFloatBitWidth();
- int newBits = newElementType.getIntOrFloatBitWidth();
+ auto containerElemTy = cast<MemRefType>(adaptor.getBase().getType());
+ Type emulatedElemTy = op.getValueToStore().getType().getElementType();
+ Type newElementType = containerElemTy.getElementType();
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+ int containerBits = newElementType.getIntOrFloatBitWidth();
// Check per-element alignment.
- if (newBits % oldBits != 0) {
+ if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = newBits / oldBits;
+ int scale = containerBits / emulatedBits;
int origElements = op.getValueToStore().getType().getNumElements();
if (origElements % scale != 0)
return failure();
@@ -404,7 +404,7 @@ struct ConvertVectorMaskedStore final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndicesOfr) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, oldBits, newBits,
+ rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -458,7 +458,7 @@ struct ConvertVectorMaskedStore final
loc, newType, adaptor.getBase(), linearizedIndices,
newMask.value()->getResult(0), passThru);
- auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+ auto newBitCastType = VectorType::get(numElements * scale, emulatedElemTy);
Value valueToStore =
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
valueToStore = rewriter.create<arith::SelectOp>(
@@ -490,17 +490,17 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
"only 1-D vectors are supported ATM");
auto loc = op.getLoc();
- auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getType().getElementType();
- Type newElementType = convertedType.getElementType();
- int oldBits = oldElementType.getIntOrFloatBitWidth();
- int newBits = newElementType.getIntOrFloatBitWidth();
+ auto containerElemTy = cast<MemRefType>(adaptor.getBase().getType());
+ Type emulatedElemTy = op.getType().getElementType();
+ Type newElementType = containerElemTy.getElementType();
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+ int containerBits = newElementType.getIntOrFloatBitWidth();
// Check per-element alignment.
- if (newBits % oldBits != 0) {
+ if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = newBits / oldBits;
+ int scale = containerBits / emulatedBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -541,7 +541,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, oldBits, newBits,
+ rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -558,7 +558,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
llvm::divideCeil(maxintraDataOffset + origElements, scale);
Value result =
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
- numElements, oldElementType, newElementType);
+ numElements, emulatedElemTy, newElementType);
if (!foldedIntraVectorOffset) {
auto resultVector = rewriter.create<arith::ConstantOp>(
@@ -593,17 +593,17 @@ struct ConvertVectorMaskedLoad final
"only 1-D vectors are supported ATM");
auto loc = op.getLoc();
- auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getType().getElementType();
- Type newElementType = convertedType.getElementType();
- int oldBits = oldElementType.getIntOrFloatBitWidth();
- int newBits = newElementType.getIntOrFloatBitWidth();
+ auto containerElemTy = cast<MemRefType>(adaptor.getBase().getType());
+ Type emulatedElemTy = op.getType().getElementType();
+ Type newElementType = containerElemTy.getElementType();
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+ int containerBits = newElementType.getIntOrFloatBitWidth();
// Check per-element alignment.
- if (newBits % oldBits != 0) {
+ if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = newBits / oldBits;
+ int scale = containerBits / emulatedBits;
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
@@ -657,7 +657,7 @@ struct ConvertVectorMaskedLoad final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, oldBits, newBits,
+ rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -679,7 +679,7 @@ struct ConvertVectorMaskedLoad final
auto numElements =
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto loadType = VectorType::get(numElements, newElementType);
- auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+ auto newBitcastType = VectorType::get(numElements * scale, emulatedElemTy);
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -755,17 +755,17 @@ struct ConvertVectorTransferRead final
"only 1-D vectors are supported ATM");
auto loc = op.getLoc();
- auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
- Type oldElementType = op.getType().getElementType();
- Type newElementType = convertedType.getElementType();
- int oldBits = oldElementType.getIntOrFloatBitWidth();
- int newBits = newElementType.getIntOrFloatBitWidth();
+ auto containerElemTy = cast<MemRefType>(adaptor.getSource().getType());
+ Type emulatedElemTy = op.getType().getElementType();
+ Type newElementType = containerElemTy.getElementType();
+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+ int containerBits = newElementType.getIntOrFloatBitWidth();
// Check per-element alignment.
- if (newBits % oldBits != 0) {
+ if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(op, "unalagined element types");
}
- int scale = newBits / oldBits;
+ int scale = containerBits / emulatedBits;
auto origElements = op.getVectorType().getNumElements();
@@ -781,7 +781,7 @@ struct ConvertVectorTransferRead final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, oldBits, newBits,
+ rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
@@ -802,7 +802,7 @@ struct ConvertVectorTransferRead final
newPadding);
auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements * scale, oldElementType), newRead);
+ loc, VectorType::get(numElements * scale, emulatedElemTy), newRead);
Value result = bitCast->getResult(0);
if (!foldedIntraVectorOffset) {
More information about the Mlir-commits
mailing list