[Mlir-commits] [mlir] [mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N) (PR #123527)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Feb 4 10:02:24 PST 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/123527

>From aaeb0fb646105b5af5b9d1841a49120c39f03b9d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 2 Feb 2025 15:36:33 +0000
Subject: [PATCH 1/2] [mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N)

This is PR 2 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

This PR renames the variable "scale". Note, "scale" could mean either:

  * "original-elements-per-emulated-type", or
  * "emulated-elements-per-original-type".

While from the context it is clear that it's always the former (original
type is always a sub-byte type and the emulated type is usually `i8`),
this PR reduces the cognitive load by making this clear.

**DEPENDS ON:**
* #123526 123526

Please only review the [top
commit](https://github.com/llvm/llvm-project/pull/123527/commits/d40b31bb098e874be488182050c68b887e8d091a).

**GitHub issue to track this work**:
https://github.com/llvm/llvm-project/issues/123630
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 78 +++++++++++--------
 1 file changed, 45 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 0d310dc8be2fe9..831c1ab736105a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -290,13 +290,15 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
                                       int64_t numContainerElemsToLoad,
                                       Type emulatedElemTy,
                                       Type containerElemTy) {
-  auto scale = containerElemTy.getIntOrFloatBitWidth() /
-               emulatedElemTy.getIntOrFloatBitWidth();
+  auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
+                                  emulatedElemTy.getIntOrFloatBitWidth();
   auto newLoad = rewriter.create<vector::LoadOp>(
       loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,
       getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
   return rewriter.create<vector::BitCastOp>(
-      loc, VectorType::get(numContainerElemsToLoad * scale, emulatedElemTy),
+      loc,
+      VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
+                      emulatedElemTy),
       newLoad);
 }
 
@@ -388,10 +390,11 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
       "sliceNumElements * vector element size must be less than or equal to 8");
   assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
          "vector element must be a valid sub-byte type");
-  auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+  auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
   auto emptyByteVector = rewriter.create<arith::ConstantOp>(
-      loc, VectorType::get({scale}, vectorElementType),
-      rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+      loc, VectorType::get({emulatedPerContainerElem}, vectorElementType),
+      rewriter.getZeroAttr(
+          VectorType::get({emulatedPerContainerElem}, vectorElementType)));
   auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
                                               extractOffset, sliceNumElements);
   return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
@@ -656,9 +659,9 @@ struct ConvertVectorMaskedStore final
               "(bit-wise misalignment)");
     }
 
-    int scale = containerBits / emulatedBits;
+    int emulatedPerContainerElem = containerBits / emulatedBits;
     int origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
+    if (origElements % emulatedPerContainerElem != 0)
       return failure();
 
     auto stridedMetadata =
@@ -707,12 +710,13 @@ struct ConvertVectorMaskedStore final
     //
     // FIXME: Make an example based on the comment above work (see #115460 for
     // reproducer).
-    FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+    FailureOr<Operation *> newMask = getCompressedMaskOp(
+        rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
     if (failed(newMask))
       return failure();
 
-    auto numElements = (origElements + scale - 1) / scale;
+    auto numElements = (origElements + emulatedPerContainerElem - 1) /
+                       emulatedPerContainerElem;
     auto newType = VectorType::get(numElements, containerElemTy);
     auto passThru = rewriter.create<arith::ConstantOp>(
         loc, newType, rewriter.getZeroAttr(newType));
@@ -721,7 +725,8 @@ struct ConvertVectorMaskedStore final
         loc, newType, adaptor.getBase(), linearizedIndices,
         newMask.value()->getResult(0), passThru);
 
-    auto newBitCastType = VectorType::get(numElements * scale, emulatedElemTy);
+    auto newBitCastType =
+        VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
     Value valueToStore =
         rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
     valueToStore = rewriter.create<arith::SelectOp>(
@@ -765,7 +770,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
           op, "impossible to pack emulated elements into container elements "
               "(bit-wise misalignment)");
     }
-    int scale = containerBits / emulatedBits;
+    int emulatedPerContainerElem = containerBits / emulatedBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -797,7 +802,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     // compile time as they must be constants.
 
     auto origElements = op.getVectorType().getNumElements();
-    bool isAlignedEmulation = origElements % scale == 0;
+    bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -818,9 +823,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     // Always load enough elements which can cover the original elements.
-    int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    auto numElements =
-        llvm::divideCeil(maxintraDataOffset + origElements, scale);
+    int64_t maxintraDataOffset =
+        foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
+    auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
+                                        emulatedPerContainerElem);
     Value result =
         emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
                            numElements, emulatedElemTy, containerElemTy);
@@ -870,7 +876,7 @@ struct ConvertVectorMaskedLoad final
           op, "impossible to pack emulated elements into container elements "
               "(bit-wise misalignment)");
     }
-    int scale = containerBits / emulatedBits;
+    int emulatedPerContainerElem = containerBits / emulatedBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -916,7 +922,7 @@ struct ConvertVectorMaskedLoad final
     // subvector at the proper offset after bit-casting.
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
-    bool isAlignedEmulation = origElements % scale == 0;
+    bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -935,18 +941,21 @@ struct ConvertVectorMaskedLoad final
             ? 0
             : getConstantIntValue(linearizedInfo.intraDataOffset);
 
-    int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    FailureOr<Operation *> newMask = getCompressedMaskOp(
-        rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
+    int64_t maxIntraDataOffset =
+        foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
+    FailureOr<Operation *> newMask =
+        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
+                            emulatedPerContainerElem, maxIntraDataOffset);
     if (failed(newMask))
       return failure();
 
     Value passthru = op.getPassThru();
 
-    auto numElements =
-        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+    auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+                                        emulatedPerContainerElem);
     auto loadType = VectorType::get(numElements, containerElemTy);
-    auto newBitcastType = VectorType::get(numElements * scale, emulatedElemTy);
+    auto newBitcastType =
+        VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
 
     auto emptyVector = rewriter.create<arith::ConstantOp>(
         loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -973,8 +982,8 @@ struct ConvertVectorMaskedLoad final
         rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
 
     Value mask = op.getMask();
-    auto newSelectMaskType =
-        VectorType::get(numElements * scale, rewriter.getI1Type());
+    auto newSelectMaskType = VectorType::get(
+        numElements * emulatedPerContainerElem, rewriter.getI1Type());
     // TODO: try to fold if op's mask is constant
     auto emptyMask = rewriter.create<arith::ConstantOp>(
         loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
@@ -1033,11 +1042,11 @@ struct ConvertVectorTransferRead final
           op, "impossible to pack emulated elements into container elements "
               "(bit-wise misalignment)");
     }
-    int scale = containerBits / emulatedBits;
+    int emulatedPerContainerElem = containerBits / emulatedBits;
 
     auto origElements = op.getVectorType().getNumElements();
 
-    bool isAlignedEmulation = origElements % scale == 0;
+    bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
                                                       adaptor.getPadding());
@@ -1060,9 +1069,10 @@ struct ConvertVectorTransferRead final
             ? 0
             : getConstantIntValue(linearizedInfo.intraDataOffset);
 
-    int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    auto numElements =
-        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+    int64_t maxIntraDataOffset =
+        foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
+    auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+                                        emulatedPerContainerElem);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
@@ -1070,7 +1080,9 @@ struct ConvertVectorTransferRead final
         newPadding);
 
     auto bitCast = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements * scale, emulatedElemTy), newRead);
+        loc,
+        VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
+        newRead);
 
     Value result = bitCast->getResult(0);
     if (!foldedIntraVectorOffset) {

>From 95f8ad113145083846177a599f3d1e4b6fcaeab1 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 17 Jan 2025 13:54:34 +0000
Subject: [PATCH 2/2] [mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N)

This is PR 3 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Replaces `isUnalignedEmulation` with `isFullyAligned`

Note, `isUnalignedEmulation` is always computed following a
"per-element-alignment" condition:
```cpp
// Check per-element alignment.
if (containerBits % emulatedBits != 0) {
  return rewriter.notifyMatchFailure(
    op, "impossible to pack emulated elements into container elements "
    "(bit-wise misalignment)");
}

// (...)

bool isUnalignedEmulation = origElements % emulatedPerContainerElem != 0;
```

Given that `isUnalignedEmulation` captures only one of two conditions
required for "full alignment", it should be re-named as
`isPartiallyUnalignedEmulation`. Instead, I've flipped the condition and
renamed it as `isFullyAligned`:

```cpp
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
```

2. In addition:
  * Unifies various comments throughout the file (for consistency).
  * Adds new comments throughout the file and adds TODOs where high-level
    comments are missing.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 111 ++++++++++--------
 1 file changed, 64 insertions(+), 47 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 831c1ab736105a..28ccbfbb6962e9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -48,6 +48,10 @@ using namespace mlir;
 using VectorValue = TypedValue<VectorType>;
 using MemRefValue = TypedValue<MemRefType>;
 
+//===----------------------------------------------------------------------===//
+// Utils
+//===----------------------------------------------------------------------===//
+
 /// Returns a compressed mask for the emulated vector. For example, when
 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
 /// elements span two dest elements), this method compresses `vector<8xi1>`
@@ -407,6 +411,7 @@ namespace {
 // ConvertVectorStore
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -632,6 +637,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 // ConvertVectorMaskedStore
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorMaskedStore final
     : OpConversionPattern<vector::MaskedStoreOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -745,6 +751,7 @@ struct ConvertVectorMaskedStore final
 // ConvertVectorLoad
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -802,7 +809,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     // compile time as they must be constants.
 
     auto origElements = op.getVectorType().getNumElements();
-    bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
+    // Note, per-element-alignment was already verified above.
+    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -818,9 +826,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isAlignedEmulation
-            ? 0
-            : getConstantIntValue(linearizedInfo.intraDataOffset);
+        isFullyAligned ? 0
+                       : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     // Always load enough elements which can cover the original elements.
     int64_t maxintraDataOffset =
@@ -834,10 +841,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     if (!foldedIntraVectorOffset) {
       auto resultVector = rewriter.create<arith::ConstantOp>(
           loc, op.getType(), rewriter.getZeroAttr(op.getType()));
-      result = dynamicallyExtractSubVector(rewriter, loc, result, resultVector,
-                                           linearizedInfo.intraDataOffset,
-                                           origElements);
-    } else if (!isAlignedEmulation) {
+      result = dynamicallyExtractSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+          linearizedInfo.intraDataOffset, origElements);
+    } else if (!isFullyAligned) {
       result = staticallyExtractSubvector(
           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
@@ -850,6 +857,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
 // ConvertVectorMaskedLoad
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorMaskedLoad final
     : OpConversionPattern<vector::MaskedLoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -1016,6 +1024,7 @@ struct ConvertVectorMaskedLoad final
 // ConvertVectorTransferRead
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorTransferRead final
     : OpConversionPattern<vector::TransferReadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -1046,7 +1055,8 @@ struct ConvertVectorTransferRead final
 
     auto origElements = op.getVectorType().getNumElements();
 
-    bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
+    // Note, per-element-alignment was already verified above.
+    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
                                                       adaptor.getPadding());
@@ -1065,9 +1075,8 @@ struct ConvertVectorTransferRead final
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isAlignedEmulation
-            ? 0
-            : getConstantIntValue(linearizedInfo.intraDataOffset);
+        isFullyAligned ? 0
+                       : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     int64_t maxIntraDataOffset =
         foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1091,7 +1100,7 @@ struct ConvertVectorTransferRead final
       result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
                                            linearizedInfo.intraDataOffset,
                                            origElements);
-    } else if (!isAlignedEmulation) {
+    } else if (!isFullyAligned) {
       result = staticallyExtractSubvector(
           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
@@ -1774,33 +1783,34 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 /// LLVM to scramble with peephole optimizations. Templated to choose between
 /// signed and unsigned conversions.
 ///
-/// For example (signed):
+/// EXAMPLE 1 (signed):
 ///    arith.extsi %in : vector<8xi4> to vector<8xi32>
-///      is rewriten as
-///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
-///        %1 = arith.shli %0, 4 : vector<4xi8>
-///        %2 = arith.shrsi %1, 4 : vector<4xi8>
-///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
-///        %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
+/// is rewriten as:
+///    %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///    %1 = arith.shli %0, 4 : vector<4xi8>
+///    %2 = arith.shrsi %1, 4 : vector<4xi8>
+///    %3 = arith.shrsi %0, 4 : vector<4xi8>
+///    %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
+///    %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
 ///
+/// EXAMPLE 2 (fp):
 ///    arith.sitofp %in : vector<8xi4> to vector<8xf32>
-///      is rewriten as
-///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
-///        %1 = arith.shli %0, 4 : vector<4xi8>
-///        %2 = arith.shrsi %1, 4 : vector<4xi8>
-///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
-///        %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
+/// is rewriten as:
+///    %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///    %1 = arith.shli %0, 4 : vector<4xi8>
+///    %2 = arith.shrsi %1, 4 : vector<4xi8>
+///    %3 = arith.shrsi %0, 4 : vector<4xi8>
+///    %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8>
+///    %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
 ///
-/// Example (unsigned):
+/// EXAMPLE 3 (unsigned):
 ///    arith.extui %in : vector<8xi4> to vector<8xi32>
-///      is rewritten as
-///        %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
-///        %1 = arith.andi %0, 15 : vector<4xi8>
-///        %2 = arith.shrui %0, 4 : vector<4xi8>
-///        %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
-///        %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
+///  is rewritten as:
+///    %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+///    %1 = arith.andi %0, 15 : vector<4xi8>
+///    %2 = arith.shrui %0, 4 : vector<4xi8>
+///    %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8>
+///    %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
 ///
 template <typename ConversionOpType, bool isSigned>
 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
@@ -1810,8 +1820,8 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
                                 PatternRewriter &rewriter) const override {
     // Verify the preconditions.
     Value srcValue = conversionOp.getIn();
-    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
-    auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+    VectorType srcVecType = dyn_cast<VectorType>(srcValue.getType());
+    VectorType dstVecType = dyn_cast<VectorType>(conversionOp.getType());
 
     if (failed(
             commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
@@ -1851,15 +1861,16 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
 ///
 /// For example:
 ///    arith.trunci %in : vector<8xi32> to vector<8xi4>
-///      is rewriten as
 ///
-///        %cst = arith.constant dense<15> : vector<4xi8>
-///        %cst_0 = arith.constant dense<4> : vector<4xi8>
-///        %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
-///        %2 = arith.andi %0, %cst : vector<4xi8>
-///        %3 = arith.shli %1, %cst_0 : vector<4xi8>
-///        %4 = arith.ori %2, %3 : vector<4xi8>
-///        %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
+/// is rewriten as:
+///
+///   %cst = arith.constant dense<15> : vector<4xi8>
+///   %cst_0 = arith.constant dense<4> : vector<4xi8>
+///   %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8>
+///   %2 = arith.andi %0, %cst : vector<4xi8>
+///   %3 = arith.shli %1, %cst_0 : vector<4xi8>
+///   %4 = arith.ori %2, %3 : vector<4xi8>
+///   %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
 ///
 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
   using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
@@ -1903,10 +1914,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
 
 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
 /// perform the transpose on wider (byte) element types.
-/// For example:
+///
+/// EXAMPLE:
 ///   %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
 ///
-///   is rewritten as:
+/// is rewritten as:
 ///
 ///   %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
 ///   %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
@@ -1954,6 +1966,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
 // Public Interface Definition
 //===----------------------------------------------------------------------===//
 
+// The emulated type is inferred from the converted memref type.
 void vector::populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns) {
@@ -1966,22 +1979,26 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
 
 void vector::populateVectorNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
+  // TODO: Document what the emulated type is.
   patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
                RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
                                                     benefit);
 
   // Patterns for aligned cases. We set higher priority as they are expected to
   // generate better performance for aligned cases.
+  // The emulated type is always i8.
   patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
                RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
                RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
                                               benefit.getBenefit() + 1);
+  // The emulated type is always i8.
   patterns
       .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
            RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
           patterns.getContext(), benefit.getBenefit() + 1);
 }
 
+// The emulated type is always i8.
 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);



More information about the Mlir-commits mailing list