[Mlir-commits] [mlir] 0f8bab8 - [mlir] Revamp implementation of sub-byte load/store emulation.

Mahesh Ravishankar llvmlistbot at llvm.org
Thu Aug 17 13:28:06 PDT 2023


Author: Mahesh Ravishankar
Date: 2023-08-17T20:27:53Z
New Revision: 0f8bab8d590ed0cb5402f72009b32cbad115f013

URL: https://github.com/llvm/llvm-project/commit/0f8bab8d590ed0cb5402f72009b32cbad115f013
DIFF: https://github.com/llvm/llvm-project/commit/0f8bab8d590ed0cb5402f72009b32cbad115f013.diff

LOG: [mlir] Revamp implementation of sub-byte load/store emulation.

When handling sub-byte emulation, the sizes of the converted `memref`s
also need to be updated (this was not done in the current
implementation). This adds the additional complexity of having to
linearize the `memref`s as well. Consider a `memref<3x3xi4>` where the
`i4` elements are packed. This has a overall size of 5 bytes (rounded
up to number of bytes). This can only be represented by a
`memref<5xi8>`. A `memref<3x2xi8>` would imply an implicit padding of
4 bits at the end of each row. So incorporate linearization into the
sub-byte load-store emulation.

This patch also updates some of the utility functions to make better
use of statically available information using `OpFoldResult` and
`makeComposedFoldedAffineApplyOps`.

Reviewed By: hanchung, yzhang93

Differential Revision: https://reviews.llvm.org/D158125

Added: 
    mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Modified: 
    mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
    mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
    mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
    mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
    mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
    mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp

Removed: 
    mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir
    mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index e1504f030defa6..9ff009b8c27bf6 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -28,36 +28,37 @@ namespace memref {
 /// contiguous chunk of memory.
 bool isStaticShapeAndContiguousRowMajor(MemRefType type);
 
-/// Returns the flattened 1-D memref and linearized offset for narrow type
-/// emulation.
-///
-/// The emulation only works on 1D memref types. To make this work on N-D
-/// memref, we need to linearize the offset.
-///
-/// For example, to emulate i4 to i8, the following op:
-///
-/// %0 = memref.load %arg0[%v0, %v1] :
-///                  memref<?x?xi4, strided<[?, ?], offset: ?>>
-///
-/// can be replaced with
-///
-/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
-///
-/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
-/// %linearized_size = %size0 * %size1
-/// %scaled_linear_offset = %linearized_offset / 8 * 4
-/// %scaled_base_offset = %offset / 8 * 4
-///
-/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
-///                      sizes = [%linearized_size], strides = [%stride#1]
-///
-/// %new_load = memref.load %linearized[%scaled_linear_offset] :
-///                         memref<?xi8, strided<[?], offset: ?>>
-std::pair<Value, Value>
-getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits,
-                            int dstBits, SmallVector<Value> indices,
-                            memref::ExtractStridedMetadataOp stridedMetadata,
-                            OpBuilder &builder);
+/// For a `memref` with `offset`, `sizes` and `strides`, returns the
+/// offset and size to use for the linearized `memref`.
+/// - If the linearization is done for emulating load/stores of
+///   element type with bitwidth `srcBits` using element type with
+///   bitwidth `dstBits`, the linearized offset and size are
+///   scaled down by `dstBits`/`srcBits`.
+/// - If `indices` is provided, it represents the position in the
+///   original `memref` being accessed. The method then returns the
+///   index to use in the linearized `memref`. The linearized index
+///   is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided
+///   0, is returned for the linearized index.
+struct LinearizedMemRefInfo {
+  OpFoldResult linearizedOffset;
+  OpFoldResult linearizedSize;
+};
+std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
+    OpBuilder &builder, Location loc, int srcBits, int dstBits,
+    OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
+    ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices = {});
+
+/// For a `memref` with `offset` and `sizes`, returns the
+/// offset and size to use for the linearized `memref`, assuming that
+/// the strides are computed from a row-major ordering of the sizes;
+/// - If the linearization is done for emulating load/stores of
+///   element type with bitwidth `srcBits` using element type with
+///   bitwidth `dstBits`, the linearized offset and size are
+///   scaled down by `dstBits`/`srcBits`.
+LinearizedMemRefInfo
+getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
+                                 int dstBits, OpFoldResult offset,
+                                 ArrayRef<OpFoldResult> sizes);
 
 } // namespace memref
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index fa5c90c24551da..2a524ceb9db887 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -35,18 +35,18 @@ using namespace mlir;
 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
 /// located at (x % 2) * 4. Because there are two elements in one i8, and one
 /// element has 4 bits.
-static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
-                                  int targetBits, OpBuilder &builder) {
+static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
+                                  int sourceBits, int targetBits,
+                                  OpBuilder &builder) {
   assert(targetBits % sourceBits == 0);
-  IntegerType targetType = builder.getIntegerType(targetBits);
-  IntegerAttr idxAttr =
-      builder.getIntegerAttr(targetType, targetBits / sourceBits);
-  auto idx = builder.create<arith::ConstantOp>(loc, targetType, idxAttr);
-  IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
-  auto srcBitsValue =
-      builder.create<arith::ConstantOp>(loc, targetType, srcBitsAttr);
-  auto m = builder.create<arith::RemUIOp>(loc, srcIdx, idx);
-  return builder.create<arith::MulIOp>(loc, targetType, m, srcBitsValue);
+  AffineExpr s0;
+  bindSymbols(builder.getContext(), s0);
+  int scaleFactor = targetBits / sourceBits;
+  OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
+      builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+  Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
+  IntegerType dstType = builder.getIntegerType(targetBits);
+  return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
 }
 
 namespace {
@@ -61,15 +61,43 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
   LogicalResult
   matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type newTy = getTypeConverter()->convertType(op.getType());
-    if (!newTy) {
+    auto currentType = op.getMemref().getType().cast<MemRefType>();
+    auto newResultType =
+        getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+    if (!newResultType) {
       return rewriter.notifyMatchFailure(
           op->getLoc(),
           llvm::formatv("failed to convert memref type: {0}", op.getType()));
     }
 
+    // Special case zero-rank memrefs.
+    if (currentType.getRank() == 0) {
+      rewriter.replaceOpWithNewOp<memref::AllocOp>(
+          op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
+          adaptor.getAlignmentAttr());
+      return success();
+    }
+
+    Location loc = op.getLoc();
+    OpFoldResult zero = rewriter.getIndexAttr(0);
+    SmallVector<OpFoldResult> indices(currentType.getRank(), zero);
+
+    // Get linearized type.
+    int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
+    int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
+    SmallVector<OpFoldResult> sizes = op.getMixedSizes();
+
+    memref::LinearizedMemRefInfo linearizedMemRefInfo =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
+    SmallVector<Value> dynamicLinearizedSize;
+    if (!newResultType.hasStaticShape()) {
+      dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
+          rewriter, loc, linearizedMemRefInfo.linearizedSize));
+    }
+
     rewriter.replaceOpWithNewOp<memref::AllocOp>(
-        op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
+        op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
         adaptor.getAlignmentAttr());
     return success();
   }
@@ -109,62 +137,57 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   LogicalResult
   matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type newTy = getTypeConverter()->convertType(op.getMemRefType());
-    if (!newTy) {
-      return rewriter.notifyMatchFailure(
-          op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
-                                      op.getMemRefType()));
-    }
-
-    if (op.getMemRefType() == newTy)
-      return failure();
-
-    auto loc = op.getLoc();
-    auto sourceType = cast<MemRefType>(adaptor.getMemref().getType());
-    unsigned sourceRank = sourceType.getRank();
-    SmallVector<Value> indices = adaptor.getIndices();
-    assert(indices.size() == sourceRank);
-
-    auto srcElementType = sourceType.getElementType();
+    auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+    auto convertedElementType = convertedType.getElementType();
     auto oldElementType = op.getMemRefType().getElementType();
     int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = srcElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
     if (dstBits % srcBits != 0) {
       return rewriter.notifyMatchFailure(
           op, "only dstBits % srcBits == 0 supported");
     }
 
-    auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
-        loc, adaptor.getMemref());
-
-    Value newLoad, lastIdx;
-    if (sourceRank == 0) {
-      newLoad = rewriter.create<memref::LoadOp>(
-          loc, srcElementType, adaptor.getMemref(), adaptor.getIndices());
-
-      lastIdx = stridedMetadata.getOffset();
+    Location loc = op.getLoc();
+    // Special case 0-rank memref loads.
+    Value bitsLoad;
+    if (convertedType.getRank() == 0) {
+      bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
+                                                 ValueRange{});
     } else {
-      auto [reinterpret, linearizedOffset] =
-          memref::getLinearizeMemRefAndOffset(loc, sourceType, srcBits, dstBits,
-                                              adaptor.getIndices(),
-                                              stridedMetadata, rewriter);
-
-      newLoad = rewriter.create<memref::LoadOp>(loc, srcElementType,
-                                                reinterpret, linearizedOffset);
-
-      lastIdx = adaptor.getIndices().back();
+      SmallVector<OpFoldResult> indices =
+          getAsOpFoldResult(adaptor.getIndices());
+
+      auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+          loc, op.getMemRef());
+
+      // Linearize the indices of the original load instruction. Do not account
+      // for the scaling yet. This will be accounted for later.
+      OpFoldResult linearizedIndices;
+      std::tie(std::ignore, linearizedIndices) =
+          memref::getLinearizedMemRefOffsetAndSize(
+              rewriter, loc, srcBits, srcBits,
+              stridedMetadata.getConstifiedMixedOffset(),
+              stridedMetadata.getConstifiedMixedSizes(),
+              stridedMetadata.getConstifiedMixedStrides(), indices);
+
+      AffineExpr s0;
+      bindSymbols(rewriter.getContext(), s0);
+      int64_t scaler = dstBits / srcBits;
+      OpFoldResult scaledLinearizedIndices =
+          affine::makeComposedFoldedAffineApply(
+              rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+      Value newLoad = rewriter.create<memref::LoadOp>(
+          loc, adaptor.getMemref(),
+          getValueOrCreateConstantIndexOp(rewriter, loc,
+                                          scaledLinearizedIndices));
+
+      // Get the offset and shift the bits to the rightmost.
+      // Note, currently only the big-endian is supported.
+      Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
+                                                  srcBits, dstBits, rewriter);
+      bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
     }
 
-    // Get the offset and shift the bits to the rightmost.
-    // Note, currently only the big-endian is supported.
-    auto castLastIdx =
-        rewriter.create<arith::IndexCastUIOp>(loc, srcElementType, lastIdx);
-
-    Value BitwidthOffset =
-        getOffsetForBitwidth(loc, castLastIdx, srcBits, dstBits, rewriter);
-    auto bitsLoad =
-        rewriter.create<arith::ShRSIOp>(loc, newLoad, BitwidthOffset);
-
     // Get the corresponding bits. If the arith computation bitwidth equals
     // to the emulated bitwidth, we apply a mask to extract the low bits.
     // It is not clear if this case actually happens in practice, but we keep
@@ -172,10 +195,10 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
     // is 
diff erent from the emulated bitwidth we truncate the result.
     Operation *result;
     auto resultTy = getTypeConverter()->convertType(oldElementType);
-    if (resultTy == srcElementType) {
+    if (resultTy == convertedElementType) {
       auto mask = rewriter.create<arith::ConstantOp>(
-          loc, srcElementType,
-          rewriter.getIntegerAttr(srcElementType, (1 << srcBits) - 1));
+          loc, convertedElementType,
+          rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
 
       result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
     } else {
@@ -200,6 +223,25 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
   patterns
       .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
           typeConverter, patterns.getContext());
+  memref::populateResolveExtractStridedMetadataPatterns(patterns);
+}
+
+static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
+                                               int dstBits) {
+  if (ty.getRank() == 0)
+    return {};
+
+  int64_t linearizedShape = 1;
+  for (auto shape : ty.getShape()) {
+    if (shape == ShapedType::kDynamic)
+      return {ShapedType::kDynamic};
+    linearizedShape *= shape;
+  }
+  int scale = dstBits / srcBits;
+  // Scale the size to the ceilDiv(linearizedShape, scale)
+  // to accomodate all the values.
+  linearizedShape = (linearizedShape + scale - 1) / scale;
+  return {linearizedShape};
 }
 
 void memref::populateMemRefNarrowTypeEmulationConversions(
@@ -215,11 +257,26 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
         if (width >= loadStoreWidth)
           return ty;
 
+        // Currently only handle innermost stride being 1, checking
+        SmallVector<int64_t> strides;
+        int64_t offset;
+        if (failed(getStridesAndOffset(ty, strides, offset)))
+          return std::nullopt;
+        if (!strides.empty() && strides.back() != 1)
+          return std::nullopt;
+
         auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
                                           intTy.getSignedness());
         if (!newElemTy)
           return std::nullopt;
 
-        return ty.cloneWith(std::nullopt, newElemTy);
+        StridedLayoutAttr layoutAttr;
+        if (offset != 0) {
+          layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
+                                              ArrayRef<int64_t>{1});
+        }
+
+        return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
+                               newElemTy, layoutAttr, ty.getMemorySpace());
       });
 }

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index ff2c4107ee46dc..672ef3eb4cd50f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -687,13 +687,17 @@ struct ExtractStridedMetadataOpAllocFolder
 
     auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
     int64_t offset = 0;
-    if (allocLikeOp.getType() == baseBufferType)
-      results.push_back(allocLikeOp);
-    else
-      results.push_back(rewriter.create<memref::ReinterpretCastOp>(
-          loc, baseBufferType, allocLikeOp, offset,
-          /*sizes=*/ArrayRef<int64_t>(),
-          /*strides=*/ArrayRef<int64_t>()));
+    if (op.getBaseBuffer().use_empty()) {
+      results.push_back(nullptr);
+    } else {
+      if (allocLikeOp.getType() == baseBufferType)
+        results.push_back(allocLikeOp);
+      else
+        results.push_back(rewriter.create<memref::ReinterpretCastOp>(
+            loc, baseBufferType, allocLikeOp, offset,
+            /*sizes=*/ArrayRef<int64_t>(),
+            /*strides=*/ArrayRef<int64_t>()));
+    }
 
     // Offset.
     results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));

diff  --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index f6e3f97f455bee..e640248af6e499 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -46,79 +46,78 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
   return curDim < 0;
 }
 
-std::pair<Value, Value>
-getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits,
-                            int dstBits, SmallVector<Value> indices,
-                            memref::ExtractStridedMetadataOp stridedMetadata,
-                            OpBuilder &builder) {
-  auto srcElementType = sourceType.getElementType();
-  unsigned sourceRank = indices.size();
-
-  Value baseBuffer = stridedMetadata.getBaseBuffer();
-  SmallVector<Value> baseSizes = stridedMetadata.getSizes();
-  SmallVector<Value> baseStrides = stridedMetadata.getStrides();
-  Value baseOffset = stridedMetadata.getOffset();
-  assert(indices.size() == baseStrides.size());
+std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
+    OpBuilder &builder, Location loc, int srcBits, int dstBits,
+    OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
+    ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices) {
+  unsigned sourceRank = sizes.size();
+  assert(sizes.size() == strides.size() &&
+         "expected as many sizes as strides for a memref");
+  SmallVector indicesVec = llvm::to_vector(indices);
+  if (indices.empty())
+    indicesVec.resize(sourceRank, builder.getIndexAttr(0));
+  assert(indicesVec.size() == strides.size() &&
+         "expected as many indices as rank of memref");
 
   // Create the affine symbols and values for linearization.
-  SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
+  SmallVector<AffineExpr> symbols(2 * sourceRank);
   bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
-  symbols[0] = builder.getAffineSymbolExpr(0);
-  AffineExpr addMulMap = symbols.front();
-  AffineExpr mulMap = symbols.front();
+  AffineExpr addMulMap = builder.getAffineConstantExpr(0);
+  AffineExpr mulMap = builder.getAffineConstantExpr(1);
 
-  SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
-  offsetValues[0] = builder.getIndexAttr(0);
-  SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
-  sizeValues[0] = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
+  SmallVector<OpFoldResult> sizeValues(sourceRank);
 
   for (unsigned i = 0; i < sourceRank; ++i) {
-    unsigned offsetIdx = 2 * i + 1;
+    unsigned offsetIdx = 2 * i;
     addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
-    offsetValues[offsetIdx] = indices[i];
-    offsetValues[offsetIdx + 1] = baseStrides[i];
+    offsetValues[offsetIdx] = indicesVec[i];
+    offsetValues[offsetIdx + 1] = strides[i];
 
-    unsigned sizeIdx = i + 1;
-    mulMap = mulMap * symbols[sizeIdx];
-    sizeValues[sizeIdx] = baseSizes[i];
+    mulMap = mulMap * symbols[i];
   }
 
-  // Adjust linearizedOffset by the scale factor (dstBits / srcBits).
-  OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
-  AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
-  offsetValues.back() = scaler;
+  // Adjust linearizedIndices, size and offset by the scale factor (dstBits /
+  // srcBits).
+  int64_t scaler = dstBits / srcBits;
+  addMulMap = addMulMap.floorDiv(scaler);
+  mulMap = mulMap.floorDiv(scaler);
 
-  OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
-      builder, loc, scaledAddMulMap, offsetValues);
+  OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
+      builder, loc, addMulMap, offsetValues);
   OpFoldResult linearizedSize =
-      affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);
+      affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
 
   // Adjust baseOffset by the scale factor (dstBits / srcBits).
-  AffineExpr s0, s1;
-  bindSymbols(builder.getContext(), s0, s1);
+  AffineExpr s0;
+  bindSymbols(builder.getContext(), s0);
   OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
-      builder, loc, s0.floorDiv(s1), {baseOffset, scaler});
-
-  // Flatten n-D MemRef to 1-D MemRef.
-  std::optional<int64_t> stride =
-      getConstantIntValue(stridedMetadata.getConstifiedMixedStrides().back());
-  auto layoutAttr =
-      StridedLayoutAttr::get(sourceType.getContext(), ShapedType::kDynamic,
-                             {stride ? stride.value() : ShapedType::kDynamic});
-  int64_t staticShape = sourceType.hasStaticShape()
-                            ? sourceType.getNumElements()
-                            : ShapedType::kDynamic;
-  auto flattenMemrefType = MemRefType::get(
-      staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());
-
-  auto reinterpret = builder.create<memref::ReinterpretCastOp>(
-      loc, flattenMemrefType, baseBuffer,
-      getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
-      getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
-      baseStrides.back());
-
-  return std::make_pair(reinterpret, getValueOrCreateConstantIndexOp(
-                                         builder, loc, linearizedOffset));
+      builder, loc, s0.floorDiv(scaler), {offset});
+
+  return {{adjustBaseOffset, linearizedSize}, linearizedIndices};
+}
+
+LinearizedMemRefInfo
+getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
+                                 int dstBits, OpFoldResult offset,
+                                 ArrayRef<OpFoldResult> sizes) {
+  SmallVector<OpFoldResult> strides(sizes.size());
+  if (sizes.size() > 0) {
+    strides.back() = builder.getIndexAttr(1);
+    AffineExpr s0, s1;
+    bindSymbols(builder.getContext(), s0, s1);
+    for (int index = sizes.size() - 1; index > 0; --index) {
+      strides[index - 1] = affine::makeComposedFoldedAffineApply(
+          builder, loc, s0 * s1,
+          ArrayRef<OpFoldResult>{strides[index], sizes[index]});
+    }
+  }
+
+  LinearizedMemRefInfo linearizedMemRefInfo;
+  std::tie(linearizedMemRefInfo, std::ignore) =
+      getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset,
+                                       sizes, strides);
+  return linearizedMemRefInfo;
 }
 
 } // namespace memref

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 7e747aaa450ab5..01ee43354a711e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -69,19 +69,24 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     if (origElements % scale != 0)
       return failure();
 
-    auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
-        loc, adaptor.getBase());
-
-    auto [reinterpret, linearizedOffset] = memref::getLinearizeMemRefAndOffset(
-        loc, sourceType, srcBits, dstBits, adaptor.getIndices(),
-        stridedMetadata, rewriter);
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+
+    OpFoldResult linearizedIndices;
+    std::tie(std::ignore, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(),
+            getAsOpFoldResult(adaptor.getIndices()));
 
     auto srcElementType = sourceType.getElementType();
     auto numElements =
         static_cast<int>(std::ceil(static_cast<double>(origElements) / scale));
     auto newLoad = rewriter.create<vector::LoadOp>(
-        loc, VectorType::get(numElements, srcElementType), reinterpret,
-        linearizedOffset);
+        loc, VectorType::get(numElements, srcElementType), adaptor.getBase(),
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
 
     numElements *= scale;
     auto castType = VectorType::get(numElements, oldElementType);

diff  --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir
deleted file mode 100644
index f2db905370a588..00000000000000
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir
+++ /dev/null
@@ -1,107 +0,0 @@
-// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
-// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
-
-// Expect no conversions, i32 is supported.
-// CHECK-LABEL: func @memref_i32
-// CHECK:         [[M:%.+]] = memref.alloc() : memref<4xi32, 1>
-// CHECK-NEXT:    [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1>
-// CHECK-NEXT:    memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1>
-// CHECK-NEXT:    return
-func.func @memref_i32() {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : i32
-    %m = memref.alloc() : memref<4xi32, 1>
-    %v = memref.load %m[%c0] : memref<4xi32, 1>
-    memref.store %c1, %m[%c0] : memref<4xi32, 1>
-    return
-}
-
-// -----
-
-// Expect no conversions, f32 is not an integer type.
-// CHECK-LABEL: func @memref_f32
-// CHECK:         [[M:%.+]] = memref.alloc() : memref<4xf32, 1>
-// CHECK-NEXT:    [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1>
-// CHECK-NEXT:    memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1>
-// CHECK-NEXT:    return
-func.func @memref_f32() {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1.0 : f32
-    %m = memref.alloc() : memref<4xf32, 1>
-    %v = memref.load %m[%c0] : memref<4xf32, 1>
-    memref.store %c1, %m[%c0] : memref<4xf32, 1>
-    return
-}
-
-// -----
-
-// CHECK-LABEL: func @memref_load_i4_zero_rank
-// CHECK-NEXT:    %[[M:.*]]  = memref.alloc() : memref<i8>
-// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[M]] : memref<i8> -> memref<i8>, index
-// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[M]][] : memref<i8>
-// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[OFFSET]] : index to i8
-// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
-// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
-// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
-// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
-// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
-// CHECK-NEXT:    %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
-// CHECK-NEXT:    return
-func.func @memref_load_i4_zero_rank() {
-    %0 = memref.alloc() : memref<i4>
-    %1 = memref.load %0[] : memref<i4>
-    return
-}
-
-// -----
-
-// CHECK-LABEL: func @memref_load_i4
-// CHECK-SAME:    (%[[ARG:.*]]: index)
-// CHECK-NEXT:    %[[M:.*]]  = memref.alloc() : memref<4xi8>
-// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
-// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
-// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
-// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[1], offset: ?>>
-// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[1], offset: ?>>
-// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
-// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
-// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
-// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
-// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
-// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
-// CHECK-NEXT:    %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
-// CHECK-NEXT:    return
-func.func @memref_load_i4(%arg0: index) {
-    %0 = memref.alloc() : memref<4xi4>
-    %1 = memref.load %0[%arg0] : memref<4xi4>
-    return
-}
-
-// -----
-
-// CHECK-LABEL: func @memref_load_i4_rank2
-// CHECK-SAME:    (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
-// CHECK-NEXT:    memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
-// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
-// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
-// CHECK-NEXT:    %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
-// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
-// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[1], offset: ?>>
-// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[1], offset: ?>>
-// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
-// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
-// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
-// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
-// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
-// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
-// CHECK-NEXT:    %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
-// CHECK-NEXT:    return
-func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
-    memref.assume_alignment %0, 64 : memref<4x128xi4>
-    %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
-    return
-}

diff  --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
deleted file mode 100644
index 19625e2f2beb17..00000000000000
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
+++ /dev/null
@@ -1,72 +0,0 @@
-// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8" %s | FileCheck %s
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
-// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
-
-// Expect no conversions.
-// CHECK-LABEL: func @memref_i8
-// CHECK:         [[M:%.+]] = memref.alloc() : memref<4xi8, 1>
-// CHECK-NEXT:    [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi8, 1>
-// CHECK-NEXT:    memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi8, 1>
-// CHECK-NEXT:    return
-func.func @memref_i8() {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : i8
-    %m = memref.alloc() : memref<4xi8, 1>
-    %v = memref.load %m[%c0] : memref<4xi8, 1>
-    memref.store %c1, %m[%c0] : memref<4xi8, 1>
-    return
-}
-
-// -----
-
-// CHECK-LABEL: func @memref_load_i4
-// CHECK-SAME:    (%[[ARG:.*]]: index)
-// CHECK-NEXT:    %[[M:.*]]  = memref.alloc() : memref<4xi8>
-// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
-// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
-// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
-// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[1], offset: ?>>
-// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[1], offset: ?>>
-// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
-// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
-// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
-// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
-// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
-// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
-// CHECK-NEXT:    %[[MASK:.*]] = arith.constant 15 : i8
-// CHECK-NEXT:    %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
-// CHECK-NEXT:    return
-func.func @memref_load_i4(%arg0: index) {
-    %0 = memref.alloc() : memref<4xi4>
-    %1 = memref.load %0[%arg0] : memref<4xi4>
-    return
-}
-
-// -----
-
-// CHECK-LABEL: func @memref_load_i4_rank2
-// CHECK-SAME:    (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
-// CHECK-NEXT:    memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
-// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
-// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
-// CHECK-NEXT:    %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
-// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
-// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[1], offset: ?>>
-// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[1], offset: ?>>
-// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
-// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
-// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
-// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
-// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
-// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
-// CHECK-NEXT:    %[[MASK:.*]] = arith.constant 15 : i8
-// CHECK-NEXT:    %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
-// CHECK-NEXT:    return
-func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
-    memref.assume_alignment %0, 64 : memref<4x128xi4>
-    %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
-    return
-}

diff  --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
new file mode 100644
index 00000000000000..c0050d8c510d53
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -0,0 +1,157 @@
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+
+// Expect no conversions.
+func.func @memref_i8() -> i8 {
+    %c3 = arith.constant 3 : index
+    %m = memref.alloc() : memref<4xi8, 1>
+    %v = memref.load %m[%c3] : memref<4xi8, 1>
+    return %v : i8
+}
+// CHECK-LABEL: func @memref_i8()
+//       CHECK:   %[[M:.+]] = memref.alloc() : memref<4xi8, 1>
+//  CHECK-NEXT:   %[[V:.+]] = memref.load %[[M]][%{{.+}}] : memref<4xi8, 1>
+//  CHECK-NEXT:   return %[[V]]
+
+// CHECK32-LABEL: func @memref_i8()
+//       CHECK32:   %[[M:.+]] = memref.alloc() : memref<1xi32, 1>
+//       CHECK32:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK32:   %[[V:.+]] = memref.load %[[M]][%[[C0]]] : memref<1xi32, 1>
+//       CHECK32:   %[[C24:.+]] = arith.constant 24 : index
+//       CHECK32:   %[[CAST:.+]] = arith.index_cast %[[C24]] : index to i32
+//       CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[V]], %[[CAST]]
+//       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i8
+//  CHECK32-NEXT:   return %[[TRUNC]]
+
+// -----
+
+func.func @memref_load_i4(%arg0: index) -> i4 {
+    %0 = memref.alloc() : memref<5xi4>
+    %1 = memref.load %0[%arg0] : memref<5xi4>
+    return %1 : i4
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+//      CHECK: func @memref_load_i4(
+// CHECK-SAME:     %[[ARG0:.+]]: index
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK:   %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+//      CHECK:   return %[[TRUNC]]
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+//      CHECK32: func @memref_load_i4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK32:   %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+//      CHECK32:   return %[[TRUNC]]
+
+// -----
+
+func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
+    %0 = memref.alloc() : memref<3x125xi4>
+    memref.assume_alignment %0, 64 : memref<3x125xi4>
+    %1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4>
+    return %1 : i4
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)
+//      CHECK: func @memref_load_i4_rank2(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
+//      CHECK:   memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
+//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+//      CHECK:   return %[[TRUNC]]
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)
+//      CHECK32: func @memref_load_i4_rank2(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
+//      CHECK32:   memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
+//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+//      CHECK32:   return %[[TRUNC]]
+
+// -----
+
+func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index) -> i4 {
+  %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
+  %1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
+  return %1 : i4
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
+//      CHECK: func @memref_load_i4_dynamic(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
+//      CHECK:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//      CHECK:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
+//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+//      CHECK:   return %[[TRUNC]]
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
+//      CHECK32: func @memref_load_i4_dynamic(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
+//      CHECK32:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//      CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
+//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+//      CHECK32:   return %[[TRUNC]]
+
+// -----
+
+func.func @rank_zero_memref() -> i4 {
+  %0 = memref.alloc() : memref<i4>
+  %1 = memref.load %0[] : memref<i4>
+  return %1 : i4
+}
+// CHECK-LABEL: func @rank_zero_memref()
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+//       CHECK:   %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref<i8>
+//       CHECK:   %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i8 to i4
+//       CHECK:   return %[[TRUNC]]
+
+// CHECK32-LABEL: func @rank_zero_memref()
+//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+//       CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref<i32>
+//       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
+//       CHECK32:   return %[[TRUNC]]

diff  --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 71c133778087d6..e3c6b098b70ba4 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,36 +1,81 @@
-// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s
-
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
-// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
 
+func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
+    %0 = memref.alloc() : memref<3x4xi8>
+    %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
+    return %1 : vector<4xi8>
+}
 // Expect no conversions, i8 is supported.
-// CHECK-LABEL: func @vector_load_i8
-// CHECK-SAME:  (%[[ARG:.*]]: memref<3x4xi8>, %[[IDX0:.*]]: index, %[[IDX1:.*]]: index)
-// CHECK-NEXT:  [[L:%.+]] = vector.load %[[ARG]][%[[IDX0]], %[[IDX1]]] : memref<3x4xi8>, vector<4xi8>
-// CHECK-NEXT:  return
-func.func @vector_load_i8(%arg0: memref<3x4xi8>, %arg1: index, %arg2: index) {
-    %0 = vector.load %arg0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
-    return
+//      CHECK: func @vector_load_i8(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK-NEXT:   %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
+// CHECK-NEXT:   [[L:%.+]] = vector.load %[[ALLOC]][%[[ARG0]], %[[ARG1]]] : memref<3x4xi8>, vector<4xi8>
+// CHECK-NEXT:   return
+
+//      CHECK32: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
+//      CHECK32: func @vector_load_i8(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[VECLOAD:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
+//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VECLOAD]] : vector<1xi32> to vector<4xi8>
+//      CHECK32:   return %[[VEC_I4]]
+
+// -----
+
+func.func @vector_load_i4(%arg1: index, %arg2: index) -> vector<3x8xi4> {
+    %0 = memref.alloc() : memref<3x8xi4>
+    %cst = arith.constant dense<0> : vector<3x8xi4>
+    %1 = vector.load %0[%arg1, %arg2] : memref<3x8xi4>, vector<8xi4>
+    %2 = vector.insert %1, %cst [0] : vector<8xi4> into vector<3x8xi4>
+    return %2 : vector<3x8xi4>
 }
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_load_i4
+// CHECK-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<12xi8>, vector<4xi8>
+//      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_load_i4
+// CHECK32-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
+//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
 
 // -----
 
-// CHECK-LABEL: func @vector_load_i4
-// CHECK-SAME:  (%[[ARG:.*]]: memref<3x4xi8>, %[[IDX0:.*]]: index, %[[IDX1:.*]]: index)
-// CHECK-NEXT:  %[[CST:.*]] = arith.constant dense<0> : vector<3x4xi4>
-// CHECK-NEXT:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x4xi8> -> memref<i8>, index, index, index, index, index
-// CHECK-NEXT:  %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[IDX0]], %[[STRIDES]]#0, %[[IDX1]], %[[STRIDES]]#1]
-// CHECK-NEXT:  %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
-// CHECK-NEXT:  %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
-// CHECK-NEXT:  %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<12xi8, strided<[1], offset: ?>>
-// CHECK-NEXT:  %[[LOAD:.*]] = vector.load %[[CAST]][%[[INDEX]]] : memref<12xi8, strided<[1], offset: ?>>, vector<2xi8>
-// CHECK-NEXT:  %[[BITCAST:.*]] = vector.bitcast %[[LOAD]] : vector<2xi8> to vector<4xi4>
-// CHECK-NEXT:  %[[INSERT:.*]] = vector.insert %[[BITCAST]], %[[CST]] [0] : vector<4xi4> into vector<3x4xi4>
-// CHECK-NEXT:  return
-func.func @vector_load_i4(%arg0: memref<3x4xi4>, %arg1: index, %arg2: index) {
-    %cst = arith.constant dense<0> : vector<3x4xi4>
-    %0 = vector.load %arg0[%arg1, %arg2] : memref<3x4xi4>, vector<4xi4>
-    %1 = vector.insert %0, %cst [0] : vector<4xi4> into vector<3x4xi4>
-    return
+func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xi4> {
+  %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
+  %1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
+  return %1 : vector<8xi4>
 }
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+//      CHECK: func.func @vector_load_i4_dynamic(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
+//      CHECK:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//      CHECK:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
+//      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+//      CHECK32: func.func @vector_load_i4_dynamic(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
+//      CHECK32:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//      CHECK32:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
+//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>

diff  --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index 64646b01b9f515..eeb26d1876c1db 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -89,8 +89,7 @@ struct TestEmulateNarrowTypePass
     target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
     target.addDynamicallyLegalDialect<
         arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
-        affine::AffineDialect>(
-        [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+        affine::AffineDialect>(opLegalCallback);
 
     RewritePatternSet patterns(ctx);
 


        


More information about the Mlir-commits mailing list