[Mlir-commits] [mlir] [mlir][Vector] Refactor VectorEmulateNarrowType.cpp (PR #123529)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Sat Mar 15 12:22:32 PDT 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/123529
>From 696d7d2560dda52ac11422312ffde853363c8598 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 17 Jan 2025 13:54:34 +0000
Subject: [PATCH 1/4] [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N)
This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.
1. Update `alignedConversionPrecondition` (1):
This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).
2. Update `alignedConversionPrecondition` (2):
In #121298, we replaced `dstElemBitwidt` in this calculation:
```cpp
const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```
with the hard-coded value of 8:
```cpp
const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```
That was correct as for the patterns for which this hook was/is used:
* `RewriteAlignedSubByteIntExt`,
* `RewriteAlignedSubByteIntTrunc`.
The destination type (or, more precisely, the emulated type) was always
`i8`.
In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.
Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).
3. Update alignedConversionPrecondition (3):
The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).
NEXT STEPS (1):
We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.
For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:
```mlir
%0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```
However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:
```mlir
%bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```
I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.
NEXT STEPS (2):
With this PR, I am introducing explicit references to "sub-byte" as
that is effectively what this logic is used of (i.e. for emulating
"sub-byte" types). We should either generalise (which would include
increasing test coverage) or restrict everything to "sub-byte" type
emulation.
---
.../Transforms/VectorEmulateNarrowType.cpp | 134 +++++++++++++-----
1 file changed, 102 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 51e72753ff162..59ed3b5521470 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1091,6 +1091,38 @@ struct ConvertVectorMaskedLoad final
}
};
+/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
+///
+/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
+/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
+/// (a multi-byte scalar, e.g. i16), where N is some integer.
+///
+/// Put differently, this method checks whether this would be valid:
+///
+/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
+///
+/// EXAMPLES:
+/// * vector<4xi4> -> i16 - yes (N = 1)
+/// * vector<4xi4> -> i8 - yes (N = 2)
+/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
+/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
+static bool isSubByteVecFittable(VectorType subByteVecTy,
+ Type multiByteScalarTy) {
+ assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
+
+ int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
+ int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
+
+ assert(subByteBits < 8 && "Not a sub-byte scalar type!");
+ assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+ assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
+
+ int elemsPerMultiByte = multiByteBits / subByteBits;
+
+ // TODO: This is a bit too restrictive for vectors rank > 1.
+ return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
+}
+
//===----------------------------------------------------------------------===//
// ConvertVectorTransferRead
//===----------------------------------------------------------------------===//
@@ -1127,7 +1159,8 @@ struct ConvertVectorTransferRead final
auto origElements = op.getVectorType().getNumElements();
// Note, per-element-alignment was already verified above.
- bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+ bool isFullyAligned =
+ isSubByteVecFittable(op.getVectorType(), containerElemTy);
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
adaptor.getPadding());
@@ -1428,41 +1461,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
return commonConversionPrecondition(rewriter, preconditionType, op);
}
-/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
-/// means that:
-/// 1. The `dstType` element type is a multiple of the
-/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
-/// is not supported). Let this multiple be `N`.
-/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
-/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
-/// not supported).
+/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
+///
+/// Alignment means that `subByteVecTy` can be packed into a vector of
+/// `containerTy` elements. More specifically:
+/// 1. The bit-width of `containerTy` is a multiple of the
+/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
+/// this multiple is 4.
+/// 2. The multiple from 1. above divides evenly the number of the (trailing)
+/// elements in `subByteVecTy`.
+///
+/// EXAMPLE 1:
+/// `subByteVecTy = vector<2xi4>`, and
+/// `containerTy = i16`
+///
+/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
+///
+/// EXAMPLE 2:
+/// `subByteVecTy = vector<3xi4>`, and
+/// `containerTy = i16`
+///
+/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
+///
+/// EXAMPLE 3:
+/// `subByteVecTy = vector<3xi3>`, and
+/// `containerTy = i16`
+///
+/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
///
/// NOTE: This method assumes that common conversion preconditions are met. In
-/// particular, the element type of `dstType` is assumed to be a multi-byte
-/// type (e.g. i8, i16, i32).
+/// particular, `containerTy` is assumed to be a
+/// multi-byte scalar type (e.g., i8, i16, i32).
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
- VectorType subByteVecType,
- VectorType dstType,
+ VectorType subByteVecTy,
+ Type containerTy,
Operation *op) {
- if (!subByteVecType || !dstType)
- return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
- unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
- unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+ // TODO: This is validating the inputs rather than checking the conditions
+ // documented above. Replace with an assert.
+ if (!subByteVecTy)
+ return rewriter.notifyMatchFailure(op, "not a vector!");
- if (dstElemBitwidth < 8)
- return rewriter.notifyMatchFailure(
- op, "the bitwidth of dstType must be greater than or equal to 8");
- if (dstElemBitwidth % srcElemBitwidth != 0)
- return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
- if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
+ // TODO: This is validating the inputs rather than checking the conditions
+ // documented above. Replace with an assert.
+ if (!containerTy.isIntOrFloat())
+ return rewriter.notifyMatchFailure(op, "not a scalar!");
+
+ unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
+ unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
+
+ // Enforced by the common pre-conditions.
+ assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+
+ // TODO: Remove this condition - the assert above (and
+ // commonConversionPrecondtion) takes care of that.
+ if (multiByteBits < 8)
+ return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
+
+ // TODO: Add support other widths (when/if needed)
+ if (subByteBits != 2 && subByteBits != 4)
return rewriter.notifyMatchFailure(
- op, "only src bitwidth of 2 or 4 is supported at this moment");
+ op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
+
+ // Condition 1.
+ if (multiByteBits % subByteBits != 0)
+ return rewriter.notifyMatchFailure(op, "unalagined element types");
- const int numSrcElemsPerByte = 8 / srcElemBitwidth;
- if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
+ // Condition 2.
+ if (!isSubByteVecFittable(subByteVecTy, containerTy))
return rewriter.notifyMatchFailure(
- op, "the trailing dimension of the input vector of sub-bytes must be a "
- "multiple of 8 / <sub-byte-width>");
+ op, "not possible to fit this sub-byte vector type into a vector of "
+ "the given multi-byte type");
return success();
}
@@ -1899,8 +1967,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
return failure();
// Check general alignment preconditions.
- if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
- conversionOp)))
+ Type containerType = rewriter.getI8Type();
+ if (failed(alignedConversionPrecondition(rewriter, srcVecType,
+ containerType, conversionOp)))
return failure();
// Perform the rewrite.
@@ -1964,8 +2033,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
// Check general alignment preconditions. We invert the src/dst type order
// to reuse the existing precondition logic.
- if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
- truncOp)))
+ Type containerType = rewriter.getI8Type();
+ if (failed(alignedConversionPrecondition(rewriter, dstVecType,
+ containerType, truncOp)))
return failure();
// Create a new iX -> i8 truncation op.
>From aa17f5eb50c387638841af9df632544c13912d09 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 10 Mar 2025 16:21:12 +0000
Subject: [PATCH 2/4] fixup! [mlir][Vector] Update VectorEmulateNarrowType.cpp
(4/N)
Address comments from Alan
---
.../Transforms/VectorEmulateNarrowType.cpp | 37 ++++++++-----------
1 file changed, 15 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 59ed3b5521470..649d73a0a460f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1106,8 +1106,8 @@ struct ConvertVectorMaskedLoad final
/// * vector<4xi4> -> i8 - yes (N = 2)
/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
-static bool isSubByteVecFittable(VectorType subByteVecTy,
- Type multiByteScalarTy) {
+static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
+ Type multiByteScalarTy) {
assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
@@ -1160,7 +1160,7 @@ struct ConvertVectorTransferRead final
// Note, per-element-alignment was already verified above.
bool isFullyAligned =
- isSubByteVecFittable(op.getVectorType(), containerElemTy);
+ fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
adaptor.getPadding());
@@ -1496,38 +1496,31 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
VectorType subByteVecTy,
Type containerTy,
Operation *op) {
+ assert(containerTy.isIntOrFloat() &&
+ "container element type is not a scalar");
+
// TODO: This is validating the inputs rather than checking the conditions
// documented above. Replace with an assert.
if (!subByteVecTy)
return rewriter.notifyMatchFailure(op, "not a vector!");
- // TODO: This is validating the inputs rather than checking the conditions
- // documented above. Replace with an assert.
- if (!containerTy.isIntOrFloat())
- return rewriter.notifyMatchFailure(op, "not a scalar!");
-
unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
// Enforced by the common pre-conditions.
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
- // TODO: Remove this condition - the assert above (and
- // commonConversionPrecondtion) takes care of that.
- if (multiByteBits < 8)
- return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
-
// TODO: Add support other widths (when/if needed)
if (subByteBits != 2 && subByteBits != 4)
return rewriter.notifyMatchFailure(
op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
- // Condition 1.
+ // Condition 1 ("per-element" alignment)
if (multiByteBits % subByteBits != 0)
return rewriter.notifyMatchFailure(op, "unalagined element types");
- // Condition 2.
- if (!isSubByteVecFittable(subByteVecTy, containerTy))
+ // Condition 2 ("full" alignment)
+ if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
return rewriter.notifyMatchFailure(
op, "not possible to fit this sub-byte vector type into a vector of "
"the given multi-byte type");
@@ -1967,9 +1960,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
return failure();
// Check general alignment preconditions.
- Type containerType = rewriter.getI8Type();
- if (failed(alignedConversionPrecondition(rewriter, srcVecType,
- containerType, conversionOp)))
+ if (failed(alignedConversionPrecondition(
+ rewriter, srcVecType,
+ /*containerTy=*/rewriter.getI8Type(), conversionOp)))
return failure();
// Perform the rewrite.
@@ -2033,9 +2026,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
// Check general alignment preconditions. We invert the src/dst type order
// to reuse the existing precondition logic.
- Type containerType = rewriter.getI8Type();
- if (failed(alignedConversionPrecondition(rewriter, dstVecType,
- containerType, truncOp)))
+ if (failed(alignedConversionPrecondition(
+ rewriter, dstVecType,
+ /*containerTy=*/rewriter.getI8Type(), truncOp)))
return failure();
// Create a new iX -> i8 truncation op.
>From 3abe26ca6a3c7579190e91be807e1d07ecc2c234 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 11 Mar 2025 19:57:08 +0000
Subject: [PATCH 3/4] fixup! fixup! [mlir][Vector] Update
VectorEmulateNarrowType.cpp (4/N)
isFullyAligned -> isDivisibleInSize
---
.../Transforms/VectorEmulateNarrowType.cpp | 36 +++++++++----------
1 file changed, 18 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 649d73a0a460f..1b274095dc625 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -519,7 +519,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto origElements = valueToStore.getType().getNumElements();
// Note, per-element-alignment was already verified above.
- bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+ bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -535,8 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedNumFrontPadElems =
- isFullyAligned ? 0
- : getConstantIntValue(linearizedInfo.intraDataOffset);
+ isDivisibleInSize ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
if (!foldedNumFrontPadElems) {
return rewriter.notifyMatchFailure(
@@ -554,7 +554,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// need unaligned emulation because the store address is aligned and the
// source is a whole byte.
bool emulationRequiresPartialStores =
- !isFullyAligned || *foldedNumFrontPadElems != 0;
+ !isDivisibleInSize || *foldedNumFrontPadElems != 0;
if (!emulationRequiresPartialStores) {
// Basic case: storing full bytes.
auto numElements = origElements / emulatedPerContainerElem;
@@ -881,7 +881,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
auto origElements = op.getVectorType().getNumElements();
// Note, per-element-alignment was already verified above.
- bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+ bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -897,8 +897,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedIntraVectorOffset =
- isFullyAligned ? 0
- : getConstantIntValue(linearizedInfo.intraDataOffset);
+ isDivisibleInSize ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
// Always load enough elements which can cover the original elements.
int64_t maxintraDataOffset =
@@ -915,7 +915,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
result = dynamicallyExtractSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
- } else if (!isFullyAligned) {
+ } else if (!isDivisibleInSize) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
@@ -1002,7 +1002,7 @@ struct ConvertVectorMaskedLoad final
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
// Note, per-element-alignment was already verified above.
- bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+ bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -1017,8 +1017,8 @@ struct ConvertVectorMaskedLoad final
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedIntraVectorOffset =
- isFullyAligned ? 0
- : getConstantIntValue(linearizedInfo.intraDataOffset);
+ isDivisibleInSize ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
int64_t maxIntraDataOffset =
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1042,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
passthru = dynamicallyInsertSubVector(
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
origElements);
- } else if (!isFullyAligned) {
+ } else if (!isDivisibleInSize) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
}
@@ -1070,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
linearizedInfo.intraDataOffset,
origElements);
- } else if (!isFullyAligned) {
+ } else if (!isDivisibleInSize) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
*foldedIntraVectorOffset);
}
@@ -1081,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
result = dynamicallyExtractSubVector(
rewriter, loc, result, op.getPassThru(),
linearizedInfo.intraDataOffset, origElements);
- } else if (!isFullyAligned) {
+ } else if (!isDivisibleInSize) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
@@ -1159,7 +1159,7 @@ struct ConvertVectorTransferRead final
auto origElements = op.getVectorType().getNumElements();
// Note, per-element-alignment was already verified above.
- bool isFullyAligned =
+ bool isDivisibleInSize =
fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
@@ -1179,8 +1179,8 @@ struct ConvertVectorTransferRead final
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedIntraVectorOffset =
- isFullyAligned ? 0
- : getConstantIntValue(linearizedInfo.intraDataOffset);
+ isDivisibleInSize ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
int64_t maxIntraDataOffset =
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1204,7 +1204,7 @@ struct ConvertVectorTransferRead final
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
linearizedInfo.intraDataOffset,
origElements);
- } else if (!isFullyAligned) {
+ } else if (!isDivisibleInSize) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
>From fa0163977523841183b2d9167ed004035f11675a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 15 Mar 2025 19:22:03 +0000
Subject: [PATCH 4/4] fixup! fixup! fixup! [mlir][Vector] Update
VectorEmulateNarrowType.cpp (4/N)
Add minor missing re-naming (otherwise there are inconsistent names remaining))
---
.../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 1b274095dc625..cf6efaa04ae44 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1505,10 +1505,10 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
return rewriter.notifyMatchFailure(op, "not a vector!");
unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
- unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
+ unsigned containerBits = containerTy.getIntOrFloatBitWidth();
// Enforced by the common pre-conditions.
- assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+ assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
// TODO: Add support other widths (when/if needed)
if (subByteBits != 2 && subByteBits != 4)
@@ -1516,7 +1516,7 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
// Condition 1 ("per-element" alignment)
- if (multiByteBits % subByteBits != 0)
+ if (containerBits % subByteBits != 0)
return rewriter.notifyMatchFailure(op, "unalagined element types");
// Condition 2 ("full" alignment)
More information about the Mlir-commits
mailing list