[Mlir-commits] [mlir] andrzej/refactor narrow type 2 (PR #123527)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Jan 19 12:50:19 PST 2025


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/123527

- **[mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N)**
- **[mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N)**


>From 5edc3423c1da0c62c8a7bb5fa3e9d54855bdf3bf Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 18 Jan 2025 21:45:51 +0000
Subject: [PATCH 1/2] [mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N)

This is PR 1 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

This PR renames `srcBits/dstBits` to `oldBits/newBits` to improve
consistency in naming within the file. This is illustrated below:

```cpp
  // Extracted from VectorEmulateNarrowType.cpp
  Type oldElementType = op.getType().getElementType();
  Type newElementType = convertedType.getElementType();

  // BEFORE (mixing old/new and src/dst):
  // int srcBits = oldElementType.getIntOrFloatBitWidth();
  // int dstBits = newElementType.getIntOrFloatBitWidth();

  // AFTER (consistently using old/new):
  int oldBits = oldElementType.getIntOrFloatBitWidth();
  int newBits = newElementType.getIntOrFloatBitWidth();
```

Also adds some comments and unifies related "rewriter notification"
messages.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 70 +++++++++----------
 1 file changed, 35 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..70d50e1d48040c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -314,14 +314,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int scale = newBits / oldBits;
 
     // Adjust the number of elements to store when emulating narrow types.
     // Here only the 1-D vector store is considered, and the N-D memref types
@@ -346,7 +346,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     OpFoldResult linearizedIndices;
     std::tie(std::ignore, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -385,15 +385,15 @@ struct ConvertVectorMaskedStore final
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
 
-    int scale = dstBits / srcBits;
+    int scale = newBits / oldBits;
     int origElements = op.getValueToStore().getType().getNumElements();
     if (origElements % scale != 0)
       return failure();
@@ -404,7 +404,7 @@ struct ConvertVectorMaskedStore final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndicesOfr) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -493,14 +493,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int scale = newBits / oldBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -541,7 +541,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -596,14 +596,14 @@ struct ConvertVectorMaskedLoad final
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int scale = newBits / oldBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -657,7 +657,7 @@ struct ConvertVectorMaskedLoad final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -758,14 +758,14 @@ struct ConvertVectorTransferRead final
     auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int scale = newBits / oldBits;
 
     auto origElements = op.getVectorType().getNumElements();
 
@@ -781,7 +781,7 @@ struct ConvertVectorTransferRead final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),

>From d40b31bb098e874be488182050c68b887e8d091a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 18 Jan 2025 21:59:34 +0000
Subject: [PATCH 2/2] [mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N)

This is PR 2 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

This PR renames the variable "scale". Note, "scale" could mean either:

  * "original-elements-per-emulated-type", or
  * "emulated-elements-per-original-type".

While from the context it is clear that it's always the former (original
type is always a sub-byte type and the emulated type is usually `i8`),
this PR reduces the cognitive load by making this clear.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 77 +++++++++++--------
 1 file changed, 44 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 70d50e1d48040c..4e0be258954496 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -282,13 +282,15 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
                    OpFoldResult linearizedIndices,
                    int64_t numEmultedElementsToLoad, Type origElemType,
                    Type emulatedElemType) {
-  auto scale = emulatedElemType.getIntOrFloatBitWidth() /
-               origElemType.getIntOrFloatBitWidth();
+  auto elementsPerContainerType = emulatedElemType.getIntOrFloatBitWidth() /
+                                  origElemType.getIntOrFloatBitWidth();
   auto newLoad = rewriter.create<vector::LoadOp>(
       loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
       getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
   return rewriter.create<vector::BitCastOp>(
-      loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
+      loc,
+      VectorType::get(numEmultedElementsToLoad * elementsPerContainerType,
+                      origElemType),
       newLoad);
 }
 
@@ -321,7 +323,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     if (newBits % oldBits != 0) {
       return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = newBits / oldBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to store when emulating narrow types.
     // Here only the 1-D vector store is considered, and the N-D memref types
@@ -337,7 +339,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector<4xi8>
 
     auto origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
+    if (origElements % elementsPerContainerType != 0)
       return failure();
 
     auto stridedMetadata =
@@ -352,7 +354,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = origElements / scale;
+    auto numElements = origElements / elementsPerContainerType;
     auto bitCast = rewriter.create<vector::BitCastOp>(
         loc, VectorType::get(numElements, newElementType),
         op.getValueToStore());
@@ -393,9 +395,9 @@ struct ConvertVectorMaskedStore final
       return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
 
-    int scale = newBits / oldBits;
+    int elementsPerContainerType = newBits / oldBits;
     int origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
+    if (origElements % elementsPerContainerType != 0)
       return failure();
 
     auto stridedMetadata =
@@ -444,12 +446,13 @@ struct ConvertVectorMaskedStore final
     //
     // FIXME: Make an example based on the comment above work (see #115460 for
     // reproducer).
-    FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+    FailureOr<Operation *> newMask = getCompressedMaskOp(
+        rewriter, loc, op.getMask(), origElements, elementsPerContainerType);
     if (failed(newMask))
       return failure();
 
-    auto numElements = (origElements + scale - 1) / scale;
+    auto numElements = (origElements + elementsPerContainerType - 1) /
+                       elementsPerContainerType;
     auto newType = VectorType::get(numElements, newElementType);
     auto passThru = rewriter.create<arith::ConstantOp>(
         loc, newType, rewriter.getZeroAttr(newType));
@@ -458,7 +461,8 @@ struct ConvertVectorMaskedStore final
         loc, newType, adaptor.getBase(), linearizedIndices,
         newMask.value()->getResult(0), passThru);
 
-    auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+    auto newBitCastType =
+        VectorType::get(numElements * elementsPerContainerType, oldElementType);
     Value valueToStore =
         rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
     valueToStore = rewriter.create<arith::SelectOp>(
@@ -500,7 +504,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     if (newBits % oldBits != 0) {
       return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = newBits / oldBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -532,7 +536,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     // compile time as they must be constants.
 
     auto origElements = op.getVectorType().getNumElements();
-    bool isUnalignedEmulation = origElements % scale != 0;
+    bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -553,9 +557,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             : 0;
 
     // Always load enough elements which can cover the original elements.
-    int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    auto numElements =
-        llvm::divideCeil(maxintraDataOffset + origElements, scale);
+    int64_t maxintraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
+                                        elementsPerContainerType);
     Value result =
         emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
                            numElements, oldElementType, newElementType);
@@ -603,7 +608,7 @@ struct ConvertVectorMaskedLoad final
     if (newBits % oldBits != 0) {
       return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = newBits / oldBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -649,7 +654,7 @@ struct ConvertVectorMaskedLoad final
     // subvector at the proper offset after bit-casting.
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
-    bool isUnalignedEmulation = origElements % scale != 0;
+    bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -668,18 +673,21 @@ struct ConvertVectorMaskedLoad final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    FailureOr<Operation *> newMask = getCompressedMaskOp(
-        rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
+    int64_t maxIntraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    FailureOr<Operation *> newMask =
+        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
+                            elementsPerContainerType, maxIntraDataOffset);
     if (failed(newMask))
       return failure();
 
     Value passthru = op.getPassThru();
 
-    auto numElements =
-        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+    auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+                                        elementsPerContainerType);
     auto loadType = VectorType::get(numElements, newElementType);
-    auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+    auto newBitcastType =
+        VectorType::get(numElements * elementsPerContainerType, oldElementType);
 
     auto emptyVector = rewriter.create<arith::ConstantOp>(
         loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -706,8 +714,8 @@ struct ConvertVectorMaskedLoad final
         rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
 
     Value mask = op.getMask();
-    auto newSelectMaskType =
-        VectorType::get(numElements * scale, rewriter.getI1Type());
+    auto newSelectMaskType = VectorType::get(
+        numElements * elementsPerContainerType, rewriter.getI1Type());
     // TODO: try to fold if op's mask is constant
     auto emptyMask = rewriter.create<arith::ConstantOp>(
         loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
@@ -765,11 +773,11 @@ struct ConvertVectorTransferRead final
     if (newBits % oldBits != 0) {
       return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = newBits / oldBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     auto origElements = op.getVectorType().getNumElements();
 
-    bool isUnalignedEmulation = origElements % scale != 0;
+    bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
                                                       adaptor.getPadding());
@@ -792,9 +800,10 @@ struct ConvertVectorTransferRead final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    auto numElements =
-        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+    int64_t maxIntraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+                                        elementsPerContainerType);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -802,7 +811,9 @@ struct ConvertVectorTransferRead final
         newPadding);
 
     auto bitCast = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements * scale, oldElementType), newRead);
+        loc,
+        VectorType::get(numElements * elementsPerContainerType, oldElementType),
+        newRead);
 
     Value result = bitCast->getResult(0);
     if (!foldedIntraVectorOffset) {



More information about the Mlir-commits mailing list