[Mlir-commits] [mlir] [MLIR] VectorEmulateNarrowType to support loading of unaligned vectors (PR #113411)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 29 18:14:25 PDT 2024


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/113411

>From cfc1e1d8ad20bcb7d549e5b479f8eb994d9d16ea Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Sun, 20 Oct 2024 14:54:57 +0000
Subject: [PATCH 01/10] [MLIR] VectorEmulateNarrowType support unaligned cases

Previously the pass only supports emulation of vector sizes that are
a multiple of emulated data type (i8). This patch expands its support
of emulation which's size are not a multiple of byte
sizes, such as `vector<3xi2>`.

A limitation of this patch is that the linearized index of the unaligned
vector has to be known at compile time. Extra code needs to be emitted
to handle it if the condition does not hold.

The following ops are updated:
* `vector::LoadOp`
* `vector::StoreOp`
* `vector::TransferReadOp`
---
 .../mlir/Dialect/MemRef/Utils/MemRefUtils.h   |   9 +-
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp |   9 +-
 .../Transforms/VectorEmulateNarrowType.cpp    | 162 +++++++++++++++---
 .../vector-emulate-narrow-type-unaligned.mlir |  55 ++++++
 4 files changed, 204 insertions(+), 31 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index ca3326dbbef519..db32543162b781 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -32,7 +32,8 @@ namespace memref {
 bool isStaticShapeAndContiguousRowMajor(MemRefType type);
 
 /// For a `memref` with `offset`, `sizes` and `strides`, returns the
-/// offset and size to use for the linearized `memref`.
+/// offset, size, and potentially the size padded at the front to use for the
+/// linearized `memref`.
 /// - If the linearization is done for emulating load/stores of
 ///   element type with bitwidth `srcBits` using element type with
 ///   bitwidth `dstBits`, the linearized offset and size are
@@ -42,9 +43,15 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type);
 ///   index to use in the linearized `memref`. The linearized index
 ///   is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided
 ///   0, is returned for the linearized index.
+/// - If the size of the load/store is smaller than the linearized memref
+/// load/store,
+///   the memory region emulated is larger than the actual memory region needed.
+///   `frontPaddingSize` returns the size of the irrelevant offset at the
+///   beginning.
 struct LinearizedMemRefInfo {
   OpFoldResult linearizedOffset;
   OpFoldResult linearizedSize;
+  OpFoldResult frontPaddingSize;
 };
 std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     OpBuilder &builder, Location loc, int srcBits, int dstBits,
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 7321b19068016c..69724bec248827 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -81,11 +81,10 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
 
   // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
-  addMulMap = addMulMap.floorDiv(scaler);
   mulMap = mulMap.floorDiv(scaler);
 
   OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
-      builder, loc, addMulMap, offsetValues);
+      builder, loc, addMulMap.floorDiv(scaler), offsetValues);
   OpFoldResult linearizedSize =
       affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
 
@@ -95,7 +94,11 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
   OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
       builder, loc, s0.floorDiv(scaler), {offset});
 
-  return {{adjustBaseOffset, linearizedSize}, linearizedIndices};
+  OpFoldResult frontPaddingSize = affine::makeComposedFoldedAffineApply(
+      builder, loc, addMulMap % scaler, offsetValues);
+
+  return {{adjustBaseOffset, linearizedSize, frontPaddingSize},
+          linearizedIndices};
 }
 
 LinearizedMemRefInfo
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 66362d3ca70fb6..42a9a2ab12196a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -24,6 +24,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
+#include <optional>
 
 using namespace mlir;
 
@@ -102,6 +103,23 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   return newMask;
 }
 
+///
+static std::optional<int64_t>
+getFrontPaddingSize(ConversionPatternRewriter &rewriter, Location loc,
+                    const memref::LinearizedMemRefInfo linearizedInfo,
+                    bool isUnalignedEmulation) {
+  if (!isUnalignedEmulation)
+    return 0;
+  auto foldedFrontPaddingSize = getValueOrCreateConstantIndexOp(
+      rewriter, loc, linearizedInfo.frontPaddingSize);
+  // try to fold the front padding size into a constant
+  if (auto frontPadding = dyn_cast_or_null<arith::ConstantIndexOp>(
+          foldedFrontPaddingSize.getDefiningOp())) {
+    return frontPadding.value();
+  }
+  return std::nullopt;
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -142,14 +160,17 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector<4xi8>
 
     auto origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+
+    // if the size of vector we are loading is not byte-aligned, extra handling
+    // is needed
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -157,14 +178,48 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = origElements / scale;
-    auto bitCast = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements, newElementType),
-        op.getValueToStore());
+    auto foldedFrontPaddingSize = getFrontPaddingSize(
+        rewriter, loc, linearizedInfo, isUnalignedEmulation);
 
-    rewriter.replaceOpWithNewOp<vector::StoreOp>(
-        op, bitCast.getResult(), adaptor.getBase(),
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+    if (!foldedFrontPaddingSize) {
+      // unimplemented case for dynamic front padding size
+      return failure();
+    }
+
+    auto numElements =
+        (*foldedFrontPaddingSize + origElements + scale - 1) / scale;
+    auto newVectorType = VectorType::get(numElements, newElementType);
+
+    if (isUnalignedEmulation) {
+      auto insertedVectorType =
+          VectorType::get(numElements * scale, oldElementType);
+
+      auto linearizedIndicesValue =
+          getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+      auto passThru =
+          rewriter.create<vector::LoadOp>(loc, newVectorType, adaptor.getBase(),
+                                          ValueRange{linearizedIndicesValue});
+      auto bitcastedPassThru =
+          rewriter.create<vector::BitCastOp>(loc, insertedVectorType, passThru);
+
+      // just extract it and use it for the strided slice offset
+      auto insertStridedSlice = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, insertedVectorType, op.getValueToStore(), bitcastedPassThru,
+          rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
+          rewriter.getI64ArrayAttr({1}));
+      // bit cast the vector to the original type
+      auto bitCast = rewriter.create<vector::BitCastOp>(loc, newVectorType,
+                                                        insertStridedSlice);
+
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(
+          op, bitCast.getResult(), adaptor.getBase(), linearizedIndicesValue);
+    } else {
+      auto bitCast = rewriter.create<vector::BitCastOp>(loc, newVectorType,
+                                                        op.getValueToStore());
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(
+          op, bitCast.getResult(), adaptor.getBase(),
+          getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+    }
     return success();
   }
 };
@@ -294,19 +349,31 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
     // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
     //
-    // TODO: Currently, only the even number of elements loading is supported.
-    // To deal with the odd number of elements, one has to extract the
-    // subvector at the proper offset after bit-casting.
+    // There are cases where the number of elements to load is not byte-aligned,
+    // for example:
+    //
+    // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
+    //
+    // we will have to load extra bytes and extract the exact slice in between.
+    //
+    // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
+    // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
+    // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
+    // = [1]}
+    //        : vector<8xi2> to vector<3xi2>
+    //
+    // TODO: Currently the extract_strided_slice's attributes must be known at
+    // compile time as they must be constants.
 
     auto origElements = op.getVectorType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -314,15 +381,35 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = (origElements + scale - 1) / scale;
+    auto foldedFrontPaddingSize = getFrontPaddingSize(
+        rewriter, loc, linearizedInfo, isUnalignedEmulation);
+
+    if (!foldedFrontPaddingSize) {
+      // unimplemented case for dynamic front padding size
+      return failure();
+    }
+
+    auto numElements =
+        (*foldedFrontPaddingSize + origElements + scale - 1) / scale;
+    auto loadVectorType = VectorType::get(numElements, newElementType);
     auto newLoad = rewriter.create<vector::LoadOp>(
-        loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
+        loc, loadVectorType, adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
 
+    auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
     auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
-
-    rewriter.replaceOp(op, bitCast->getResult(0));
+        rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
+
+    if (newBitCastType.getNumElements() != origElements) {
+      auto extractStridedSlice = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, op.getType(), bitCast,
+          rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
+          rewriter.getI64ArrayAttr({origElements}),
+          rewriter.getI64ArrayAttr({1}));
+      rewriter.replaceOp(op, extractStridedSlice.getResult());
+    } else {
+      rewriter.replaceOp(op, bitCast->getResult(0));
+    }
     return success();
   }
 };
@@ -464,8 +551,8 @@ struct ConvertVectorTransferRead final
     int scale = dstBits / srcBits;
 
     auto origElements = op.getVectorType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
                                                       adaptor.getPadding());
@@ -474,7 +561,8 @@ struct ConvertVectorTransferRead final
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -482,7 +570,16 @@ struct ConvertVectorTransferRead final
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = (origElements + scale - 1) / scale;
+    auto foldedFrontPaddingSize = getFrontPaddingSize(
+        rewriter, loc, linearizedInfo, isUnalignedEmulation);
+
+    if (!foldedFrontPaddingSize) {
+      // unimplemented case for dynamic front padding size
+      return failure();
+    }
+
+    auto numElements =
+        (*foldedFrontPaddingSize + origElements + scale - 1) / scale;
     auto newReadType = VectorType::get(numElements, newElementType);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
@@ -490,10 +587,21 @@ struct ConvertVectorTransferRead final
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
         newPadding);
 
+    auto bitCastType = VectorType::get(numElements * scale, oldElementType);
     auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
+        rewriter.create<vector::BitCastOp>(loc, bitCastType, newRead);
+
+    if (isUnalignedEmulation) {
+      // we only extract a portion of the vector.
+      rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
+          op, op.getType(), bitCast,
+          rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
+          rewriter.getI64ArrayAttr({origElements}),
+          rewriter.getI64ArrayAttr({1}));
+    } else {
+      rewriter.replaceOp(op, bitCast->getResult(0));
+    }
 
-    rewriter.replaceOp(op, bitCast->getResult(0));
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
new file mode 100644
index 00000000000000..eebd7c74f44766
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+
+func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %cst = arith.constant dense<0> : vector<3x3xi2>
+    %1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
+    %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
+    return %2 : vector<3x3xi2>
+}
+
+// CHECK: func @vector_load_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = arith.constant 1 : index
+// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[VEC_I2:.+]] = vector.bitcast %[[VEC]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[EXCTRACT:.+]] = vector.extract_strided_slice %[[VEC_I2]] {offsets = [2], sizes = [3], strides = [1]} : vector<8xi2> to vector<3xi2>
+
+//-----
+
+func.func @vector_store_i2(%arg0: vector<3xi2>) {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+    return
+}
+
+// CHECK: func @vector_store_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = arith.constant 1 : index
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[LOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %arg0, %[[BITCAST1]] {offsets = [2], strides = [1]} : vector<3xi2> into vector<8xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[INSERT]] : vector<8xi2> to vector<2xi8>
+// CHECK: vector.store %[[BITCAST2]], %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8> 
+
+//-----
+
+func.func @vector_transfer_read_i2() -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0i2 = arith.constant 0 : i2
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %1 = vector.transfer_read %0[%c2, %c0], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: func @vector_transfer_read_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = arith.constant 1 : index
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %0 : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: vector.extract_strided_slice %[[BITCAST]] {offsets = [2], sizes = [3], strides = [1]} : vector<8xi2> to vector<3xi2>

>From a855adf6b03ef178c9495733d5d9cb123e7c2142 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Wed, 23 Oct 2024 15:49:41 +0000
Subject: [PATCH 02/10] Remove StoreOp

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 57 ++++---------------
 .../vector-emulate-narrow-type-unaligned.mlir | 19 -------
 2 files changed, 10 insertions(+), 66 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 42a9a2ab12196a..096f30648f49e7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -160,17 +160,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector<4xi8>
 
     auto origElements = op.getValueToStore().getType().getNumElements();
-
-    // if the size of vector we are loading is not byte-aligned, extra handling
-    // is needed
-    bool isUnalignedEmulation = origElements % scale != 0;
+    if (origElements % scale != 0)
+      return failure();
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
     OpFoldResult linearizedIndices;
-    memref::LinearizedMemRefInfo linearizedInfo;
-    std::tie(linearizedInfo, linearizedIndices) =
+    std::tie(std::ignore, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -178,48 +175,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto foldedFrontPaddingSize = getFrontPaddingSize(
-        rewriter, loc, linearizedInfo, isUnalignedEmulation);
-
-    if (!foldedFrontPaddingSize) {
-      // unimplemented case for dynamic front padding size
-      return failure();
-    }
-
-    auto numElements =
-        (*foldedFrontPaddingSize + origElements + scale - 1) / scale;
-    auto newVectorType = VectorType::get(numElements, newElementType);
-
-    if (isUnalignedEmulation) {
-      auto insertedVectorType =
-          VectorType::get(numElements * scale, oldElementType);
-
-      auto linearizedIndicesValue =
-          getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
-      auto passThru =
-          rewriter.create<vector::LoadOp>(loc, newVectorType, adaptor.getBase(),
-                                          ValueRange{linearizedIndicesValue});
-      auto bitcastedPassThru =
-          rewriter.create<vector::BitCastOp>(loc, insertedVectorType, passThru);
-
-      // just extract it and use it for the strided slice offset
-      auto insertStridedSlice = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, insertedVectorType, op.getValueToStore(), bitcastedPassThru,
-          rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
-          rewriter.getI64ArrayAttr({1}));
-      // bit cast the vector to the original type
-      auto bitCast = rewriter.create<vector::BitCastOp>(loc, newVectorType,
-                                                        insertStridedSlice);
+    auto numElements = origElements / scale;
+    auto bitCast = rewriter.create<vector::BitCastOp>(
+        loc, VectorType::get(numElements, newElementType),
+        op.getValueToStore());
 
-      rewriter.replaceOpWithNewOp<vector::StoreOp>(
-          op, bitCast.getResult(), adaptor.getBase(), linearizedIndicesValue);
-    } else {
-      auto bitCast = rewriter.create<vector::BitCastOp>(loc, newVectorType,
-                                                        op.getValueToStore());
-      rewriter.replaceOpWithNewOp<vector::StoreOp>(
-          op, bitCast.getResult(), adaptor.getBase(),
-          getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
-    }
+    rewriter.replaceOpWithNewOp<vector::StoreOp>(
+        op, bitCast.getResult(), adaptor.getBase(),
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index eebd7c74f44766..329ab2164c9b5c 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -19,25 +19,6 @@ func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
 
 //-----
 
-func.func @vector_store_i2(%arg0: vector<3xi2>) {
-    %0 = memref.alloc() : memref<3x3xi2>
-    %c0 = arith.constant 0 : index
-    %c2 = arith.constant 2 : index
-    vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
-    return
-}
-
-// CHECK: func @vector_store_i2
-// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
-// CHECK: %[[INDEX:.+]] = arith.constant 1 : index
-// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8>
-// CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[LOAD]] : vector<2xi8> to vector<8xi2>
-// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %arg0, %[[BITCAST1]] {offsets = [2], strides = [1]} : vector<3xi2> into vector<8xi2>
-// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[INSERT]] : vector<8xi2> to vector<2xi8>
-// CHECK: vector.store %[[BITCAST2]], %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8> 
-
-//-----
-
 func.func @vector_transfer_read_i2() -> vector<3xi2> {
  %0 = memref.alloc() : memref<3x3xi2>
  %c0i2 = arith.constant 0 : i2

>From b00a45a960c38bb397481022b1a809b366406cdd Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Thu, 24 Oct 2024 00:01:30 +0000
Subject: [PATCH 03/10] update and refactor

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 64 ++++++++++---------
 1 file changed, 33 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 096f30648f49e7..1b868ea9a8c705 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -103,7 +103,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   return newMask;
 }
 
-///
 static std::optional<int64_t>
 getFrontPaddingSize(ConversionPatternRewriter &rewriter, Location loc,
                     const memref::LinearizedMemRefInfo linearizedInfo,
@@ -120,6 +119,17 @@ getFrontPaddingSize(ConversionPatternRewriter &rewriter, Location loc,
   return std::nullopt;
 }
 
+static OpResult extractSubvector(ConversionPatternRewriter &rewriter,
+                                 Location loc, VectorType extractType,
+                                 Value vector, int64_t frontOffset,
+                                 int64_t subvecSize) {
+  return rewriter
+      .create<vector::ExtractStridedSliceOp>(
+          loc, extractType, vector, rewriter.getI64ArrayAttr({frontOffset}),
+          rewriter.getI64ArrayAttr({subvecSize}), rewriter.getI64ArrayAttr({1}))
+      ->getResult(0);
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -353,26 +363,24 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     }
 
     auto numElements =
-        (*foldedFrontPaddingSize + origElements + scale - 1) / scale;
-    auto loadVectorType = VectorType::get(numElements, newElementType);
+        llvm::alignTo(*foldedFrontPaddingSize + origElements, scale) / scale;
     auto newLoad = rewriter.create<vector::LoadOp>(
-        loc, loadVectorType, adaptor.getBase(),
+        loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
 
-    auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
-    auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
-
-    if (newBitCastType.getNumElements() != origElements) {
-      auto extractStridedSlice = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, op.getType(), bitCast,
-          rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
-          rewriter.getI64ArrayAttr({origElements}),
-          rewriter.getI64ArrayAttr({1}));
-      rewriter.replaceOp(op, extractStridedSlice.getResult());
-    } else {
-      rewriter.replaceOp(op, bitCast->getResult(0));
+    OpResult castedResult =
+        rewriter
+            .create<vector::BitCastOp>(
+                loc, VectorType::get(numElements * scale, oldElementType),
+                newLoad)
+            ->getResult(0);
+
+    if (isUnalignedEmulation) {
+      castedResult = extractSubvector(rewriter, loc, op.getType(), castedResult,
+                                      *foldedFrontPaddingSize, origElements);
     }
+
+    rewriter.replaceOp(op, castedResult);
     return success();
   }
 };
@@ -542,28 +550,22 @@ struct ConvertVectorTransferRead final
     }
 
     auto numElements =
-        (*foldedFrontPaddingSize + origElements + scale - 1) / scale;
-    auto newReadType = VectorType::get(numElements, newElementType);
+        llvm::alignTo(*foldedFrontPaddingSize + origElements, scale) / scale;
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
-        loc, newReadType, adaptor.getSource(),
+        loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
         newPadding);
 
-    auto bitCastType = VectorType::get(numElements * scale, oldElementType);
-    auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, bitCastType, newRead);
+    auto bitCast = rewriter.create<vector::BitCastOp>(
+        loc, VectorType::get(numElements * scale, oldElementType), newRead);
 
+    auto bitCastResult = bitCast->getResult(0);
     if (isUnalignedEmulation) {
-      // we only extract a portion of the vector.
-      rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
-          op, op.getType(), bitCast,
-          rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
-          rewriter.getI64ArrayAttr({origElements}),
-          rewriter.getI64ArrayAttr({1}));
-    } else {
-      rewriter.replaceOp(op, bitCast->getResult(0));
+      bitCastResult = extractSubvector(rewriter, loc, op.getType(), bitCast,
+                                       *foldedFrontPaddingSize, origElements);
     }
+    rewriter.replaceOp(op, bitCastResult);
 
     return success();
   }

>From 6cf80dc0d90bcd43a3053a9cbb179f08df188748 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Wed, 23 Oct 2024 20:44:47 +0000
Subject: [PATCH 04/10] Implement mask load

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 120 ++++++++++++++----
 .../vector-emulate-narrow-type-unaligned.mlir |  31 +++++
 2 files changed, 124 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 1b868ea9a8c705..e1a6d6ed10c523 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -43,8 +43,9 @@ using namespace mlir;
 ///   %mask = [1, 1, 0]
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
-                                                  int origElements, int scale) {
-  auto numElements = (origElements + scale - 1) / scale;
+                                                  int origElements, int scale,
+                                                  int frontOffset = 0) {
+  auto numElements = (frontOffset + origElements + scale - 1) / scale;
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
@@ -68,6 +69,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   shape.back() = numElements;
   auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
   if (createMaskOp) {
+    if (frontOffset != 0) {
+      assert(false && "unimplemented case for frontOffset != 0");
+      return failure();
+    }
     OperandRange maskOperands = createMaskOp.getOperands();
     size_t numMaskOperands = maskOperands.size();
     AffineExpr s0;
@@ -87,11 +92,27 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     size_t numMaskOperands = maskDimSizes.size();
     int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-    int64_t maskIndex = (origIndex + scale - 1) / scale;
+    int64_t startIndex = frontOffset / scale;
+    int64_t maskIndex = llvm::alignTo(frontOffset + origIndex, scale) / scale;
+
+    // TODO: we only want the mask between [startIndex, maskIndex] to be true,
+    // the rest are false.
+    if (frontOffset != 0 && maskDimSizes.size() > 1)
+      return failure();
+
     SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
     newMaskDimSizes.push_back(maskIndex);
-    newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
-                                                      newMaskDimSizes);
+
+    if (frontOffset == 0) {
+      newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+                                                        newMaskDimSizes);
+    } else {
+      SmallVector<bool> newMaskValues;
+      for (int64_t i = 0; i < numElements; ++i)
+        newMaskValues.push_back(i >= startIndex && i < maskIndex);
+      auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
+      newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
+    }
   }
 
   while (!extractOps.empty()) {
@@ -229,7 +250,8 @@ struct ConvertVectorMaskedStore final
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
     OpFoldResult linearizedIndicesOfr;
-    std::tie(std::ignore, linearizedIndicesOfr) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndicesOfr) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -242,19 +264,19 @@ struct ConvertVectorMaskedStore final
     // Load the whole data and use arith.select to handle the corner cases.
     // E.g., given these input values:
     //
-    //   %mask = [1, 1, 1, 0, 0, 0]
-    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
-    //   %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
+    //   %mask = [0, 1, 1, 1, 1, 1, 0, 0]
+    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
+    //   %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
     //
     // we'll have
     //
-    //    expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
+    //    expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
     //
-    //    %new_mask = [1, 1, 0]
-    //    %maskedload = [0x12, 0x34, 0x0]
-    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
-    //    %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
-    //    %packed_data = [0x78, 0x94, 0x00]
+    //    %new_mask = [1, 1, 1, 0]
+    //    %maskedload = [0x12, 0x34, 0x56, 0x00]
+    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
+    //    %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
+    //    %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
     //
     // Using the new mask to store %packed_data results in expected output.
     FailureOr<Operation *> newMask =
@@ -271,8 +293,9 @@ struct ConvertVectorMaskedStore final
         loc, newType, adaptor.getBase(), linearizedIndices,
         newMask.value()->getResult(0), passThru);
 
-    Value valueToStore = rewriter.create<vector::BitCastOp>(
-        loc, op.getValueToStore().getType(), newLoad);
+    auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+    Value valueToStore =
+        rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
     valueToStore = rewriter.create<arith::SelectOp>(
         loc, op.getMask(), op.getValueToStore(), valueToStore);
     valueToStore =
@@ -454,13 +477,13 @@ struct ConvertVectorMaskedLoad final
     // subvector at the proper offset after bit-casting.
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -468,15 +491,37 @@ struct ConvertVectorMaskedLoad final
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
+    auto foldedFrontPaddingSize = getFrontPaddingSize(
+        rewriter, loc, linearizedInfo, isUnalignedEmulation);
+    if (!foldedFrontPaddingSize) {
+      // unimplemented case for dynamic front padding size
+      return failure();
+    }
+
     FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
+                            *foldedFrontPaddingSize);
     if (failed(newMask))
       return failure();
 
-    auto numElements = (origElements + scale - 1) / scale;
+    auto numElements =
+        llvm::alignTo(*foldedFrontPaddingSize + origElements, scale) / scale;
     auto newType = VectorType::get(numElements, newElementType);
+
+    auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+
+    Value passthru = op.getPassThru();
+    if (isUnalignedEmulation) {
+      // create an empty vector of the new type
+      auto emptyVector = rewriter.create<arith::ConstantOp>(
+          loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+      passthru = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, newBitcastType, op.getPassThru(), emptyVector,
+          rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
+          rewriter.getI64ArrayAttr({1}));
+    }
     auto newPassThru =
-        rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
+        rewriter.create<vector::BitCastOp>(loc, newType, passthru);
 
     // Generating the new masked load.
     auto newLoad = rewriter.create<vector::MaskedLoadOp>(
@@ -487,10 +532,31 @@ struct ConvertVectorMaskedLoad final
     // Setting the part that originally was not effectively loaded from memory
     // to pass through.
     auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
-    auto select = rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast,
-                                                   op.getPassThru());
-    rewriter.replaceOp(op, select->getResult(0));
+        rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
+
+    auto mask = op.getMask();
+    if (isUnalignedEmulation) {
+      auto newSelectMaskType =
+          VectorType::get(numElements * scale, rewriter.getI1Type());
+      // TODO: can fold if op's mask is constant
+      mask = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, newSelectMaskType, op.getMask(),
+          rewriter.create<arith::ConstantOp>(
+              loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType)),
+          rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
+          rewriter.getI64ArrayAttr({1}));
+    }
+
+    auto select =
+        rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
+
+    if (isUnalignedEmulation) {
+      auto extract = extractSubvector(rewriter, loc, op.getType(), select,
+                                      *foldedFrontPaddingSize, origElements);
+      rewriter.replaceOp(op, extract);
+    } else {
+      rewriter.replaceOp(op, select->getResult(0));
+    }
 
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 329ab2164c9b5c..7ecbad7968225d 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -34,3 +34,34 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
 // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %0 : memref<3xi8>, vector<2xi8>
 // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
 // CHECK: vector.extract_strided_slice %[[BITCAST]] {offsets = [2], sizes = [3], strides = [1]} : vector<8xi2> to vector<3xi2>
+
+//-----
+
+func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
+    %0 = memref.alloc() : memref<3x5xi2>
+    %cst = arith.constant dense<0> : vector<3x5xi2>
+    %mask = vector.constant_mask [3] : vector<5xi1>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
+      memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+    %2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
+    return %2 : vector<3x5xi2>
+}
+
+// CHECK: func @vector_cst_maskedload_i2
+// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
+// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %arg0, %[[VESSEL]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
+// CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %alloc[%[[C2]]], %[[NEWMASK:.+]], %[[BITCAST1]]
+// CHECK-SAME: : memref<4xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi1> into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BITCAST2]], %[[INSERT1]] : vector<8xi1>, vector<8xi2>
+// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [2], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2> 

>From 6039b914675642a3b8b8330222d7acf344ebffeb Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Thu, 24 Oct 2024 18:24:39 +0000
Subject: [PATCH 05/10] Refactor

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 53 +++++++++++--------
 1 file changed, 30 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e1a6d6ed10c523..ceef267cc100be 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -140,10 +140,10 @@ getFrontPaddingSize(ConversionPatternRewriter &rewriter, Location loc,
   return std::nullopt;
 }
 
-static OpResult extractSubvector(ConversionPatternRewriter &rewriter,
-                                 Location loc, VectorType extractType,
-                                 Value vector, int64_t frontOffset,
-                                 int64_t subvecSize) {
+static Value extractSubvectorFrom(ConversionPatternRewriter &rewriter,
+                                  Location loc, VectorType extractType,
+                                  Value vector, int64_t frontOffset,
+                                  int64_t subvecSize) {
   return rewriter
       .create<vector::ExtractStridedSliceOp>(
           loc, extractType, vector, rewriter.getI64ArrayAttr({frontOffset}),
@@ -151,6 +151,14 @@ static OpResult extractSubvector(ConversionPatternRewriter &rewriter,
       ->getResult(0);
 }
 
+static Value insertSubvectorInto(ConversionPatternRewriter &rewriter,
+                                 Location loc, Value src, Value dest,
+                                 int64_t offset) {
+  return rewriter.create<vector::InsertStridedSliceOp>(
+      loc, dest.getType(), src, dest, rewriter.getI64ArrayAttr({offset}),
+      rewriter.getI64ArrayAttr({1}));
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -391,7 +399,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
         loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
 
-    OpResult castedResult =
+    Value castedResult =
         rewriter
             .create<vector::BitCastOp>(
                 loc, VectorType::get(numElements * scale, oldElementType),
@@ -399,8 +407,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             ->getResult(0);
 
     if (isUnalignedEmulation) {
-      castedResult = extractSubvector(rewriter, loc, op.getType(), castedResult,
-                                      *foldedFrontPaddingSize, origElements);
+      castedResult =
+          extractSubvectorFrom(rewriter, loc, op.getType(), castedResult,
+                               *foldedFrontPaddingSize, origElements);
     }
 
     rewriter.replaceOp(op, castedResult);
@@ -515,10 +524,8 @@ struct ConvertVectorMaskedLoad final
       // create an empty vector of the new type
       auto emptyVector = rewriter.create<arith::ConstantOp>(
           loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
-      passthru = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, newBitcastType, op.getPassThru(), emptyVector,
-          rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
-          rewriter.getI64ArrayAttr({1}));
+      passthru = insertSubvectorInto(rewriter, loc, op.getPassThru(),
+                                     emptyVector, *foldedFrontPaddingSize);
     }
     auto newPassThru =
         rewriter.create<vector::BitCastOp>(loc, newType, passthru);
@@ -534,25 +541,24 @@ struct ConvertVectorMaskedLoad final
     auto bitCast =
         rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
 
-    auto mask = op.getMask();
+    Value mask = op.getMask();
     if (isUnalignedEmulation) {
       auto newSelectMaskType =
           VectorType::get(numElements * scale, rewriter.getI1Type());
       // TODO: can fold if op's mask is constant
-      mask = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, newSelectMaskType, op.getMask(),
-          rewriter.create<arith::ConstantOp>(
-              loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType)),
-          rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
-          rewriter.getI64ArrayAttr({1}));
+      auto emptyVector = rewriter.create<arith::ConstantOp>(
+          loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
+      mask = insertSubvectorInto(rewriter, loc, op.getMask(), emptyVector,
+                                 *foldedFrontPaddingSize);
     }
 
     auto select =
         rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
 
     if (isUnalignedEmulation) {
-      auto extract = extractSubvector(rewriter, loc, op.getType(), select,
-                                      *foldedFrontPaddingSize, origElements);
+      auto extract =
+          extractSubvectorFrom(rewriter, loc, op.getType(), select,
+                               *foldedFrontPaddingSize, origElements);
       rewriter.replaceOp(op, extract);
     } else {
       rewriter.replaceOp(op, select->getResult(0));
@@ -626,10 +632,11 @@ struct ConvertVectorTransferRead final
     auto bitCast = rewriter.create<vector::BitCastOp>(
         loc, VectorType::get(numElements * scale, oldElementType), newRead);
 
-    auto bitCastResult = bitCast->getResult(0);
+    Value bitCastResult = bitCast->getResult(0);
     if (isUnalignedEmulation) {
-      bitCastResult = extractSubvector(rewriter, loc, op.getType(), bitCast,
-                                       *foldedFrontPaddingSize, origElements);
+      bitCastResult =
+          extractSubvectorFrom(rewriter, loc, op.getType(), bitCastResult,
+                               *foldedFrontPaddingSize, origElements);
     }
     rewriter.replaceOp(op, bitCastResult);
 

>From b669a0dfe0a14baa9ba9d76817b0a3260b79991a Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Fri, 25 Oct 2024 15:37:41 +0000
Subject: [PATCH 06/10] Updates

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 79 ++++++++-----------
 1 file changed, 34 insertions(+), 45 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ceef267cc100be..e782291308a6c1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -22,6 +23,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
 #include <optional>
@@ -70,7 +72,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
   if (createMaskOp) {
     if (frontOffset != 0) {
-      assert(false && "unimplemented case for frontOffset != 0");
       return failure();
     }
     OperandRange maskOperands = createMaskOp.getOperands();
@@ -93,7 +94,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     size_t numMaskOperands = maskDimSizes.size();
     int64_t origIndex = maskDimSizes[numMaskOperands - 1];
     int64_t startIndex = frontOffset / scale;
-    int64_t maskIndex = llvm::alignTo(frontOffset + origIndex, scale) / scale;
+    int64_t maskIndex = llvm::divideCeil(frontOffset + origIndex, scale);
 
     // TODO: we only want the mask between [startIndex, maskIndex] to be true,
     // the rest are false.
@@ -132,31 +133,29 @@ getFrontPaddingSize(ConversionPatternRewriter &rewriter, Location loc,
     return 0;
   auto foldedFrontPaddingSize = getValueOrCreateConstantIndexOp(
       rewriter, loc, linearizedInfo.frontPaddingSize);
-  // try to fold the front padding size into a constant
-  if (auto frontPadding = dyn_cast_or_null<arith::ConstantIndexOp>(
-          foldedFrontPaddingSize.getDefiningOp())) {
-    return frontPadding.value();
-  }
-  return std::nullopt;
+  return getConstantIntValue(foldedFrontPaddingSize);
 }
 
 static Value extractSubvectorFrom(ConversionPatternRewriter &rewriter,
                                   Location loc, VectorType extractType,
                                   Value vector, int64_t frontOffset,
                                   int64_t subvecSize) {
+  auto offsets = rewriter.getI64ArrayAttr({frontOffset});
+  auto sizes = rewriter.getI64ArrayAttr({subvecSize});
+  auto strides = rewriter.getI64ArrayAttr({1});
   return rewriter
-      .create<vector::ExtractStridedSliceOp>(
-          loc, extractType, vector, rewriter.getI64ArrayAttr({frontOffset}),
-          rewriter.getI64ArrayAttr({subvecSize}), rewriter.getI64ArrayAttr({1}))
+      .create<vector::ExtractStridedSliceOp>(loc, extractType, vector, offsets,
+                                             sizes, strides)
       ->getResult(0);
 }
 
 static Value insertSubvectorInto(ConversionPatternRewriter &rewriter,
                                  Location loc, Value src, Value dest,
                                  int64_t offset) {
-  return rewriter.create<vector::InsertStridedSliceOp>(
-      loc, dest.getType(), src, dest, rewriter.getI64ArrayAttr({offset}),
-      rewriter.getI64ArrayAttr({1}));
+  auto offsets = rewriter.getI64ArrayAttr({offset});
+  auto strides = rewriter.getI64ArrayAttr({1});
+  return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
+                                                       dest, offsets, strides);
 }
 
 namespace {
@@ -394,25 +393,20 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     }
 
     auto numElements =
-        llvm::alignTo(*foldedFrontPaddingSize + origElements, scale) / scale;
+        llvm::divideCeil(*foldedFrontPaddingSize + origElements, scale);
     auto newLoad = rewriter.create<vector::LoadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
 
-    Value castedResult =
-        rewriter
-            .create<vector::BitCastOp>(
-                loc, VectorType::get(numElements * scale, oldElementType),
-                newLoad)
-            ->getResult(0);
+    Value result = rewriter.create<vector::BitCastOp>(
+        loc, VectorType::get(numElements * scale, oldElementType), newLoad);
 
     if (isUnalignedEmulation) {
-      castedResult =
-          extractSubvectorFrom(rewriter, loc, op.getType(), castedResult,
-                               *foldedFrontPaddingSize, origElements);
+      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                    *foldedFrontPaddingSize, origElements);
     }
 
-    rewriter.replaceOp(op, castedResult);
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -514,9 +508,8 @@ struct ConvertVectorMaskedLoad final
       return failure();
 
     auto numElements =
-        llvm::alignTo(*foldedFrontPaddingSize + origElements, scale) / scale;
-    auto newType = VectorType::get(numElements, newElementType);
-
+        llvm::divideCeil(*foldedFrontPaddingSize + origElements, scale);
+    auto loadType = VectorType::get(numElements, newElementType);
     auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
 
     Value passthru = op.getPassThru();
@@ -524,15 +517,15 @@ struct ConvertVectorMaskedLoad final
       // create an empty vector of the new type
       auto emptyVector = rewriter.create<arith::ConstantOp>(
           loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
-      passthru = insertSubvectorInto(rewriter, loc, op.getPassThru(),
-                                     emptyVector, *foldedFrontPaddingSize);
+      passthru = insertSubvectorInto(rewriter, loc, passthru, emptyVector,
+                                     *foldedFrontPaddingSize);
     }
     auto newPassThru =
-        rewriter.create<vector::BitCastOp>(loc, newType, passthru);
+        rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
 
     // Generating the new masked load.
     auto newLoad = rewriter.create<vector::MaskedLoadOp>(
-        loc, newType, adaptor.getBase(),
+        loc, loadType, adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
         newMask.value()->getResult(0), newPassThru);
 
@@ -552,17 +545,14 @@ struct ConvertVectorMaskedLoad final
                                  *foldedFrontPaddingSize);
     }
 
-    auto select =
+    Value result =
         rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
 
     if (isUnalignedEmulation) {
-      auto extract =
-          extractSubvectorFrom(rewriter, loc, op.getType(), select,
-                               *foldedFrontPaddingSize, origElements);
-      rewriter.replaceOp(op, extract);
-    } else {
-      rewriter.replaceOp(op, select->getResult(0));
+      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                    *foldedFrontPaddingSize, origElements);
     }
+    rewriter.replaceOp(op, result);
 
     return success();
   }
@@ -622,7 +612,7 @@ struct ConvertVectorTransferRead final
     }
 
     auto numElements =
-        llvm::alignTo(*foldedFrontPaddingSize + origElements, scale) / scale;
+        llvm::divideCeil(*foldedFrontPaddingSize + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -632,13 +622,12 @@ struct ConvertVectorTransferRead final
     auto bitCast = rewriter.create<vector::BitCastOp>(
         loc, VectorType::get(numElements * scale, oldElementType), newRead);
 
-    Value bitCastResult = bitCast->getResult(0);
+    Value result = bitCast->getResult(0);
     if (isUnalignedEmulation) {
-      bitCastResult =
-          extractSubvectorFrom(rewriter, loc, op.getType(), bitCastResult,
-                               *foldedFrontPaddingSize, origElements);
+      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                    *foldedFrontPaddingSize, origElements);
     }
-    rewriter.replaceOp(op, bitCastResult);
+    rewriter.replaceOp(op, result);
 
     return success();
   }

>From cce4aae34e01f4788be9cdf6da8714e06d2fe033 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Fri, 25 Oct 2024 19:31:08 +0000
Subject: [PATCH 07/10] updates

---
 .../Vector/Transforms/VectorEmulateNarrowType.cpp      | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e782291308a6c1..495464c31598ed 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -36,13 +36,17 @@ using namespace mlir;
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 /// Returns a compressed mask. The mask value is set only if any mask is present
-/// in the scale range. E.g., if `scale` equals to 2, the following mask:
+/// in the scale range. E.g., if `scale` equals to 2, and `frontOffset` equals
+/// to 2, the following mask:
 ///
 ///   %mask = [1, 1, 1, 0, 0, 0]
 ///
-/// will return the following new compressed mask:
+/// will first be padded with frontOffset zeros:
+///   %mask = [0, 0, 1, 1, 1, 0, 0, 0]
 ///
-///   %mask = [1, 1, 0]
+/// then it will return the following new compressed mask:
+///
+///   %mask = [0, 1, 1, 0]
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
                                                   int origElements, int scale,

>From 6a0b7c263295466352387016393ef51dcb1b0a81 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Mon, 28 Oct 2024 15:48:03 +0000
Subject: [PATCH 08/10] Fix according to comments

---
 .../mlir/Dialect/MemRef/Utils/MemRefUtils.h   | 11 ++-
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp |  4 +-
 .../Transforms/VectorEmulateNarrowType.cpp    | 82 ++++++++++---------
 3 files changed, 50 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index db32543162b781..6315e9020cf400 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -44,14 +44,13 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type);
 ///   is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided
 ///   0, is returned for the linearized index.
 /// - If the size of the load/store is smaller than the linearized memref
-/// load/store,
-///   the memory region emulated is larger than the actual memory region needed.
-///   `frontPaddingSize` returns the size of the irrelevant offset at the
-///   beginning.
+/// load/store, the memory region emulated is larger than the actual memory
+/// region needed. `intraVectorOffset` returns the element offset of the data
+/// relevant at the beginning.
 struct LinearizedMemRefInfo {
   OpFoldResult linearizedOffset;
   OpFoldResult linearizedSize;
-  OpFoldResult frontPaddingSize;
+  OpFoldResult intraVectorOffset;
 };
 std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     OpBuilder &builder, Location loc, int srcBits, int dstBits,
@@ -120,4 +119,4 @@ MemrefValue skipViewLikeOps(MemrefValue source);
 } // namespace memref
 } // namespace mlir
 
-#endif // MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H
+#endif // MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H
\ No newline at end of file
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 69724bec248827..6de744a7f75244 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -94,10 +94,10 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
   OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
       builder, loc, s0.floorDiv(scaler), {offset});
 
-  OpFoldResult frontPaddingSize = affine::makeComposedFoldedAffineApply(
+  OpFoldResult intraVectorOffset = affine::makeComposedFoldedAffineApply(
       builder, loc, addMulMap % scaler, offsetValues);
 
-  return {{adjustBaseOffset, linearizedSize, frontPaddingSize},
+  return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
           linearizedIndices};
 }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 495464c31598ed..48672a695d9db8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -36,12 +36,12 @@ using namespace mlir;
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 /// Returns a compressed mask. The mask value is set only if any mask is present
-/// in the scale range. E.g., if `scale` equals to 2, and `frontOffset` equals
-/// to 2, the following mask:
+/// in the scale range. E.g., if `scale` equals to 2, and `intraVectorOffset`
+/// equals to 2, the following mask:
 ///
 ///   %mask = [1, 1, 1, 0, 0, 0]
 ///
-/// will first be padded with frontOffset zeros:
+/// will first be padded with number of `intraVectorOffset` zeros:
 ///   %mask = [0, 0, 1, 1, 1, 0, 0, 0]
 ///
 /// then it will return the following new compressed mask:
@@ -50,8 +50,8 @@ using namespace mlir;
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
                                                   int origElements, int scale,
-                                                  int frontOffset = 0) {
-  auto numElements = (frontOffset + origElements + scale - 1) / scale;
+                                                  int intraVectorOffset = 0) {
+  auto numElements = (intraVectorOffset + origElements + scale - 1) / scale;
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
@@ -75,7 +75,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   shape.back() = numElements;
   auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
   if (createMaskOp) {
-    if (frontOffset != 0) {
+    if (intraVectorOffset != 0) {
       return failure();
     }
     OperandRange maskOperands = createMaskOp.getOperands();
@@ -97,18 +97,18 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     size_t numMaskOperands = maskDimSizes.size();
     int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-    int64_t startIndex = frontOffset / scale;
-    int64_t maskIndex = llvm::divideCeil(frontOffset + origIndex, scale);
+    int64_t startIndex = intraVectorOffset / scale;
+    int64_t maskIndex = llvm::divideCeil(intraVectorOffset + origIndex, scale);
 
     // TODO: we only want the mask between [startIndex, maskIndex] to be true,
     // the rest are false.
-    if (frontOffset != 0 && maskDimSizes.size() > 1)
+    if (intraVectorOffset != 0 && maskDimSizes.size() > 1)
       return failure();
 
     SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
     newMaskDimSizes.push_back(maskIndex);
 
-    if (frontOffset == 0) {
+    if (intraVectorOffset == 0) {
       newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
                                                         newMaskDimSizes);
     } else {
@@ -130,14 +130,11 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
 }
 
 static std::optional<int64_t>
-getFrontPaddingSize(ConversionPatternRewriter &rewriter, Location loc,
-                    const memref::LinearizedMemRefInfo linearizedInfo,
-                    bool isUnalignedEmulation) {
-  if (!isUnalignedEmulation)
-    return 0;
-  auto foldedFrontPaddingSize = getValueOrCreateConstantIndexOp(
-      rewriter, loc, linearizedInfo.frontPaddingSize);
-  return getConstantIntValue(foldedFrontPaddingSize);
+getIntraVectorOffset(ConversionPatternRewriter &rewriter, Location loc,
+                     const memref::LinearizedMemRefInfo linearizedInfo) {
+  auto foldedIntraVectorOffset = getValueOrCreateConstantIndexOp(
+      rewriter, loc, linearizedInfo.intraVectorOffset);
+  return getConstantIntValue(foldedIntraVectorOffset);
 }
 
 static Value extractSubvectorFrom(ConversionPatternRewriter &rewriter,
@@ -388,16 +385,18 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto foldedFrontPaddingSize = getFrontPaddingSize(
-        rewriter, loc, linearizedInfo, isUnalignedEmulation);
+    auto foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getIntraVectorOffset(rewriter, loc, linearizedInfo)
+            : 0;
 
-    if (!foldedFrontPaddingSize) {
-      // unimplemented case for dynamic front padding size
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic intra vector offset
       return failure();
     }
 
     auto numElements =
-        llvm::divideCeil(*foldedFrontPaddingSize + origElements, scale);
+        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
     auto newLoad = rewriter.create<vector::LoadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
@@ -407,7 +406,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
 
     if (isUnalignedEmulation) {
       result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedFrontPaddingSize, origElements);
+                                    *foldedIntraVectorOffset, origElements);
     }
 
     rewriter.replaceOp(op, result);
@@ -498,21 +497,24 @@ struct ConvertVectorMaskedLoad final
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto foldedFrontPaddingSize = getFrontPaddingSize(
-        rewriter, loc, linearizedInfo, isUnalignedEmulation);
-    if (!foldedFrontPaddingSize) {
-      // unimplemented case for dynamic front padding size
+    auto foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getIntraVectorOffset(rewriter, loc, linearizedInfo)
+            : 0;
+
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic intra vector offset
       return failure();
     }
 
     FailureOr<Operation *> newMask =
         getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
-                            *foldedFrontPaddingSize);
+                            *foldedIntraVectorOffset);
     if (failed(newMask))
       return failure();
 
     auto numElements =
-        llvm::divideCeil(*foldedFrontPaddingSize + origElements, scale);
+        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
     auto loadType = VectorType::get(numElements, newElementType);
     auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
 
@@ -522,7 +524,7 @@ struct ConvertVectorMaskedLoad final
       auto emptyVector = rewriter.create<arith::ConstantOp>(
           loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
       passthru = insertSubvectorInto(rewriter, loc, passthru, emptyVector,
-                                     *foldedFrontPaddingSize);
+                                     *foldedIntraVectorOffset);
     }
     auto newPassThru =
         rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -546,7 +548,7 @@ struct ConvertVectorMaskedLoad final
       auto emptyVector = rewriter.create<arith::ConstantOp>(
           loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
       mask = insertSubvectorInto(rewriter, loc, op.getMask(), emptyVector,
-                                 *foldedFrontPaddingSize);
+                                 *foldedIntraVectorOffset);
     }
 
     Value result =
@@ -554,7 +556,7 @@ struct ConvertVectorMaskedLoad final
 
     if (isUnalignedEmulation) {
       result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedFrontPaddingSize, origElements);
+                                    *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
@@ -607,16 +609,18 @@ struct ConvertVectorTransferRead final
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto foldedFrontPaddingSize = getFrontPaddingSize(
-        rewriter, loc, linearizedInfo, isUnalignedEmulation);
+    auto foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getIntraVectorOffset(rewriter, loc, linearizedInfo)
+            : 0;
 
-    if (!foldedFrontPaddingSize) {
-      // unimplemented case for dynamic front padding size
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic inra-vector offset
       return failure();
     }
 
     auto numElements =
-        llvm::divideCeil(*foldedFrontPaddingSize + origElements, scale);
+        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -629,7 +633,7 @@ struct ConvertVectorTransferRead final
     Value result = bitCast->getResult(0);
     if (isUnalignedEmulation) {
       result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedFrontPaddingSize, origElements);
+                                    *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 

>From 9624cde26276ad2533b5269eecf20cdfa6ef60c6 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Mon, 28 Oct 2024 21:55:46 +0000
Subject: [PATCH 09/10] updates

---
 .../mlir/Dialect/MemRef/Utils/MemRefUtils.h   |  2 +-
 .../Transforms/VectorEmulateNarrowType.cpp    | 36 +++++++------------
 2 files changed, 14 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 6315e9020cf400..2ddc53b760d11f 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -119,4 +119,4 @@ MemrefValue skipViewLikeOps(MemrefValue source);
 } // namespace memref
 } // namespace mlir
 
-#endif // MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H
\ No newline at end of file
+#endif // MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 48672a695d9db8..8f938943c768e7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -75,9 +75,9 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   shape.back() = numElements;
   auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
   if (createMaskOp) {
-    if (intraVectorOffset != 0) {
+    // TODO: handle the case with non-zero intraVectorOffset for CreateMaskOp.
+    if (intraVectorOffset != 0)
       return failure();
-    }
     OperandRange maskOperands = createMaskOp.getOperands();
     size_t numMaskOperands = maskOperands.size();
     AffineExpr s0;
@@ -129,18 +129,9 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   return newMask;
 }
 
-static std::optional<int64_t>
-getIntraVectorOffset(ConversionPatternRewriter &rewriter, Location loc,
-                     const memref::LinearizedMemRefInfo linearizedInfo) {
-  auto foldedIntraVectorOffset = getValueOrCreateConstantIndexOp(
-      rewriter, loc, linearizedInfo.intraVectorOffset);
-  return getConstantIntValue(foldedIntraVectorOffset);
-}
-
-static Value extractSubvectorFrom(ConversionPatternRewriter &rewriter,
-                                  Location loc, VectorType extractType,
-                                  Value vector, int64_t frontOffset,
-                                  int64_t subvecSize) {
+static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
+                                  VectorType extractType, Value vector,
+                                  int64_t frontOffset, int64_t subvecSize) {
   auto offsets = rewriter.getI64ArrayAttr({frontOffset});
   auto sizes = rewriter.getI64ArrayAttr({subvecSize});
   auto strides = rewriter.getI64ArrayAttr({1});
@@ -150,9 +141,8 @@ static Value extractSubvectorFrom(ConversionPatternRewriter &rewriter,
       ->getResult(0);
 }
 
-static Value insertSubvectorInto(ConversionPatternRewriter &rewriter,
-                                 Location loc, Value src, Value dest,
-                                 int64_t offset) {
+static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
+                                 Value src, Value dest, int64_t offset) {
   auto offsets = rewriter.getI64ArrayAttr({offset});
   auto strides = rewriter.getI64ArrayAttr({1});
   return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
@@ -385,9 +375,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto foldedIntraVectorOffset =
+    std::optional<int64_t> foldedIntraVectorOffset =
         isUnalignedEmulation
-            ? getIntraVectorOffset(rewriter, loc, linearizedInfo)
+            ? getConstantIntValue(linearizedInfo.intraVectorOffset)
             : 0;
 
     if (!foldedIntraVectorOffset) {
@@ -497,9 +487,9 @@ struct ConvertVectorMaskedLoad final
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto foldedIntraVectorOffset =
+    std::optional<int64_t> foldedIntraVectorOffset =
         isUnalignedEmulation
-            ? getIntraVectorOffset(rewriter, loc, linearizedInfo)
+            ? getConstantIntValue(linearizedInfo.intraVectorOffset)
             : 0;
 
     if (!foldedIntraVectorOffset) {
@@ -609,9 +599,9 @@ struct ConvertVectorTransferRead final
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto foldedIntraVectorOffset =
+    std::optional<int64_t> foldedIntraVectorOffset =
         isUnalignedEmulation
-            ? getIntraVectorOffset(rewriter, loc, linearizedInfo)
+            ? getConstantIntValue(linearizedInfo.intraVectorOffset)
             : 0;
 
     if (!foldedIntraVectorOffset) {

>From fffe88af99fe7faa5ed36a735531c2a82dbc365a Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Wed, 30 Oct 2024 01:13:48 +0000
Subject: [PATCH 10/10] updates again

---
 .../mlir/Dialect/MemRef/Utils/MemRefUtils.h   |  4 +--
 .../Transforms/VectorEmulateNarrowType.cpp    | 26 +++++++++----------
 2 files changed, 15 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 2ddc53b760d11f..a761a77a407e87 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -45,12 +45,12 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type);
 ///   0, is returned for the linearized index.
 /// - If the size of the load/store is smaller than the linearized memref
 /// load/store, the memory region emulated is larger than the actual memory
-/// region needed. `intraVectorOffset` returns the element offset of the data
+/// region needed. `intraDataOffset` returns the element offset of the data
 /// relevant at the beginning.
 struct LinearizedMemRefInfo {
   OpFoldResult linearizedOffset;
   OpFoldResult linearizedSize;
-  OpFoldResult intraVectorOffset;
+  OpFoldResult intraDataOffset;
 };
 std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     OpBuilder &builder, Location loc, int srcBits, int dstBits,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 8f938943c768e7..1d6f8a991d9b5b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -36,12 +36,12 @@ using namespace mlir;
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 /// Returns a compressed mask. The mask value is set only if any mask is present
-/// in the scale range. E.g., if `scale` equals to 2, and `intraVectorOffset`
+/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
 /// equals to 2, the following mask:
 ///
 ///   %mask = [1, 1, 1, 0, 0, 0]
 ///
-/// will first be padded with number of `intraVectorOffset` zeros:
+/// will first be padded with number of `intraDataOffset` zeros:
 ///   %mask = [0, 0, 1, 1, 1, 0, 0, 0]
 ///
 /// then it will return the following new compressed mask:
@@ -50,8 +50,8 @@ using namespace mlir;
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
                                                   int origElements, int scale,
-                                                  int intraVectorOffset = 0) {
-  auto numElements = (intraVectorOffset + origElements + scale - 1) / scale;
+                                                  int intraDataOffset = 0) {
+  auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
@@ -75,8 +75,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   shape.back() = numElements;
   auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
   if (createMaskOp) {
-    // TODO: handle the case with non-zero intraVectorOffset for CreateMaskOp.
-    if (intraVectorOffset != 0)
+    // TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
+    if (intraDataOffset != 0)
       return failure();
     OperandRange maskOperands = createMaskOp.getOperands();
     size_t numMaskOperands = maskOperands.size();
@@ -97,18 +97,18 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     size_t numMaskOperands = maskDimSizes.size();
     int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-    int64_t startIndex = intraVectorOffset / scale;
-    int64_t maskIndex = llvm::divideCeil(intraVectorOffset + origIndex, scale);
+    int64_t startIndex = intraDataOffset / scale;
+    int64_t maskIndex = llvm::divideCeil(intraDataOffset + origIndex, scale);
 
     // TODO: we only want the mask between [startIndex, maskIndex] to be true,
     // the rest are false.
-    if (intraVectorOffset != 0 && maskDimSizes.size() > 1)
+    if (intraDataOffset != 0 && maskDimSizes.size() > 1)
       return failure();
 
     SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
     newMaskDimSizes.push_back(maskIndex);
 
-    if (intraVectorOffset == 0) {
+    if (intraDataOffset == 0) {
       newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
                                                         newMaskDimSizes);
     } else {
@@ -377,7 +377,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
 
     std::optional<int64_t> foldedIntraVectorOffset =
         isUnalignedEmulation
-            ? getConstantIntValue(linearizedInfo.intraVectorOffset)
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
     if (!foldedIntraVectorOffset) {
@@ -489,7 +489,7 @@ struct ConvertVectorMaskedLoad final
 
     std::optional<int64_t> foldedIntraVectorOffset =
         isUnalignedEmulation
-            ? getConstantIntValue(linearizedInfo.intraVectorOffset)
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
     if (!foldedIntraVectorOffset) {
@@ -601,7 +601,7 @@ struct ConvertVectorTransferRead final
 
     std::optional<int64_t> foldedIntraVectorOffset =
         isUnalignedEmulation
-            ? getConstantIntValue(linearizedInfo.intraVectorOffset)
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
     if (!foldedIntraVectorOffset) {



More information about the Mlir-commits mailing list