[Mlir-commits] [mlir] [mlir][Vector] Refactor VectorEmulateNarrowType.cpp (PR #123529)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sat Mar 15 12:22:32 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/123529

>From 696d7d2560dda52ac11422312ffde853363c8598 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 17 Jan 2025 13:54:34 +0000
Subject: [PATCH 1/4] [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N)

This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

2. Update `alignedConversionPrecondition` (2):

In #121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

3. Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).

NEXT STEPS (1):

We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.

For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:

```mlir
  %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```

However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:

```mlir
  %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```

I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.

NEXT STEPS (2):

With this PR, I am introducing explicit references to "sub-byte" as
that is effectively what this logic is used of (i.e. for emulating
"sub-byte" types). We should either generalise (which would include
increasing test coverage) or restrict everything to "sub-byte" type
emulation.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 134 +++++++++++++-----
 1 file changed, 102 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 51e72753ff162..59ed3b5521470 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1091,6 +1091,38 @@ struct ConvertVectorMaskedLoad final
   }
 };
 
+/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
+///
+/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
+/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
+/// (a multi-byte scalar, e.g. i16), where N is some integer.
+///
+/// Put differently, this method checks whether this would be valid:
+///
+///   vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
+///
+/// EXAMPLES:
+///   * vector<4xi4> -> i16 - yes (N = 1)
+///   * vector<4xi4> -> i8 - yes (N = 2)
+///   * vector<3xi4> -> i8 - no (N would have to be 1.5)
+///   * vector<3xi2> -> i16 - no (N would have to be 0.5)
+static bool isSubByteVecFittable(VectorType subByteVecTy,
+                                 Type multiByteScalarTy) {
+  assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
+
+  int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
+  int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
+
+  assert(subByteBits < 8 && "Not a sub-byte scalar type!");
+  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+  assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
+
+  int elemsPerMultiByte = multiByteBits / subByteBits;
+
+  // TODO: This is a bit too restrictive for vectors rank > 1.
+  return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
+}
+
 //===----------------------------------------------------------------------===//
 // ConvertVectorTransferRead
 //===----------------------------------------------------------------------===//
@@ -1127,7 +1159,8 @@ struct ConvertVectorTransferRead final
     auto origElements = op.getVectorType().getNumElements();
 
     // Note, per-element-alignment was already verified above.
-    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+    bool isFullyAligned =
+        isSubByteVecFittable(op.getVectorType(), containerElemTy);
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
                                                       adaptor.getPadding());
@@ -1428,41 +1461,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
   return commonConversionPrecondition(rewriter, preconditionType, op);
 }
 
-/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
-/// means that:
-///   1. The `dstType` element type is a multiple of the
-///   `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
-///   is not supported). Let this multiple be `N`.
-///   2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
-///   multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
-///   not supported).
+/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
+///
+/// Alignment means that `subByteVecTy` can be packed into a vector of
+/// `containerTy` elements. More specifically:
+///   1. The bit-width of `containerTy` is a multiple of the
+///      bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
+///      this multiple is 4.
+///   2. The multiple from 1. above divides evenly the number of the (trailing)
+///      elements in `subByteVecTy`.
+///
+/// EXAMPLE 1:
+///   `subByteVecTy = vector<2xi4>`, and
+///   `containerTy = i16`
+///
+/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
+///
+/// EXAMPLE 2:
+///   `subByteVecTy = vector<3xi4>`, and
+///   `containerTy = i16`
+///
+/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
+///
+/// EXAMPLE 3:
+///   `subByteVecTy = vector<3xi3>`, and
+///   `containerTy = i16`
+///
+/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
 ///
 /// NOTE: This method assumes that common conversion preconditions are met. In
-/// particular, the element type of `dstType` is assumed to be a multi-byte
-/// type (e.g. i8, i16, i32).
+/// particular, `containerTy` is assumed to be a
+/// multi-byte scalar type (e.g., i8, i16, i32).
 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
-                                                   VectorType subByteVecType,
-                                                   VectorType dstType,
+                                                   VectorType subByteVecTy,
+                                                   Type containerTy,
                                                    Operation *op) {
-  if (!subByteVecType || !dstType)
-    return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
-  unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
-  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+  // TODO: This is validating the inputs rather than checking the conditions
+  // documented above. Replace with an assert.
+  if (!subByteVecTy)
+    return rewriter.notifyMatchFailure(op, "not a vector!");
 
-  if (dstElemBitwidth < 8)
-    return rewriter.notifyMatchFailure(
-        op, "the bitwidth of dstType must be greater than or equal to 8");
-  if (dstElemBitwidth % srcElemBitwidth != 0)
-    return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
-  if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
+  // TODO: This is validating the inputs rather than checking the conditions
+  // documented above. Replace with an assert.
+  if (!containerTy.isIntOrFloat())
+    return rewriter.notifyMatchFailure(op, "not a scalar!");
+
+  unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
+  unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
+
+  // Enforced by the common pre-conditions.
+  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+
+  // TODO: Remove this condition - the assert above (and
+  // commonConversionPrecondtion) takes care of that.
+  if (multiByteBits < 8)
+    return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
+
+  // TODO: Add support other widths (when/if needed)
+  if (subByteBits != 2 && subByteBits != 4)
     return rewriter.notifyMatchFailure(
-        op, "only src bitwidth of 2 or 4 is supported at this moment");
+        op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
+
+  // Condition 1.
+  if (multiByteBits % subByteBits != 0)
+    return rewriter.notifyMatchFailure(op, "unalagined element types");
 
-  const int numSrcElemsPerByte = 8 / srcElemBitwidth;
-  if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
+  // Condition 2.
+  if (!isSubByteVecFittable(subByteVecTy, containerTy))
     return rewriter.notifyMatchFailure(
-        op, "the trailing dimension of the input vector of sub-bytes must be a "
-            "multiple of 8 / <sub-byte-width>");
+        op, "not possible to fit this sub-byte vector type into a vector of "
+            "the given multi-byte type");
 
   return success();
 }
@@ -1899,8 +1967,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
       return failure();
 
     // Check general alignment preconditions.
-    if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
-                                             conversionOp)))
+    Type containerType = rewriter.getI8Type();
+    if (failed(alignedConversionPrecondition(rewriter, srcVecType,
+                                             containerType, conversionOp)))
       return failure();
 
     // Perform the rewrite.
@@ -1964,8 +2033,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
 
     // Check general alignment preconditions. We invert the src/dst type order
     // to reuse the existing precondition logic.
-    if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
-                                             truncOp)))
+    Type containerType = rewriter.getI8Type();
+    if (failed(alignedConversionPrecondition(rewriter, dstVecType,
+                                             containerType, truncOp)))
       return failure();
 
     // Create a new iX -> i8 truncation op.

>From aa17f5eb50c387638841af9df632544c13912d09 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 10 Mar 2025 16:21:12 +0000
Subject: [PATCH 2/4] fixup! [mlir][Vector] Update VectorEmulateNarrowType.cpp
 (4/N)

Address comments from Alan
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 37 ++++++++-----------
 1 file changed, 15 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 59ed3b5521470..649d73a0a460f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1106,8 +1106,8 @@ struct ConvertVectorMaskedLoad final
 ///   * vector<4xi4> -> i8 - yes (N = 2)
 ///   * vector<3xi4> -> i8 - no (N would have to be 1.5)
 ///   * vector<3xi2> -> i16 - no (N would have to be 0.5)
-static bool isSubByteVecFittable(VectorType subByteVecTy,
-                                 Type multiByteScalarTy) {
+static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
+                                       Type multiByteScalarTy) {
   assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
 
   int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
@@ -1160,7 +1160,7 @@ struct ConvertVectorTransferRead final
 
     // Note, per-element-alignment was already verified above.
     bool isFullyAligned =
-        isSubByteVecFittable(op.getVectorType(), containerElemTy);
+        fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
                                                       adaptor.getPadding());
@@ -1496,38 +1496,31 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
                                                    VectorType subByteVecTy,
                                                    Type containerTy,
                                                    Operation *op) {
+  assert(containerTy.isIntOrFloat() &&
+         "container element type is not a scalar");
+
   // TODO: This is validating the inputs rather than checking the conditions
   // documented above. Replace with an assert.
   if (!subByteVecTy)
     return rewriter.notifyMatchFailure(op, "not a vector!");
 
-  // TODO: This is validating the inputs rather than checking the conditions
-  // documented above. Replace with an assert.
-  if (!containerTy.isIntOrFloat())
-    return rewriter.notifyMatchFailure(op, "not a scalar!");
-
   unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
   unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
 
   // Enforced by the common pre-conditions.
   assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
 
-  // TODO: Remove this condition - the assert above (and
-  // commonConversionPrecondtion) takes care of that.
-  if (multiByteBits < 8)
-    return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
-
   // TODO: Add support other widths (when/if needed)
   if (subByteBits != 2 && subByteBits != 4)
     return rewriter.notifyMatchFailure(
         op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
 
-  // Condition 1.
+  // Condition 1 ("per-element" alignment)
   if (multiByteBits % subByteBits != 0)
     return rewriter.notifyMatchFailure(op, "unalagined element types");
 
-  // Condition 2.
-  if (!isSubByteVecFittable(subByteVecTy, containerTy))
+  // Condition 2 ("full" alignment)
+  if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
     return rewriter.notifyMatchFailure(
         op, "not possible to fit this sub-byte vector type into a vector of "
             "the given multi-byte type");
@@ -1967,9 +1960,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
       return failure();
 
     // Check general alignment preconditions.
-    Type containerType = rewriter.getI8Type();
-    if (failed(alignedConversionPrecondition(rewriter, srcVecType,
-                                             containerType, conversionOp)))
+    if (failed(alignedConversionPrecondition(
+            rewriter, srcVecType,
+            /*containerTy=*/rewriter.getI8Type(), conversionOp)))
       return failure();
 
     // Perform the rewrite.
@@ -2033,9 +2026,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
 
     // Check general alignment preconditions. We invert the src/dst type order
     // to reuse the existing precondition logic.
-    Type containerType = rewriter.getI8Type();
-    if (failed(alignedConversionPrecondition(rewriter, dstVecType,
-                                             containerType, truncOp)))
+    if (failed(alignedConversionPrecondition(
+            rewriter, dstVecType,
+            /*containerTy=*/rewriter.getI8Type(), truncOp)))
       return failure();
 
     // Create a new iX -> i8 truncation op.

>From 3abe26ca6a3c7579190e91be807e1d07ecc2c234 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 11 Mar 2025 19:57:08 +0000
Subject: [PATCH 3/4] fixup! fixup! [mlir][Vector] Update
 VectorEmulateNarrowType.cpp (4/N)

isFullyAligned -> isDivisibleInSize
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 36 +++++++++----------
 1 file changed, 18 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 649d73a0a460f..1b274095dc625 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -519,7 +519,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 
     auto origElements = valueToStore.getType().getNumElements();
     // Note, per-element-alignment was already verified above.
-    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+    bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -535,8 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedNumFrontPadElems =
-        isFullyAligned ? 0
-                       : getConstantIntValue(linearizedInfo.intraDataOffset);
+        isDivisibleInSize ? 0
+                          : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     if (!foldedNumFrontPadElems) {
       return rewriter.notifyMatchFailure(
@@ -554,7 +554,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // need unaligned emulation because the store address is aligned and the
     // source is a whole byte.
     bool emulationRequiresPartialStores =
-        !isFullyAligned || *foldedNumFrontPadElems != 0;
+        !isDivisibleInSize || *foldedNumFrontPadElems != 0;
     if (!emulationRequiresPartialStores) {
       // Basic case: storing full bytes.
       auto numElements = origElements / emulatedPerContainerElem;
@@ -881,7 +881,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
 
     auto origElements = op.getVectorType().getNumElements();
     // Note, per-element-alignment was already verified above.
-    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+    bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -897,8 +897,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isFullyAligned ? 0
-                       : getConstantIntValue(linearizedInfo.intraDataOffset);
+        isDivisibleInSize ? 0
+                          : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     // Always load enough elements which can cover the original elements.
     int64_t maxintraDataOffset =
@@ -915,7 +915,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
       result = dynamicallyExtractSubVector(
           rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
           linearizedInfo.intraDataOffset, origElements);
-    } else if (!isFullyAligned) {
+    } else if (!isDivisibleInSize) {
       result = staticallyExtractSubvector(
           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
@@ -1002,7 +1002,7 @@ struct ConvertVectorMaskedLoad final
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
     // Note, per-element-alignment was already verified above.
-    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+    bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -1017,8 +1017,8 @@ struct ConvertVectorMaskedLoad final
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isFullyAligned ? 0
-                       : getConstantIntValue(linearizedInfo.intraDataOffset);
+        isDivisibleInSize ? 0
+                          : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     int64_t maxIntraDataOffset =
         foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1042,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
       passthru = dynamicallyInsertSubVector(
           rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
           origElements);
-    } else if (!isFullyAligned) {
+    } else if (!isDivisibleInSize) {
       passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
                                            *foldedIntraVectorOffset);
     }
@@ -1070,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
       mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
                                         linearizedInfo.intraDataOffset,
                                         origElements);
-    } else if (!isFullyAligned) {
+    } else if (!isDivisibleInSize) {
       mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
                                        *foldedIntraVectorOffset);
     }
@@ -1081,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
       result = dynamicallyExtractSubVector(
           rewriter, loc, result, op.getPassThru(),
           linearizedInfo.intraDataOffset, origElements);
-    } else if (!isFullyAligned) {
+    } else if (!isDivisibleInSize) {
       result = staticallyExtractSubvector(
           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
@@ -1159,7 +1159,7 @@ struct ConvertVectorTransferRead final
     auto origElements = op.getVectorType().getNumElements();
 
     // Note, per-element-alignment was already verified above.
-    bool isFullyAligned =
+    bool isDivisibleInSize =
         fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
@@ -1179,8 +1179,8 @@ struct ConvertVectorTransferRead final
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isFullyAligned ? 0
-                       : getConstantIntValue(linearizedInfo.intraDataOffset);
+        isDivisibleInSize ? 0
+                          : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     int64_t maxIntraDataOffset =
         foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1204,7 +1204,7 @@ struct ConvertVectorTransferRead final
       result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
                                            linearizedInfo.intraDataOffset,
                                            origElements);
-    } else if (!isFullyAligned) {
+    } else if (!isDivisibleInSize) {
       result = staticallyExtractSubvector(
           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }

>From fa0163977523841183b2d9167ed004035f11675a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 15 Mar 2025 19:22:03 +0000
Subject: [PATCH 4/4] fixup! fixup! fixup! [mlir][Vector] Update
 VectorEmulateNarrowType.cpp (4/N)

Add minor missing re-naming (otherwise there are inconsistent names remaining))
---
 .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp   | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 1b274095dc625..cf6efaa04ae44 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1505,10 +1505,10 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
     return rewriter.notifyMatchFailure(op, "not a vector!");
 
   unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
-  unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
+  unsigned containerBits = containerTy.getIntOrFloatBitWidth();
 
   // Enforced by the common pre-conditions.
-  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+  assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
 
   // TODO: Add support other widths (when/if needed)
   if (subByteBits != 2 && subByteBits != 4)
@@ -1516,7 +1516,7 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
         op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
 
   // Condition 1 ("per-element" alignment)
-  if (multiByteBits % subByteBits != 0)
+  if (containerBits % subByteBits != 0)
     return rewriter.notifyMatchFailure(op, "unalagined element types");
 
   // Condition 2 ("full" alignment)



More information about the Mlir-commits mailing list