[Mlir-commits] [mlir] [MLIR] VectorEmulateNarrowType to support loading of unaligned vectors (PR #113411)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 29 18:19:31 PDT 2024


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/113411

>From 355c3281fbcb58963dfc1bc1bee2a0e1f2e9e33b Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Sun, 20 Oct 2024 14:54:57 +0000
Subject: [PATCH] [MLIR] VectorEmulateNarrowType support unaligned cases

Previously the pass only supports emulation of vector sizes that are
a multiple of emulated data type (i8). This patch expands its support
of emulation which's size are not a multiple of byte
sizes, such as `vector<3xi2>`.

A limitation of this patch is that the linearized index of the unaligned
vector has to be known at compile time. Extra code needs to be emitted
to handle it if the condition does not hold.

The following ops are updated:
* `vector::LoadOp`
* `vector::StoreOp`
* `vector::TransferReadOp`
---
 .../mlir/Dialect/MemRef/Utils/MemRefUtils.h   |   8 +-
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp |   9 +-
 .../Transforms/VectorEmulateNarrowType.cpp    | 235 ++++++++++++++----
 .../vector-emulate-narrow-type-unaligned.mlir |  67 +++++
 4 files changed, 264 insertions(+), 55 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index ca3326dbbef519..a761a77a407e87 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -32,7 +32,8 @@ namespace memref {
 bool isStaticShapeAndContiguousRowMajor(MemRefType type);
 
 /// For a `memref` with `offset`, `sizes` and `strides`, returns the
-/// offset and size to use for the linearized `memref`.
+/// offset, size, and potentially the size padded at the front to use for the
+/// linearized `memref`.
 /// - If the linearization is done for emulating load/stores of
 ///   element type with bitwidth `srcBits` using element type with
 ///   bitwidth `dstBits`, the linearized offset and size are
@@ -42,9 +43,14 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type);
 ///   index to use in the linearized `memref`. The linearized index
 ///   is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided
 ///   0, is returned for the linearized index.
+/// - If the size of the load/store is smaller than the linearized memref
+/// load/store, the memory region emulated is larger than the actual memory
+/// region needed. `intraDataOffset` returns the element offset of the data
+/// relevant at the beginning.
 struct LinearizedMemRefInfo {
   OpFoldResult linearizedOffset;
   OpFoldResult linearizedSize;
+  OpFoldResult intraDataOffset;
 };
 std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     OpBuilder &builder, Location loc, int srcBits, int dstBits,
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 7321b19068016c..6de744a7f75244 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -81,11 +81,10 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
 
   // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
-  addMulMap = addMulMap.floorDiv(scaler);
   mulMap = mulMap.floorDiv(scaler);
 
   OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
-      builder, loc, addMulMap, offsetValues);
+      builder, loc, addMulMap.floorDiv(scaler), offsetValues);
   OpFoldResult linearizedSize =
       affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
 
@@ -95,7 +94,11 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
   OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
       builder, loc, s0.floorDiv(scaler), {offset});
 
-  return {{adjustBaseOffset, linearizedSize}, linearizedIndices};
+  OpFoldResult intraVectorOffset = affine::makeComposedFoldedAffineApply(
+      builder, loc, addMulMap % scaler, offsetValues);
+
+  return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
+          linearizedIndices};
 }
 
 LinearizedMemRefInfo
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 66362d3ca70fb6..1d6f8a991d9b5b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -22,8 +23,10 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
+#include <optional>
 
 using namespace mlir;
 
@@ -33,17 +36,22 @@ using namespace mlir;
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 /// Returns a compressed mask. The mask value is set only if any mask is present
-/// in the scale range. E.g., if `scale` equals to 2, the following mask:
+/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
+/// equals to 2, the following mask:
 ///
 ///   %mask = [1, 1, 1, 0, 0, 0]
 ///
-/// will return the following new compressed mask:
+/// will first be padded with number of `intraDataOffset` zeros:
+///   %mask = [0, 0, 1, 1, 1, 0, 0, 0]
 ///
-///   %mask = [1, 1, 0]
+/// then it will return the following new compressed mask:
+///
+///   %mask = [0, 1, 1, 0]
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
-                                                  int origElements, int scale) {
-  auto numElements = (origElements + scale - 1) / scale;
+                                                  int origElements, int scale,
+                                                  int intraDataOffset = 0) {
+  auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
@@ -67,6 +75,9 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   shape.back() = numElements;
   auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
   if (createMaskOp) {
+    // TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
+    if (intraDataOffset != 0)
+      return failure();
     OperandRange maskOperands = createMaskOp.getOperands();
     size_t numMaskOperands = maskOperands.size();
     AffineExpr s0;
@@ -86,11 +97,27 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     size_t numMaskOperands = maskDimSizes.size();
     int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-    int64_t maskIndex = (origIndex + scale - 1) / scale;
+    int64_t startIndex = intraDataOffset / scale;
+    int64_t maskIndex = llvm::divideCeil(intraDataOffset + origIndex, scale);
+
+    // TODO: we only want the mask between [startIndex, maskIndex] to be true,
+    // the rest are false.
+    if (intraDataOffset != 0 && maskDimSizes.size() > 1)
+      return failure();
+
     SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
     newMaskDimSizes.push_back(maskIndex);
-    newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
-                                                      newMaskDimSizes);
+
+    if (intraDataOffset == 0) {
+      newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+                                                        newMaskDimSizes);
+    } else {
+      SmallVector<bool> newMaskValues;
+      for (int64_t i = 0; i < numElements; ++i)
+        newMaskValues.push_back(i >= startIndex && i < maskIndex);
+      auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
+      newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
+    }
   }
 
   while (!extractOps.empty()) {
@@ -102,6 +129,26 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   return newMask;
 }
 
+static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
+                                  VectorType extractType, Value vector,
+                                  int64_t frontOffset, int64_t subvecSize) {
+  auto offsets = rewriter.getI64ArrayAttr({frontOffset});
+  auto sizes = rewriter.getI64ArrayAttr({subvecSize});
+  auto strides = rewriter.getI64ArrayAttr({1});
+  return rewriter
+      .create<vector::ExtractStridedSliceOp>(loc, extractType, vector, offsets,
+                                             sizes, strides)
+      ->getResult(0);
+}
+
+static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
+                                 Value src, Value dest, int64_t offset) {
+  auto offsets = rewriter.getI64ArrayAttr({offset});
+  auto strides = rewriter.getI64ArrayAttr({1});
+  return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
+                                                       dest, offsets, strides);
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -201,7 +248,8 @@ struct ConvertVectorMaskedStore final
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
     OpFoldResult linearizedIndicesOfr;
-    std::tie(std::ignore, linearizedIndicesOfr) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndicesOfr) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -214,19 +262,19 @@ struct ConvertVectorMaskedStore final
     // Load the whole data and use arith.select to handle the corner cases.
     // E.g., given these input values:
     //
-    //   %mask = [1, 1, 1, 0, 0, 0]
-    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
-    //   %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
+    //   %mask = [0, 1, 1, 1, 1, 1, 0, 0]
+    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
+    //   %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
     //
     // we'll have
     //
-    //    expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
+    //    expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
     //
-    //    %new_mask = [1, 1, 0]
-    //    %maskedload = [0x12, 0x34, 0x0]
-    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
-    //    %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
-    //    %packed_data = [0x78, 0x94, 0x00]
+    //    %new_mask = [1, 1, 1, 0]
+    //    %maskedload = [0x12, 0x34, 0x56, 0x00]
+    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
+    //    %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
+    //    %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
     //
     // Using the new mask to store %packed_data results in expected output.
     FailureOr<Operation *> newMask =
@@ -243,8 +291,9 @@ struct ConvertVectorMaskedStore final
         loc, newType, adaptor.getBase(), linearizedIndices,
         newMask.value()->getResult(0), passThru);
 
-    Value valueToStore = rewriter.create<vector::BitCastOp>(
-        loc, op.getValueToStore().getType(), newLoad);
+    auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+    Value valueToStore =
+        rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
     valueToStore = rewriter.create<arith::SelectOp>(
         loc, op.getMask(), op.getValueToStore(), valueToStore);
     valueToStore =
@@ -294,19 +343,31 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
     // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
     //
-    // TODO: Currently, only the even number of elements loading is supported.
-    // To deal with the odd number of elements, one has to extract the
-    // subvector at the proper offset after bit-casting.
+    // There are cases where the number of elements to load is not byte-aligned,
+    // for example:
+    //
+    // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
+    //
+    // we will have to load extra bytes and extract the exact slice in between.
+    //
+    // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
+    // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
+    // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
+    // = [1]}
+    //        : vector<8xi2> to vector<3xi2>
+    //
+    // TODO: Currently the extract_strided_slice's attributes must be known at
+    // compile time as they must be constants.
 
     auto origElements = op.getVectorType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -314,15 +375,31 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = (origElements + scale - 1) / scale;
+    std::optional<int64_t> foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
+            : 0;
+
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic intra vector offset
+      return failure();
+    }
+
+    auto numElements =
+        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
     auto newLoad = rewriter.create<vector::LoadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
 
-    auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
+    Value result = rewriter.create<vector::BitCastOp>(
+        loc, VectorType::get(numElements * scale, oldElementType), newLoad);
 
-    rewriter.replaceOp(op, bitCast->getResult(0));
+    if (isUnalignedEmulation) {
+      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                    *foldedIntraVectorOffset, origElements);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -396,13 +473,13 @@ struct ConvertVectorMaskedLoad final
     // subvector at the proper offset after bit-casting.
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -410,29 +487,68 @@ struct ConvertVectorMaskedLoad final
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
+    std::optional<int64_t> foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
+            : 0;
+
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic intra vector offset
+      return failure();
+    }
+
     FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
+                            *foldedIntraVectorOffset);
     if (failed(newMask))
       return failure();
 
-    auto numElements = (origElements + scale - 1) / scale;
-    auto newType = VectorType::get(numElements, newElementType);
+    auto numElements =
+        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+    auto loadType = VectorType::get(numElements, newElementType);
+    auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+
+    Value passthru = op.getPassThru();
+    if (isUnalignedEmulation) {
+      // create an empty vector of the new type
+      auto emptyVector = rewriter.create<arith::ConstantOp>(
+          loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+      passthru = insertSubvectorInto(rewriter, loc, passthru, emptyVector,
+                                     *foldedIntraVectorOffset);
+    }
     auto newPassThru =
-        rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
+        rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
 
     // Generating the new masked load.
     auto newLoad = rewriter.create<vector::MaskedLoadOp>(
-        loc, newType, adaptor.getBase(),
+        loc, loadType, adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
         newMask.value()->getResult(0), newPassThru);
 
     // Setting the part that originally was not effectively loaded from memory
     // to pass through.
     auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
-    auto select = rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast,
-                                                   op.getPassThru());
-    rewriter.replaceOp(op, select->getResult(0));
+        rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
+
+    Value mask = op.getMask();
+    if (isUnalignedEmulation) {
+      auto newSelectMaskType =
+          VectorType::get(numElements * scale, rewriter.getI1Type());
+      // TODO: can fold if op's mask is constant
+      auto emptyVector = rewriter.create<arith::ConstantOp>(
+          loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
+      mask = insertSubvectorInto(rewriter, loc, op.getMask(), emptyVector,
+                                 *foldedIntraVectorOffset);
+    }
+
+    Value result =
+        rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
+
+    if (isUnalignedEmulation) {
+      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                    *foldedIntraVectorOffset, origElements);
+    }
+    rewriter.replaceOp(op, result);
 
     return success();
   }
@@ -464,8 +580,8 @@ struct ConvertVectorTransferRead final
     int scale = dstBits / srcBits;
 
     auto origElements = op.getVectorType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
                                                       adaptor.getPadding());
@@ -474,7 +590,8 @@ struct ConvertVectorTransferRead final
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -482,18 +599,34 @@ struct ConvertVectorTransferRead final
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = (origElements + scale - 1) / scale;
-    auto newReadType = VectorType::get(numElements, newElementType);
+    std::optional<int64_t> foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
+            : 0;
+
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic inra-vector offset
+      return failure();
+    }
+
+    auto numElements =
+        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
-        loc, newReadType, adaptor.getSource(),
+        loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
         newPadding);
 
-    auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
+    auto bitCast = rewriter.create<vector::BitCastOp>(
+        loc, VectorType::get(numElements * scale, oldElementType), newRead);
+
+    Value result = bitCast->getResult(0);
+    if (isUnalignedEmulation) {
+      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                    *foldedIntraVectorOffset, origElements);
+    }
+    rewriter.replaceOp(op, result);
 
-    rewriter.replaceOp(op, bitCast->getResult(0));
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
new file mode 100644
index 00000000000000..7ecbad7968225d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+
+func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %cst = arith.constant dense<0> : vector<3x3xi2>
+    %1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
+    %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
+    return %2 : vector<3x3xi2>
+}
+
+// CHECK: func @vector_load_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = arith.constant 1 : index
+// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[VEC_I2:.+]] = vector.bitcast %[[VEC]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[EXCTRACT:.+]] = vector.extract_strided_slice %[[VEC_I2]] {offsets = [2], sizes = [3], strides = [1]} : vector<8xi2> to vector<3xi2>
+
+//-----
+
+func.func @vector_transfer_read_i2() -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0i2 = arith.constant 0 : i2
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %1 = vector.transfer_read %0[%c2, %c0], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: func @vector_transfer_read_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = arith.constant 1 : index
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %0 : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: vector.extract_strided_slice %[[BITCAST]] {offsets = [2], sizes = [3], strides = [1]} : vector<8xi2> to vector<3xi2>
+
+//-----
+
+func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
+    %0 = memref.alloc() : memref<3x5xi2>
+    %cst = arith.constant dense<0> : vector<3x5xi2>
+    %mask = vector.constant_mask [3] : vector<5xi1>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
+      memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+    %2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
+    return %2 : vector<3x5xi2>
+}
+
+// CHECK: func @vector_cst_maskedload_i2
+// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
+// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %arg0, %[[VESSEL]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
+// CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %alloc[%[[C2]]], %[[NEWMASK:.+]], %[[BITCAST1]]
+// CHECK-SAME: : memref<4xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi1> into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BITCAST2]], %[[INSERT1]] : vector<8xi1>, vector<8xi2>
+// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [2], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2> 



More information about the Mlir-commits mailing list