[Mlir-commits] [mlir] 2c31325 - [MLIR] VectorEmulateNarrowType to support loading of unaligned vectors (#113411)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 29 20:04:52 PDT 2024


Author: lialan
Date: 2024-10-29T20:04:48-07:00
New Revision: 2c313259c65317f097d57ab4c6684b25db98f2e4

URL: https://github.com/llvm/llvm-project/commit/2c313259c65317f097d57ab4c6684b25db98f2e4
DIFF: https://github.com/llvm/llvm-project/commit/2c313259c65317f097d57ab4c6684b25db98f2e4.diff

LOG: [MLIR] VectorEmulateNarrowType to support loading of unaligned vectors (#113411)

Previously, the pass only supported emulation of loading vector sizes
that are multiples of the emulated data type. This patch expands its
support for emulating sizes that are not multiples of byte sizes. In
such cases, the element values are packed back-to-back to preserve
memory space.

To give a concrete example: if an input has type `memref<3x3xi2>`, it is
actually occupying 3 bytes in memory, with the first 18 bits storing the
values and the last 6 bits as padding. The slice of `vector<3xi2>` at
index `[2, 0]` is stored in memory from bit 12 to bit 18. To properly
load the elements from bit 12 to bit 18 from memory, first load byte 2
and byte 3, and convert it to a vector of `i2` type; then extract bits 4
to 10 (element index 2-5) to form a `vector<3xi2>`.

A limitation of this patch is that the linearized index of the unaligned
vector has to be known at compile time. Extra code needs to be emitted
to handle it if the condition does not hold.

The following ops are updated:
* `vector::LoadOp`
* `vector::TransferReadOp`
* `vector::MaskedLoadOp`

Added: 
    mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Modified: 
    mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
    mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index ca3326dbbef519..a761a77a407e87 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -32,7 +32,8 @@ namespace memref {
 bool isStaticShapeAndContiguousRowMajor(MemRefType type);
 
 /// For a `memref` with `offset`, `sizes` and `strides`, returns the
-/// offset and size to use for the linearized `memref`.
+/// offset, size, and potentially the size padded at the front to use for the
+/// linearized `memref`.
 /// - If the linearization is done for emulating load/stores of
 ///   element type with bitwidth `srcBits` using element type with
 ///   bitwidth `dstBits`, the linearized offset and size are
@@ -42,9 +43,14 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type);
 ///   index to use in the linearized `memref`. The linearized index
 ///   is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided
 ///   0, is returned for the linearized index.
+/// - If the size of the load/store is smaller than the linearized memref
+/// load/store, the memory region emulated is larger than the actual memory
+/// region needed. `intraDataOffset` returns the element offset of the data
+/// relevant at the beginning.
 struct LinearizedMemRefInfo {
   OpFoldResult linearizedOffset;
   OpFoldResult linearizedSize;
+  OpFoldResult intraDataOffset;
 };
 std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     OpBuilder &builder, Location loc, int srcBits, int dstBits,

diff  --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 7321b19068016c..6de744a7f75244 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -81,11 +81,10 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
 
   // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
-  addMulMap = addMulMap.floorDiv(scaler);
   mulMap = mulMap.floorDiv(scaler);
 
   OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
-      builder, loc, addMulMap, offsetValues);
+      builder, loc, addMulMap.floorDiv(scaler), offsetValues);
   OpFoldResult linearizedSize =
       affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
 
@@ -95,7 +94,11 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
   OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
       builder, loc, s0.floorDiv(scaler), {offset});
 
-  return {{adjustBaseOffset, linearizedSize}, linearizedIndices};
+  OpFoldResult intraVectorOffset = affine::makeComposedFoldedAffineApply(
+      builder, loc, addMulMap % scaler, offsetValues);
+
+  return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
+          linearizedIndices};
 }
 
 LinearizedMemRefInfo

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 66362d3ca70fb6..1d6f8a991d9b5b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -22,8 +23,10 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
+#include <optional>
 
 using namespace mlir;
 
@@ -33,17 +36,22 @@ using namespace mlir;
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 /// Returns a compressed mask. The mask value is set only if any mask is present
-/// in the scale range. E.g., if `scale` equals to 2, the following mask:
+/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
+/// equals to 2, the following mask:
 ///
 ///   %mask = [1, 1, 1, 0, 0, 0]
 ///
-/// will return the following new compressed mask:
+/// will first be padded with number of `intraDataOffset` zeros:
+///   %mask = [0, 0, 1, 1, 1, 0, 0, 0]
 ///
-///   %mask = [1, 1, 0]
+/// then it will return the following new compressed mask:
+///
+///   %mask = [0, 1, 1, 0]
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
-                                                  int origElements, int scale) {
-  auto numElements = (origElements + scale - 1) / scale;
+                                                  int origElements, int scale,
+                                                  int intraDataOffset = 0) {
+  auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
@@ -67,6 +75,9 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   shape.back() = numElements;
   auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
   if (createMaskOp) {
+    // TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
+    if (intraDataOffset != 0)
+      return failure();
     OperandRange maskOperands = createMaskOp.getOperands();
     size_t numMaskOperands = maskOperands.size();
     AffineExpr s0;
@@ -86,11 +97,27 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     size_t numMaskOperands = maskDimSizes.size();
     int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-    int64_t maskIndex = (origIndex + scale - 1) / scale;
+    int64_t startIndex = intraDataOffset / scale;
+    int64_t maskIndex = llvm::divideCeil(intraDataOffset + origIndex, scale);
+
+    // TODO: we only want the mask between [startIndex, maskIndex] to be true,
+    // the rest are false.
+    if (intraDataOffset != 0 && maskDimSizes.size() > 1)
+      return failure();
+
     SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
     newMaskDimSizes.push_back(maskIndex);
-    newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
-                                                      newMaskDimSizes);
+
+    if (intraDataOffset == 0) {
+      newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+                                                        newMaskDimSizes);
+    } else {
+      SmallVector<bool> newMaskValues;
+      for (int64_t i = 0; i < numElements; ++i)
+        newMaskValues.push_back(i >= startIndex && i < maskIndex);
+      auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
+      newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
+    }
   }
 
   while (!extractOps.empty()) {
@@ -102,6 +129,26 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   return newMask;
 }
 
+static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
+                                  VectorType extractType, Value vector,
+                                  int64_t frontOffset, int64_t subvecSize) {
+  auto offsets = rewriter.getI64ArrayAttr({frontOffset});
+  auto sizes = rewriter.getI64ArrayAttr({subvecSize});
+  auto strides = rewriter.getI64ArrayAttr({1});
+  return rewriter
+      .create<vector::ExtractStridedSliceOp>(loc, extractType, vector, offsets,
+                                             sizes, strides)
+      ->getResult(0);
+}
+
+static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
+                                 Value src, Value dest, int64_t offset) {
+  auto offsets = rewriter.getI64ArrayAttr({offset});
+  auto strides = rewriter.getI64ArrayAttr({1});
+  return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
+                                                       dest, offsets, strides);
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -201,7 +248,8 @@ struct ConvertVectorMaskedStore final
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
     OpFoldResult linearizedIndicesOfr;
-    std::tie(std::ignore, linearizedIndicesOfr) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndicesOfr) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -214,19 +262,19 @@ struct ConvertVectorMaskedStore final
     // Load the whole data and use arith.select to handle the corner cases.
     // E.g., given these input values:
     //
-    //   %mask = [1, 1, 1, 0, 0, 0]
-    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
-    //   %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
+    //   %mask = [0, 1, 1, 1, 1, 1, 0, 0]
+    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
+    //   %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
     //
     // we'll have
     //
-    //    expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
+    //    expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
     //
-    //    %new_mask = [1, 1, 0]
-    //    %maskedload = [0x12, 0x34, 0x0]
-    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
-    //    %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
-    //    %packed_data = [0x78, 0x94, 0x00]
+    //    %new_mask = [1, 1, 1, 0]
+    //    %maskedload = [0x12, 0x34, 0x56, 0x00]
+    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
+    //    %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
+    //    %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
     //
     // Using the new mask to store %packed_data results in expected output.
     FailureOr<Operation *> newMask =
@@ -243,8 +291,9 @@ struct ConvertVectorMaskedStore final
         loc, newType, adaptor.getBase(), linearizedIndices,
         newMask.value()->getResult(0), passThru);
 
-    Value valueToStore = rewriter.create<vector::BitCastOp>(
-        loc, op.getValueToStore().getType(), newLoad);
+    auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+    Value valueToStore =
+        rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
     valueToStore = rewriter.create<arith::SelectOp>(
         loc, op.getMask(), op.getValueToStore(), valueToStore);
     valueToStore =
@@ -294,19 +343,31 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
     // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
     //
-    // TODO: Currently, only the even number of elements loading is supported.
-    // To deal with the odd number of elements, one has to extract the
-    // subvector at the proper offset after bit-casting.
+    // There are cases where the number of elements to load is not byte-aligned,
+    // for example:
+    //
+    // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
+    //
+    // we will have to load extra bytes and extract the exact slice in between.
+    //
+    // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
+    // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
+    // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
+    // = [1]}
+    //        : vector<8xi2> to vector<3xi2>
+    //
+    // TODO: Currently the extract_strided_slice's attributes must be known at
+    // compile time as they must be constants.
 
     auto origElements = op.getVectorType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -314,15 +375,31 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = (origElements + scale - 1) / scale;
+    std::optional<int64_t> foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
+            : 0;
+
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic intra vector offset
+      return failure();
+    }
+
+    auto numElements =
+        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
     auto newLoad = rewriter.create<vector::LoadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
 
-    auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
+    Value result = rewriter.create<vector::BitCastOp>(
+        loc, VectorType::get(numElements * scale, oldElementType), newLoad);
 
-    rewriter.replaceOp(op, bitCast->getResult(0));
+    if (isUnalignedEmulation) {
+      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                    *foldedIntraVectorOffset, origElements);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -396,13 +473,13 @@ struct ConvertVectorMaskedLoad final
     // subvector at the proper offset after bit-casting.
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -410,29 +487,68 @@ struct ConvertVectorMaskedLoad final
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
+    std::optional<int64_t> foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
+            : 0;
+
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic intra vector offset
+      return failure();
+    }
+
     FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
+                            *foldedIntraVectorOffset);
     if (failed(newMask))
       return failure();
 
-    auto numElements = (origElements + scale - 1) / scale;
-    auto newType = VectorType::get(numElements, newElementType);
+    auto numElements =
+        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+    auto loadType = VectorType::get(numElements, newElementType);
+    auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+
+    Value passthru = op.getPassThru();
+    if (isUnalignedEmulation) {
+      // create an empty vector of the new type
+      auto emptyVector = rewriter.create<arith::ConstantOp>(
+          loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+      passthru = insertSubvectorInto(rewriter, loc, passthru, emptyVector,
+                                     *foldedIntraVectorOffset);
+    }
     auto newPassThru =
-        rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
+        rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
 
     // Generating the new masked load.
     auto newLoad = rewriter.create<vector::MaskedLoadOp>(
-        loc, newType, adaptor.getBase(),
+        loc, loadType, adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
         newMask.value()->getResult(0), newPassThru);
 
     // Setting the part that originally was not effectively loaded from memory
     // to pass through.
     auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
-    auto select = rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast,
-                                                   op.getPassThru());
-    rewriter.replaceOp(op, select->getResult(0));
+        rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
+
+    Value mask = op.getMask();
+    if (isUnalignedEmulation) {
+      auto newSelectMaskType =
+          VectorType::get(numElements * scale, rewriter.getI1Type());
+      // TODO: can fold if op's mask is constant
+      auto emptyVector = rewriter.create<arith::ConstantOp>(
+          loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
+      mask = insertSubvectorInto(rewriter, loc, op.getMask(), emptyVector,
+                                 *foldedIntraVectorOffset);
+    }
+
+    Value result =
+        rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
+
+    if (isUnalignedEmulation) {
+      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                    *foldedIntraVectorOffset, origElements);
+    }
+    rewriter.replaceOp(op, result);
 
     return success();
   }
@@ -464,8 +580,8 @@ struct ConvertVectorTransferRead final
     int scale = dstBits / srcBits;
 
     auto origElements = op.getVectorType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
                                                       adaptor.getPadding());
@@ -474,7 +590,8 @@ struct ConvertVectorTransferRead final
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -482,18 +599,34 @@ struct ConvertVectorTransferRead final
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = (origElements + scale - 1) / scale;
-    auto newReadType = VectorType::get(numElements, newElementType);
+    std::optional<int64_t> foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
+            : 0;
+
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic inra-vector offset
+      return failure();
+    }
+
+    auto numElements =
+        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
-        loc, newReadType, adaptor.getSource(),
+        loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
         newPadding);
 
-    auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
+    auto bitCast = rewriter.create<vector::BitCastOp>(
+        loc, VectorType::get(numElements * scale, oldElementType), newRead);
+
+    Value result = bitCast->getResult(0);
+    if (isUnalignedEmulation) {
+      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                    *foldedIntraVectorOffset, origElements);
+    }
+    rewriter.replaceOp(op, result);
 
-    rewriter.replaceOp(op, bitCast->getResult(0));
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
new file mode 100644
index 00000000000000..7ecbad7968225d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+
+func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %cst = arith.constant dense<0> : vector<3x3xi2>
+    %1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
+    %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
+    return %2 : vector<3x3xi2>
+}
+
+// CHECK: func @vector_load_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = arith.constant 1 : index
+// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[VEC_I2:.+]] = vector.bitcast %[[VEC]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[EXCTRACT:.+]] = vector.extract_strided_slice %[[VEC_I2]] {offsets = [2], sizes = [3], strides = [1]} : vector<8xi2> to vector<3xi2>
+
+//-----
+
+func.func @vector_transfer_read_i2() -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0i2 = arith.constant 0 : i2
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %1 = vector.transfer_read %0[%c2, %c0], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: func @vector_transfer_read_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = arith.constant 1 : index
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %0 : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: vector.extract_strided_slice %[[BITCAST]] {offsets = [2], sizes = [3], strides = [1]} : vector<8xi2> to vector<3xi2>
+
+//-----
+
+func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
+    %0 = memref.alloc() : memref<3x5xi2>
+    %cst = arith.constant dense<0> : vector<3x5xi2>
+    %mask = vector.constant_mask [3] : vector<5xi1>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
+      memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+    %2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
+    return %2 : vector<3x5xi2>
+}
+
+// CHECK: func @vector_cst_maskedload_i2
+// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
+// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %arg0, %[[VESSEL]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
+// CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %alloc[%[[C2]]], %[[NEWMASK:.+]], %[[BITCAST1]]
+// CHECK-SAME: : memref<4xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi1> into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BITCAST2]], %[[INSERT1]] : vector<8xi1>, vector<8xi2>
+// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [2], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2> 


        


More information about the Mlir-commits mailing list