[Mlir-commits] [mlir] [mlir][Vector] Update VectorEmulateNarrowType.cpp (PR #123633)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Jan 20 07:38:48 PST 2025


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/123633

This  PR aims at improving "VectorEmulateNarrowType.cpp". This is mainly
minor refactoring, no major functional changes are made/added.
Implements #123630.

**CHANGE 1**
Renames `srcBits/dstBits` to `oldBits/newBits` to improve consistency in
naming within the file. This is illustrated below:

```cpp
  // Extracted from VectorEmulateNarrowType.cpp
  Type oldElementType = op.getType().getElementType();
  Type newElementType = convertedType.getElementType();

  // BEFORE (mixing old/new and src/dst):
  // int srcBits = oldElementType.getIntOrFloatBitWidth();
  // int dstBits = newElementType.getIntOrFloatBitWidth();

  // AFTER (consistently using old/new):
  int oldBits = oldElementType.getIntOrFloatBitWidth();
  int newBits = newElementType.getIntOrFloatBitWidth();
```

Also adds some comments and unifies related "rewriter notification"
messages.

**CHANGE 2**
Renames the variable "scale". Note, "scale" could mean either:

  * "original-elements-per-emulated-type", or
  * "emulated-elements-per-original-type".

While from the context it is clear that it's always the former (original
type is always a sub-byte type and the emulated type is usually `i8`),
this PR reduces the cognitive load by making this clear.

**CHANGE 3**
Replaces `isUnalignedEmulation` with `isFullyAligned`

Note, `isUnalignedEmulation` is always computed following a
"per-element-alignment" condition:
```cpp
// Check per-element alignment.
if (newBits % oldBits != 0) {
  return rewriter.notifyMatchFailure(op, "unalagined element types");
}

// (...)

bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
```

Given that `isUnalignedEmulation` captures only one of two conditions
required for "full alignment", it should be re-named as
`isPartiallyUnalignedEmulation`. Instead, I've flipped the condition and
renamed it as `isFullyAligned`:

```cpp
bool isFullyAligned = origElements % elementsPerContainerType == 0;
```

**CHANGE 4**
Unifies various comments throughout the file (for consistency).

**CHANGE 5**
Adds new comments throughout the file and adds TODOs where high-level
comments are missing.

**CHANGE 6**
Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

**CHANGE 7**
Update `alignedConversionPrecondition` (2):

In #121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

**CHANGE 8**
Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).


>From 83411711cd1c42d1cfcff4bf17869235009dfceb Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 18 Jan 2025 21:45:51 +0000
Subject: [PATCH] [mlir][Vector] Update VectorEmulateNarrowType.cpp

This  PR aims at improving "VectorEmulateNarrowType.cpp". This is mainly
minor refactoring, no major functional changes are made/added.
Implements #123630.

**CHANGE 1**
Renames `srcBits/dstBits` to `oldBits/newBits` to improve consistency in
naming within the file. This is illustrated below:

```cpp
  // Extracted from VectorEmulateNarrowType.cpp
  Type oldElementType = op.getType().getElementType();
  Type newElementType = convertedType.getElementType();

  // BEFORE (mixing old/new and src/dst):
  // int srcBits = oldElementType.getIntOrFloatBitWidth();
  // int dstBits = newElementType.getIntOrFloatBitWidth();

  // AFTER (consistently using old/new):
  int oldBits = oldElementType.getIntOrFloatBitWidth();
  int newBits = newElementType.getIntOrFloatBitWidth();
```

Also adds some comments and unifies related "rewriter notification"
messages.

**CHANGE 2**
Renames the variable "scale". Note, "scale" could mean either:

  * "original-elements-per-emulated-type", or
  * "emulated-elements-per-original-type".

While from the context it is clear that it's always the former (original
type is always a sub-byte type and the emulated type is usually `i8`),
this PR reduces the cognitive load by making this clear.

**CHANGE 3**
Replaces `isUnalignedEmulation` with `isFullyAligned`

Note, `isUnalignedEmulation` is always computed following a
"per-element-alignment" condition:
```cpp
// Check per-element alignment.
if (newBits % oldBits != 0) {
  return rewriter.notifyMatchFailure(op, "unalagined element types");
}

// (...)

bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
```

Given that `isUnalignedEmulation` captures only one of two conditions
required for "full alignment", it should be re-named as
`isPartiallyUnalignedEmulation`. Instead, I've flipped the condition and
renamed it as `isFullyAligned`:

```cpp
bool isFullyAligned = origElements % elementsPerContainerType == 0;
```

**CHANGE 4**
Unifies various comments throughout the file (for consistency).

**CHANGE 5**
Adds new comments throughout the file and adds TODOs where high-level
comments are missing.

**CHANGE 6**
Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

**CHANGE 7**
Update `alignedConversionPrecondition` (2):

In #121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

**CHANGE 8**
Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 370 +++++++++++-------
 1 file changed, 234 insertions(+), 136 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..373b8a8822318f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -45,6 +45,10 @@ using namespace mlir;
 #define DBGSNL() (llvm::dbgs() << "\n")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
+//===----------------------------------------------------------------------===//
+// Utils
+//===----------------------------------------------------------------------===//
+
 /// Returns a compressed mask for the emulated vector. For example, when
 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
 /// elements span two dest elements), this method compresses `vector<8xi1>`
@@ -282,13 +286,15 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
                    OpFoldResult linearizedIndices,
                    int64_t numEmultedElementsToLoad, Type origElemType,
                    Type emulatedElemType) {
-  auto scale = emulatedElemType.getIntOrFloatBitWidth() /
-               origElemType.getIntOrFloatBitWidth();
+  auto elementsPerContainerType = emulatedElemType.getIntOrFloatBitWidth() /
+                                  origElemType.getIntOrFloatBitWidth();
   auto newLoad = rewriter.create<vector::LoadOp>(
       loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
       getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
   return rewriter.create<vector::BitCastOp>(
-      loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
+      loc,
+      VectorType::get(numEmultedElementsToLoad * elementsPerContainerType,
+                      origElemType),
       newLoad);
 }
 
@@ -298,6 +304,7 @@ namespace {
 // ConvertVectorStore
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -314,14 +321,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to store when emulating narrow types.
     // Here only the 1-D vector store is considered, and the N-D memref types
@@ -337,7 +344,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector<4xi8>
 
     auto origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
+    if (origElements % elementsPerContainerType != 0)
       return failure();
 
     auto stridedMetadata =
@@ -346,13 +353,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     OpFoldResult linearizedIndices;
     std::tie(std::ignore, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = origElements / scale;
+    auto numElements = origElements / elementsPerContainerType;
     auto bitCast = rewriter.create<vector::BitCastOp>(
         loc, VectorType::get(numElements, newElementType),
         op.getValueToStore());
@@ -368,6 +375,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 // ConvertVectorMaskedStore
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorMaskedStore final
     : OpConversionPattern<vector::MaskedStoreOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -385,17 +393,17 @@ struct ConvertVectorMaskedStore final
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
 
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
     int origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
+    if (origElements % elementsPerContainerType != 0)
       return failure();
 
     auto stridedMetadata =
@@ -404,7 +412,7 @@ struct ConvertVectorMaskedStore final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndicesOfr) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -444,12 +452,13 @@ struct ConvertVectorMaskedStore final
     //
     // FIXME: Make an example based on the comment above work (see #115460 for
     // reproducer).
-    FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+    FailureOr<Operation *> newMask = getCompressedMaskOp(
+        rewriter, loc, op.getMask(), origElements, elementsPerContainerType);
     if (failed(newMask))
       return failure();
 
-    auto numElements = (origElements + scale - 1) / scale;
+    auto numElements = (origElements + elementsPerContainerType - 1) /
+                       elementsPerContainerType;
     auto newType = VectorType::get(numElements, newElementType);
     auto passThru = rewriter.create<arith::ConstantOp>(
         loc, newType, rewriter.getZeroAttr(newType));
@@ -458,7 +467,8 @@ struct ConvertVectorMaskedStore final
         loc, newType, adaptor.getBase(), linearizedIndices,
         newMask.value()->getResult(0), passThru);
 
-    auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+    auto newBitCastType =
+        VectorType::get(numElements * elementsPerContainerType, oldElementType);
     Value valueToStore =
         rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
     valueToStore = rewriter.create<arith::SelectOp>(
@@ -477,6 +487,7 @@ struct ConvertVectorMaskedStore final
 // ConvertVectorLoad
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -493,14 +504,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -532,7 +543,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     // compile time as they must be constants.
 
     auto origElements = op.getVectorType().getNumElements();
-    bool isUnalignedEmulation = origElements % scale != 0;
+    // Note, per-element-alignment was already verified above.
+    bool isFullyAligned = origElements % elementsPerContainerType == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -541,21 +553,21 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isUnalignedEmulation
-            ? getConstantIntValue(linearizedInfo.intraDataOffset)
-            : 0;
+        isFullyAligned ? 0
+                       : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     // Always load enough elements which can cover the original elements.
-    int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    auto numElements =
-        llvm::divideCeil(maxintraDataOffset + origElements, scale);
+    int64_t maxintraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
+                                        elementsPerContainerType);
     Value result =
         emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
                            numElements, oldElementType, newElementType);
@@ -566,7 +578,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
       result = dynamicallyExtractSubVector(
           rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
           linearizedInfo.intraDataOffset, origElements);
-    } else if (isUnalignedEmulation) {
+    } else if (!isFullyAligned) {
       result =
           staticallyExtractSubvector(rewriter, loc, op.getType(), result,
                                      *foldedIntraVectorOffset, origElements);
@@ -580,6 +592,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
 // ConvertVectorMaskedLoad
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorMaskedLoad final
     : OpConversionPattern<vector::MaskedLoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -596,14 +609,14 @@ struct ConvertVectorMaskedLoad final
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -649,7 +662,7 @@ struct ConvertVectorMaskedLoad final
     // subvector at the proper offset after bit-casting.
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
-    bool isUnalignedEmulation = origElements % scale != 0;
+    bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -657,7 +670,7 @@ struct ConvertVectorMaskedLoad final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -668,18 +681,21 @@ struct ConvertVectorMaskedLoad final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    FailureOr<Operation *> newMask = getCompressedMaskOp(
-        rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
+    int64_t maxIntraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    FailureOr<Operation *> newMask =
+        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
+                            elementsPerContainerType, maxIntraDataOffset);
     if (failed(newMask))
       return failure();
 
     Value passthru = op.getPassThru();
 
-    auto numElements =
-        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+    auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+                                        elementsPerContainerType);
     auto loadType = VectorType::get(numElements, newElementType);
-    auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+    auto newBitcastType =
+        VectorType::get(numElements * elementsPerContainerType, oldElementType);
 
     auto emptyVector = rewriter.create<arith::ConstantOp>(
         loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -706,8 +722,8 @@ struct ConvertVectorMaskedLoad final
         rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
 
     Value mask = op.getMask();
-    auto newSelectMaskType =
-        VectorType::get(numElements * scale, rewriter.getI1Type());
+    auto newSelectMaskType = VectorType::get(
+        numElements * elementsPerContainerType, rewriter.getI1Type());
     // TODO: try to fold if op's mask is constant
     auto emptyMask = rewriter.create<arith::ConstantOp>(
         loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
@@ -737,10 +753,43 @@ struct ConvertVectorMaskedLoad final
   }
 };
 
+/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
+///
+/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
+/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
+/// (a multi-byte scalar, e.g. i16), where N is some integer.
+///
+/// Put differently, this method checks whether this would be valid:
+///
+///   vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
+///
+/// EXAMPLES:
+///   * vector<4xi4> -> i16 - yes (N = 1)
+///   * vector<4xi4> -> i8 - yes (N = 2)
+///   * vector<3xi4> -> i8 - no (N would have to be 1.5)
+///   * vector<3xi2> -> i16 - no (N would have to be 0.5)
+static bool isSubByteVecFittable(VectorType subByteVecTy,
+                                 Type multiByteScalarTy) {
+  assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
+
+  int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
+  int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
+
+  assert(subByteBits < 8 && "Not a sub-byte scalar type!");
+  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+  assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
+
+  int elemsPerMultiByte = multiByteBits / subByteBits;
+
+  // TODO: This is a bit too restrictive for vectors rank > 1.
+  return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
+}
+
 //===----------------------------------------------------------------------===//
 // ConvertVectorTransferRead
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorTransferRead final
     : OpConversionPattern<vector::TransferReadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -758,18 +807,20 @@ struct ConvertVectorTransferRead final
     auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     auto origElements = op.getVectorType().getNumElements();
 
-    bool isUnalignedEmulation = origElements % scale != 0;
+    // Note, per-element-alignment was already verified above.
+    bool isFullyAligned =
+        isSubByteVecFittable(op.getVectorType(), newElementType);
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
                                                       adaptor.getPadding());
@@ -781,20 +832,20 @@ struct ConvertVectorTransferRead final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isUnalignedEmulation
-            ? getConstantIntValue(linearizedInfo.intraDataOffset)
-            : 0;
+        isFullyAligned ? 0
+                       : getConstantIntValue(linearizedInfo.intraDataOffset);
 
-    int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    auto numElements =
-        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+    int64_t maxIntraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+                                        elementsPerContainerType);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -802,7 +853,9 @@ struct ConvertVectorTransferRead final
         newPadding);
 
     auto bitCast = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements * scale, oldElementType), newRead);
+        loc,
+        VectorType::get(numElements * elementsPerContainerType, oldElementType),
+        newRead);
 
     Value result = bitCast->getResult(0);
     if (!foldedIntraVectorOffset) {
@@ -811,7 +864,7 @@ struct ConvertVectorTransferRead final
       result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
                                            linearizedInfo.intraDataOffset,
                                            origElements);
-    } else if (isUnalignedEmulation) {
+    } else if (!isFullyAligned) {
       result =
           staticallyExtractSubvector(rewriter, loc, op.getType(), result,
                                      *foldedIntraVectorOffset, origElements);
@@ -1069,41 +1122,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
   return commonConversionPrecondition(rewriter, preconditionType, op);
 }
 
-/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
-/// means that:
-///   1. The `dstType` element type is a multiple of the
-///   `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
-///   is not supported). Let this multiple be `N`.
-///   2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
-///   multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
-///   not supported).
+/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
+///
+/// Alignment means that `subByteVecTy` can be packed into a vector of
+/// `containerTy` elements. More specifically:
+///   1. The bit-width of `containerTy` is a multiple of the
+///      bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
+///      this multiple is 4.
+///   2. The multiple from 1. above divides evenly the number of the (trailing)
+///      elements in `subByteVecTy`.
+///
+/// EXAMPLE 1:
+///   `subByteVecTy = vector<2xi4>`, and
+///   `containerTy = i16`
+///
+/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
+///
+/// EXAMPLE 2:
+///   `subByteVecTy = vector<3xi4>`, and
+///   `containerTy = i16`
+///
+/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
+///
+/// EXAMPLE 3:
+///   `subByteVecTy = vector<3xi3>`, and
+///   `containerTy = i16`
+///
+/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
 ///
 /// NOTE: This method assumes that common conversion preconditions are met. In
-/// particular, the element type of `dstType` is assumed to be a multi-byte
-/// type (e.g. i8, i16, i32).
+/// particular, `containerTy` is assumed to be a
+/// multi-byte scalar type (e.g., i8, i16, i32).
 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
-                                                   VectorType subByteVecType,
-                                                   VectorType dstType,
+                                                   VectorType subByteVecTy,
+                                                   Type containerTy,
                                                    Operation *op) {
-  if (!subByteVecType || !dstType)
-    return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
-  unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
-  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+  // TODO: This is validating the inputs rather than checking the conditions
+  // documented above. Replace with an assert.
+  if (!subByteVecTy)
+    return rewriter.notifyMatchFailure(op, "not a vector!");
 
-  if (dstElemBitwidth < 8)
-    return rewriter.notifyMatchFailure(
-        op, "the bitwidth of dstType must be greater than or equal to 8");
-  if (dstElemBitwidth % srcElemBitwidth != 0)
-    return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
-  if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
+  // TODO: This is validating the inputs rather than checking the conditions
+  // documented above. Replace with an assert.
+  if (!containerTy.isIntOrFloat())
+    return rewriter.notifyMatchFailure(op, "not a scalar!");
+
+  unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
+  unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
+
+  // Enforced by the common pre-conditions.
+  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+
+  // TODO: Remove this condition - the assert above (and
+  // commonConversionPrecondtion) takes care of that.
+  if (multiByteBits < 8)
+    return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
+
+  // TODO: Add support other widths (when/if needed)
+  if (subByteBits != 2 && subByteBits != 4)
     return rewriter.notifyMatchFailure(
-        op, "only src bitwidth of 2 or 4 is supported at this moment");
+        op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
+
+  // Condition 1.
+  if (multiByteBits % subByteBits != 0)
+    return rewriter.notifyMatchFailure(op, "unalagined element types");
 
-  const int numSrcElemsPerByte = 8 / srcElemBitwidth;
-  if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
+  // Condition 2.
+  if (!isSubByteVecFittable(subByteVecTy, containerTy))
     return rewriter.notifyMatchFailure(
-        op, "the trailing dimension of the input vector of sub-bytes must be a "
-            "multiple of 8 / <sub-byte-width>");
+        op, "not possible to fit this sub-byte vector type into a vector of "
+            "the given multi-byte type");
 
   return success();
 }
@@ -1495,33 +1583,34 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 /// LLVM to scramble with peephole optimizations. Templated to choose between
 /// signed and unsigned conversions.
 ///
-/// For example (signed):
+/// EXAMPLE 1 (signed):
 ///    arith.extsi %in : vector<8xi4> to vector<8xi32>
-///      is rewriten as
-///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
-///        %1 = arith.shli %0, 4 : vector<4xi8>
-///        %2 = arith.shrsi %1, 4 : vector<4xi8>
-///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
-///        %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
+/// is rewriten as:
+///    %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///    %1 = arith.shli %0, 4 : vector<4xi8>
+///    %2 = arith.shrsi %1, 4 : vector<4xi8>
+///    %3 = arith.shrsi %0, 4 : vector<4xi8>
+///    %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
+///    %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
 ///
+/// EXAMPLE 2 (fp):
 ///    arith.sitofp %in : vector<8xi4> to vector<8xf32>
-///      is rewriten as
-///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
-///        %1 = arith.shli %0, 4 : vector<4xi8>
-///        %2 = arith.shrsi %1, 4 : vector<4xi8>
-///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
-///        %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
+/// is rewriten as:
+///    %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///    %1 = arith.shli %0, 4 : vector<4xi8>
+///    %2 = arith.shrsi %1, 4 : vector<4xi8>
+///    %3 = arith.shrsi %0, 4 : vector<4xi8>
+///    %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
+///    %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
 ///
-/// Example (unsigned):
+/// EXAMPLE 3 (unsigned):
 ///    arith.extui %in : vector<8xi4> to vector<8xi32>
-///      is rewritten as
-///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
-///        %1 = arith.andi %0, 15 : vector<4xi8>
-///        %2 = arith.shrui %0, 4 : vector<4xi8>
-///        %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
-///        %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
+///  is rewritten as:
+///    %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///    %1 = arith.andi %0, 15 : vector<4xi8>
+///    %2 = arith.shrui %0, 4 : vector<4xi8>
+///    %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
+///    %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
 ///
 template <typename ConversionOpType, bool isSigned>
 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
@@ -1531,16 +1620,17 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
                                 PatternRewriter &rewriter) const override {
     // Verify the preconditions.
     Value srcValue = conversionOp.getIn();
-    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
-    auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+    VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
+    VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
 
     if (failed(
             commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
       return failure();
 
     // Check general alignment preconditions.
-    if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
-                                             conversionOp)))
+    Type containerType = rewriter.getI8Type();
+    if (failed(alignedConversionPrecondition(rewriter, srcVecType,
+                                             containerType, conversionOp)))
       return failure();
 
     // Perform the rewrite.
@@ -1572,15 +1662,16 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
 ///
 /// For example:
 ///    arith.trunci %in : vector<8xi32> to vector<8xi4>
-///      is rewriten as
 ///
-///        %cst = arith.constant dense<15> : vector<4xi8>
-///        %cst_0 = arith.constant dense<4> : vector<4xi8>
-///        %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
-///        %2 = arith.andi %0, %cst : vector<4xi8>
-///        %3 = arith.shli %1, %cst_0 : vector<4xi8>
-///        %4 = arith.ori %2, %3 : vector<4xi8>
-///        %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
+/// is rewriten as:
+///
+///   %cst = arith.constant dense<15> : vector<4xi8>
+///   %cst_0 = arith.constant dense<4> : vector<4xi8>
+///   %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
+///   %2 = arith.andi %0, %cst : vector<4xi8>
+///   %3 = arith.shli %1, %cst_0 : vector<4xi8>
+///   %4 = arith.ori %2, %3 : vector<4xi8>
+///   %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
 ///
 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
   using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
@@ -1603,8 +1694,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
 
     // Check general alignment preconditions. We invert the src/dst type order
     // to reuse the existing precondition logic.
-    if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
-                                             truncOp)))
+    Type containerType = rewriter.getI8Type();
+    if (failed(alignedConversionPrecondition(rewriter, dstVecType,
+                                             containerType, truncOp)))
       return failure();
 
     // Create a new iX -> i8 truncation op.
@@ -1624,10 +1716,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
 
 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
 /// perform the transpose on wider (byte) element types.
-/// For example:
+///
+/// EXAMPLE:
 ///   %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
 ///
-///   is rewritten as:
+/// is rewritten as:
 ///
 ///   %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
 ///   %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
@@ -1675,6 +1768,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
 // Public Interface Definition
 //===----------------------------------------------------------------------===//
 
+// The emulated type is inferred from the converted memref type.
 void vector::populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns) {
@@ -1687,22 +1781,26 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
 
 void vector::populateVectorNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
+  // TODO: Document what the emulated type is.
   patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
                RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
                                                     benefit);
 
   // Patterns for aligned cases. We set higher priority as they are expected to
   // generate better performance for aligned cases.
+  // The emulated type is always i8.
   patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
                RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
                RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
                                               benefit.getBenefit() + 1);
+  // The emulated type is always i8.
   patterns
       .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
            RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
           patterns.getContext(), benefit.getBenefit() + 1);
 }
 
+// The emulated type is always i8.
 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);



More information about the Mlir-commits mailing list