[Mlir-commits] [mlir] [mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N) (PR #123526)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Feb 2 06:49:27 PST 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/123526

>From 2b7a6110fe51290e55f3356ce7563ff2d922066a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 18 Jan 2025 21:45:51 +0000
Subject: [PATCH] [mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N)

This is PR 1 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

This PR renames srcBits/dstBits + oldElementType/newElementType to
improve consistency in naming within the file. This is illustrated
below:

```cpp
  // Extracted from VectorEmulateNarrowType.cpp

  // BEFORE (mixing old/new and src/dst):
  // Type oldElementType = op.getType().getElementType();
  // Type newElementType = convertedType.getElementType();

  // int srcBits = oldElementType.getIntOrFloatBitWidth();
  // int dstBits = newElementType.getIntOrFloatBitWidth();

  // AFTER (consistently using emulated/container):
  Type emulatedElemType = op.getType().getElementType();
  Type containerElemType = convertedType.getElementType();

  int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
  int containerBits = containerElemTy.getIntOrFloatBitWidth();
```

Also adds some comments and unifies related "rewriter notification"
messages.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 120 ++++++++++--------
 1 file changed, 66 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 7ca88f1e0a0df9..63365cb5446124 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -415,18 +415,21 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
                                          "only 1-D vectors are supported ATM");
 
     auto loc = op.getLoc();
+
     auto valueToStore = cast<VectorValue>(op.getValueToStore());
-    auto oldElementType = valueToStore.getType().getElementType();
-    auto newElementType =
+    auto containerElemTy =
         cast<MemRefType>(adaptor.getBase().getType()).getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    Type emulatedElemTy = op.getValueToStore().getType().getElementType();
+    int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+    int containerBits = containerElemTy.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
+    // Check per-element alignment.
+    if (containerBits % emulatedBits != 0) {
       return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+          op, "impossible to pack emulated elements into container elements "
+              "(bit-wise misalignment)");
     }
-    int numSrcElemsPerDest = dstBits / srcBits;
+    int numSrcElemsPerDest = containerBits / emulatedBits;
 
     // Adjust the number of elements to store when emulating narrow types.
     // Here only the 1-D vector store is considered, and the N-D memref types
@@ -451,7 +454,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, emulatedBits, containerBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -483,7 +486,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
       // Basic case: storing full bytes.
       auto numElements = origElements / numSrcElemsPerDest;
       auto bitCast = rewriter.create<vector::BitCastOp>(
-          loc, VectorType::get(numElements, newElementType),
+          loc, VectorType::get(numElements, containerElemTy),
           op.getValueToStore());
       rewriter.replaceOpWithNewOp<vector::StoreOp>(
           op, bitCast.getResult(), memrefBase,
@@ -638,18 +641,20 @@ struct ConvertVectorMaskedStore final
                                          "only 1-D vectors are supported ATM");
 
     auto loc = op.getLoc();
-    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
-    Type oldElementType = op.getValueToStore().getType().getElementType();
-    Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    auto containerElemTy =
+        cast<MemRefType>(adaptor.getBase().getType()).getElementType();
+    Type emulatedElemTy = op.getValueToStore().getType().getElementType();
+    int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+    int containerBits = containerElemTy.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
+    // Check per-element alignment.
+    if (containerBits % emulatedBits != 0) {
       return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+          op, "impossible to pack emulated elements into container elements "
+              "(bit-wise misalignment)");
     }
 
-    int scale = dstBits / srcBits;
+    int scale = containerBits / emulatedBits;
     int origElements = op.getValueToStore().getType().getNumElements();
     if (origElements % scale != 0)
       return failure();
@@ -660,7 +665,7 @@ struct ConvertVectorMaskedStore final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndicesOfr) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, emulatedBits, containerBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -706,7 +711,7 @@ struct ConvertVectorMaskedStore final
       return failure();
 
     auto numElements = (origElements + scale - 1) / scale;
-    auto newType = VectorType::get(numElements, newElementType);
+    auto newType = VectorType::get(numElements, containerElemTy);
     auto passThru = rewriter.create<arith::ConstantOp>(
         loc, newType, rewriter.getZeroAttr(newType));
 
@@ -714,7 +719,7 @@ struct ConvertVectorMaskedStore final
         loc, newType, adaptor.getBase(), linearizedIndices,
         newMask.value()->getResult(0), passThru);
 
-    auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+    auto newBitCastType = VectorType::get(numElements * scale, emulatedElemTy);
     Value valueToStore =
         rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
     valueToStore = rewriter.create<arith::SelectOp>(
@@ -746,17 +751,19 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
                                          "only 1-D vectors are supported ATM");
 
     auto loc = op.getLoc();
-    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
-    Type oldElementType = op.getType().getElementType();
-    Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    auto containerElemTy =
+        cast<MemRefType>(adaptor.getBase().getType()).getElementType();
+    Type emulatedElemTy = op.getType().getElementType();
+    int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+    int containerBits = containerElemTy.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
+    // Check per-element alignment.
+    if (containerBits % emulatedBits != 0) {
       return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+          op, "impossible to pack emulated elements into container elements "
+              "(bit-wise misalignment)");
     }
-    int scale = dstBits / srcBits;
+    int scale = containerBits / emulatedBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -797,7 +804,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, emulatedBits, containerBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -814,7 +821,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
         llvm::divideCeil(maxintraDataOffset + origElements, scale);
     Value result =
         emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
-                           numElements, oldElementType, newElementType);
+                           numElements, emulatedElemTy, containerElemTy);
 
     if (!foldedIntraVectorOffset) {
       auto resultVector = rewriter.create<arith::ConstantOp>(
@@ -848,17 +855,20 @@ struct ConvertVectorMaskedLoad final
                                          "only 1-D vectors are supported ATM");
 
     auto loc = op.getLoc();
-    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
-    Type oldElementType = op.getType().getElementType();
-    Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
+    auto containerElemTy =
+        cast<MemRefType>(adaptor.getBase().getType()).getElementType();
+    Type emulatedElemTy = op.getType().getElementType();
+    int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+    int containerBits = containerElemTy.getIntOrFloatBitWidth();
+
+    // Check per-element alignment.
+    if (containerBits % emulatedBits != 0) {
       return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+          op, "impossible to pack emulated elements into container elements "
+              "(bit-wise misalignment)");
     }
-    int scale = dstBits / srcBits;
+    int scale = containerBits / emulatedBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -912,7 +922,7 @@ struct ConvertVectorMaskedLoad final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, emulatedBits, containerBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -933,8 +943,8 @@ struct ConvertVectorMaskedLoad final
 
     auto numElements =
         llvm::divideCeil(maxIntraDataOffset + origElements, scale);
-    auto loadType = VectorType::get(numElements, newElementType);
-    auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+    auto loadType = VectorType::get(numElements, containerElemTy);
+    auto newBitcastType = VectorType::get(numElements * scale, emulatedElemTy);
 
     auto emptyVector = rewriter.create<arith::ConstantOp>(
         loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -1009,23 +1019,25 @@ struct ConvertVectorTransferRead final
                                          "only 1-D vectors are supported ATM");
 
     auto loc = op.getLoc();
-    auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
-    Type oldElementType = op.getType().getElementType();
-    Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
-
-    if (dstBits % srcBits != 0) {
+    auto containerElemTy =
+        cast<MemRefType>(adaptor.getSource().getType()).getElementType();
+    Type emulatedElemTy = op.getType().getElementType();
+    int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
+    int containerBits = containerElemTy.getIntOrFloatBitWidth();
+
+    // Check per-element alignment.
+    if (containerBits % emulatedBits != 0) {
       return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+          op, "impossible to pack emulated elements into container elements "
+              "(bit-wise misalignment)");
     }
-    int scale = dstBits / srcBits;
+    int scale = containerBits / emulatedBits;
 
     auto origElements = op.getVectorType().getNumElements();
 
     bool isAlignedEmulation = origElements % scale == 0;
 
-    auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
+    auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
                                                       adaptor.getPadding());
 
     auto stridedMetadata =
@@ -1035,7 +1047,7 @@ struct ConvertVectorTransferRead final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, emulatedBits, containerBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -1051,12 +1063,12 @@ struct ConvertVectorTransferRead final
         llvm::divideCeil(maxIntraDataOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
-        loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
+        loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
         newPadding);
 
     auto bitCast = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements * scale, oldElementType), newRead);
+        loc, VectorType::get(numElements * scale, emulatedElemTy), newRead);
 
     Value result = bitCast->getResult(0);
     if (!foldedIntraVectorOffset) {



More information about the Mlir-commits mailing list