[Mlir-commits] [mlir] 5a1cdcb - [mlir] Narrow bitwidth emulation for MemRef load

Hanhan Wang llvmlistbot at llvm.org
Mon Jun 26 14:21:44 PDT 2023


Author: yzhang93
Date: 2023-06-26T14:18:30-07:00
New Revision: 5a1cdcbd8698cd263696b38e2672fccac9ec793c

URL: https://github.com/llvm/llvm-project/commit/5a1cdcbd8698cd263696b38e2672fccac9ec793c
DIFF: https://github.com/llvm/llvm-project/commit/5a1cdcbd8698cd263696b38e2672fccac9ec793c.diff

LOG: [mlir] Narrow bitwidth emulation for MemRef load

This patch adds support for narrow bitwidth storage emulation. The goal is to support sub-byte type
codegen for LLVM CPU. Specifically, a type converter is added to convert memref of narrow bitwidth
(e.g., i4) into supported wider bitwidth (e.g., i8). Another focus of this patch is to populate the
pattern for int4 memref.load. memref.store pattern should be added in a seperate patch.

Reviewed By: hanchung, mravishankar

Differential Revision: https://reviews.llvm.org/D151519

Added: 
    mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h
    mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
    mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
    mlir/test/Dialect/Arith/emulate-narrow-type.mlir
    mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir
    mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
    mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp

Modified: 
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
    mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
    mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
    mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/MemRef/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h
new file mode 100644
index 00000000000000..528bb51b214a57
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h
@@ -0,0 +1,31 @@
+//===- NarrowTypeEmulationConverter.h - Type Converter for NTE -----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_
+#define MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::arith {
+/// Converts narrow integer or float types that are not supported
+/// by the target hardware to wider types. Currently, we only
+/// handle power-of-two integer types and convert them to wider
+/// integers that are equal or larger than 8 bits.
+class NarrowTypeEmulationConverter : public TypeConverter {
+public:
+  explicit NarrowTypeEmulationConverter(unsigned targetBitwidth);
+
+  unsigned getLoadStoreBitwidth() const { return loadStoreBitwidth; }
+
+private:
+  unsigned loadStoreBitwidth;
+};
+} // namespace mlir::arith
+
+#endif // MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_

diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index c4010b7c0b57a1..de36cb48e6d024 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -22,6 +22,7 @@ namespace arith {
 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
 
 class WideIntEmulationConverter;
+class NarrowTypeEmulationConverter;
 
 /// Create a pass to bufferize Arith ops.
 std::unique_ptr<Pass> createArithBufferizePass();
@@ -35,6 +36,12 @@ std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
 void populateArithWideIntEmulationPatterns(
     WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns);
 
+/// Adds patterns to emulate narrow Arith and Function ops into wide
+/// supported types. Users need to add conversions about the computation
+/// domain of narrow types.
+void populateArithNarrowTypeEmulationPatterns(
+    NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ceil/floor division ops.
 void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
 

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 91ef1620fce643..0b1af47e1becec 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -25,6 +25,7 @@ class ValueRange;
 
 namespace arith {
 class WideIntEmulationConverter;
+class NarrowTypeEmulationConverter;
 } // namespace arith
 
 namespace memref {
@@ -73,6 +74,17 @@ void populateMemRefWideIntEmulationPatterns(
 void populateMemRefWideIntEmulationConversions(
     arith::WideIntEmulationConverter &typeConverter);
 
+/// Appends patterns for emulating memref operations over narrow types with ops
+/// over wider types.
+void populateMemRefNarrowTypeEmulationPatterns(
+    arith::NarrowTypeEmulationConverter &typeConverter,
+    RewritePatternSet &patterns);
+
+/// Appends type conversions for emulating memref operations over narrow types
+/// with ops over wider types.
+void populateMemRefNarrowTypeEmulationConversions(
+    arith::NarrowTypeEmulationConverter &typeConverter);
+
 /// Transformation to do multi-buffering/array expansion to remove dependencies
 /// on the temporary allocation between consecutive loop iterations.
 /// It returns the new allocation if the original allocation was multi-buffered

diff  --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 87d9bebfd2a7c2..b969389f223995 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   EmulateWideInt.cpp
+  EmulateNarrowType.cpp
   ExpandOps.cpp
   IntNarrowing.cpp
   IntRangeOptimizations.cpp

diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
new file mode 100644
index 00000000000000..e0e1385b6bc174
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
@@ -0,0 +1,61 @@
+//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
+    unsigned targetBitwidth)
+    : loadStoreBitwidth(targetBitwidth) {
+  assert(llvm::isPowerOf2_32(targetBitwidth) &&
+         "Only power-of-two integers are supported");
+
+  // Allow unknown types.
+  addConversion([](Type ty) -> std::optional<Type> { return ty; });
+
+  // Function case.
+  addConversion([this](FunctionType ty) -> std::optional<Type> {
+    SmallVector<Type> inputs;
+    if (failed(convertTypes(ty.getInputs(), inputs)))
+      return std::nullopt;
+
+    SmallVector<Type> results;
+    if (failed(convertTypes(ty.getResults(), results)))
+      return std::nullopt;
+
+    return FunctionType::get(ty.getContext(), inputs, results);
+  });
+}
+
+void arith::populateArithNarrowTypeEmulationPatterns(
+    NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) {
+  // Populate `func.*` conversion patterns.
+  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                 typeConverter);
+  populateCallOpTypeConversionPattern(patterns, typeConverter);
+  populateReturnOpTypeConversionPattern(patterns, typeConverter);
+}

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index a16d8505b5c72c..10ca179a6c9e4c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   ExpandOps.cpp
   ExpandStridedMetadata.cpp
   EmulateWideInt.cpp
+  EmulateNarrowType.cpp
   ExtractAddressComputations.cpp
   FoldMemRefAliasOps.cpp
   IndependenceTransforms.cpp

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
new file mode 100644
index 00000000000000..a876bc74801d1b
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -0,0 +1,315 @@
+//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// The emulation only works on 1D memref types.
+/// To make this work on N-D memref, we need to linearize the offset.
+///
+/// For example, to emulate i4 to i8, the following op:
+///
+/// %0 = memref.load %arg0[%v0, %v1] :
+///                  memref<?x?xi4, strided<[?, ?], offset: ?>>
+///
+/// can be replaced with
+///
+/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
+///
+/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
+/// %linearized_size = %size0 * %size1
+/// %scaled_linear_offset = %linearized_offset / 8 * 4
+/// %scaled_base_offset = %offset / 8 * 4
+///
+/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
+///                      sizes = [%linearized_size], strides = [%stride#1]
+///
+/// %new_load = memref.load %linearized[%scaled_linear_offset] :
+///                         memref<?xi8, strided<[?], offset: ?>>
+
+static Value
+linearizeMemrefLoad(Location loc, MemRefType sourceType, int srcBits,
+                    int dstBits, SmallVector<Value> indices,
+                    memref::ExtractStridedMetadataOp stridedMetadata,
+                    OpBuilder &builder) {
+  auto srcElementType = sourceType.getElementType();
+  unsigned sourceRank = indices.size();
+
+  Value baseBuffer = stridedMetadata.getBaseBuffer();
+  SmallVector<Value> baseSizes = stridedMetadata.getSizes();
+  SmallVector<Value> baseStrides = stridedMetadata.getStrides();
+  Value baseOffset = stridedMetadata.getOffset();
+  assert(indices.size() == baseStrides.size());
+
+  // Create the affine symbols and values for linearization.
+  SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
+  bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
+  symbols[0] = builder.getAffineSymbolExpr(0);
+  AffineExpr addMulMap = symbols.front();
+  AffineExpr mulMap = symbols.front();
+
+  SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
+  offsetValues[0] = builder.getIndexAttr(0);
+  SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
+  sizeValues[0] = builder.getIndexAttr(1);
+
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    unsigned offsetIdx = 2 * i + 1;
+    addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
+    offsetValues[offsetIdx] = indices[i];
+    offsetValues[offsetIdx + 1] = baseStrides[i];
+
+    unsigned sizeIdx = i + 1;
+    mulMap = mulMap * symbols[sizeIdx];
+    sizeValues[sizeIdx] = baseSizes[i];
+  }
+
+  // Adjust linearizedOffset by the scale factor (dstBits / srcBits).
+  OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
+  AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
+  offsetValues.back() = scaler;
+
+  OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
+      builder, loc, scaledAddMulMap, offsetValues);
+  OpFoldResult linearizedSize =
+      affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);
+
+  // Adjust baseOffset by the scale factor (dstBits / srcBits).
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+  OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
+      builder, loc, s0.floorDiv(s1), {baseOffset, scaler});
+
+  // Flatten n-D MemRef to 1-D MemRef.
+  auto layoutAttr = StridedLayoutAttr::get(
+      sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic});
+  int64_t staticShape = sourceType.hasStaticShape()
+                            ? sourceType.getNumElements()
+                            : ShapedType::kDynamic;
+  auto flattenMemrefType = MemRefType::get(
+      staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());
+
+  auto reinterpret = builder.create<memref::ReinterpretCastOp>(
+      loc, flattenMemrefType, baseBuffer,
+      getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
+      getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
+      baseStrides.back());
+
+  return builder.create<memref::LoadOp>(
+      loc, srcElementType, reinterpret.getResult(),
+      getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset));
+}
+
+/// When data is loaded/stored in `targetBits` granularity, but is used in
+/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
+/// treated as an array of elements of width `sourceBits`.
+/// Return the bit offset of the value at position `srcIdx`. For example, if
+/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
+/// located at (x % 2) * 4. Because there are two elements in one i8, and one
+/// element has 4 bits.
+static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
+                                  int targetBits, OpBuilder &builder) {
+  assert(targetBits % sourceBits == 0);
+  IntegerType targetType = builder.getIntegerType(targetBits);
+  IntegerAttr idxAttr =
+      builder.getIntegerAttr(targetType, targetBits / sourceBits);
+  auto idx = builder.create<arith::ConstantOp>(loc, targetType, idxAttr);
+  IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
+  auto srcBitsValue =
+      builder.create<arith::ConstantOp>(loc, targetType, srcBitsAttr);
+  auto m = builder.create<arith::RemUIOp>(loc, srcIdx, idx);
+  return builder.create<arith::MulIOp>(loc, targetType, m, srcBitsValue);
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAlloc
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type newTy = getTypeConverter()->convertType(op.getType());
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+    }
+
+    rewriter.replaceOpWithNewOp<memref::AllocOp>(
+        op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
+        adaptor.getAlignmentAttr());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAssumeAlignment
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefAssumeAlignment final
+    : OpConversionPattern<memref::AssumeAlignmentOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+                                      op.getMemref().getType()));
+    }
+
+    rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
+        op, adaptor.getMemref(), adaptor.getAlignmentAttr());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type newTy = getTypeConverter()->convertType(op.getMemRefType());
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+                                      op.getMemRefType()));
+    }
+
+    if (op.getMemRefType() == newTy)
+      return failure();
+
+    auto loc = op.getLoc();
+    auto sourceType = cast<MemRefType>(adaptor.getMemref().getType());
+    unsigned sourceRank = sourceType.getRank();
+    SmallVector<Value> indices = adaptor.getIndices();
+    assert(indices.size() == sourceRank);
+
+    auto srcElementType = sourceType.getElementType();
+    auto oldElementType = op.getMemRefType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = srcElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+        loc, adaptor.getMemref());
+
+    Value newLoad, lastIdx;
+    if (sourceRank == 0) {
+      newLoad = rewriter.create<memref::LoadOp>(
+          loc, srcElementType, adaptor.getMemref(), adaptor.getIndices());
+
+      lastIdx = stridedMetadata.getOffset();
+    } else {
+      newLoad = linearizeMemrefLoad(loc, sourceType, srcBits, dstBits, indices,
+                                    stridedMetadata, rewriter);
+
+      lastIdx = adaptor.getIndices().back();
+    }
+
+    // Get the offset and shift the bits to the rightmost.
+    // Note, currently only the big-endian is supported.
+    auto castLastIdx =
+        rewriter.create<arith::IndexCastUIOp>(loc, srcElementType, lastIdx);
+
+    Value BitwidthOffset =
+        getOffsetForBitwidth(loc, castLastIdx, srcBits, dstBits, rewriter);
+    auto bitsLoad =
+        rewriter.create<arith::ShRSIOp>(loc, newLoad, BitwidthOffset);
+
+    // Get the corresponding bits. If the arith computation bitwidth equals
+    // to the emulated bitwidth, we apply a mask to extract the low bits.
+    // It is not clear if this case actually happens in practice, but we keep
+    // the operations just in case. Otherwise, if the arith computation bitwidth
+    // is 
diff erent from the emulated bitwidth we truncate the result.
+    Operation *result;
+    auto resultTy = getTypeConverter()->convertType(oldElementType);
+    if (resultTy == srcElementType) {
+      auto mask = rewriter.create<arith::ConstantOp>(
+          loc, srcElementType,
+          rewriter.getIntegerAttr(srcElementType, (1 << srcBits) - 1));
+
+      result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
+    } else {
+      result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
+    }
+
+    rewriter.replaceOp(op, result->getResult(0));
+    return success();
+  }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+void memref::populateMemRefNarrowTypeEmulationPatterns(
+    arith::NarrowTypeEmulationConverter &typeConverter,
+    RewritePatternSet &patterns) {
+
+  // Populate `memref.*` conversion patterns.
+  patterns
+      .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
+          typeConverter, patterns.getContext());
+}
+
+void memref::populateMemRefNarrowTypeEmulationConversions(
+    arith::NarrowTypeEmulationConverter &typeConverter) {
+  typeConverter.addConversion(
+      [&typeConverter](MemRefType ty) -> std::optional<Type> {
+        auto intTy = dyn_cast<IntegerType>(ty.getElementType());
+        if (!intTy)
+          return ty;
+
+        unsigned width = intTy.getWidth();
+        unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
+        if (width >= loadStoreWidth)
+          return ty;
+
+        auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
+                                          intTy.getSignedness());
+        if (!newElemTy)
+          return std::nullopt;
+
+        return ty.cloneWith(std::nullopt, newElemTy);
+      });
+}

diff  --git a/mlir/test/Dialect/Arith/emulate-narrow-type.mlir b/mlir/test/Dialect/Arith/emulate-narrow-type.mlir
new file mode 100644
index 00000000000000..7120882b0c07c0
--- /dev/null
+++ b/mlir/test/Dialect/Arith/emulate-narrow-type.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8" %s | FileCheck %s
+
+// Expect no conversions, f32 is not an integer type.
+// CHECK-LABEL: func @identity_f32
+// CHECK-SAME:    ([[ARG:%.+]]: f32) -> f32
+// CHECK-NEXT:    return [[ARG]] : f32
+func.func @identity_f32(%a : f32) -> f32 {
+    return %a : f32
+}
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @identity_i32
+// CHECK-SAME:    ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    return [[ARG]] : vector<2xi32>
+func.func @identity_i32(%a : vector<2xi32>) -> vector<2xi32> {
+    return %a : vector<2xi32>
+}
+
+// CHECK-LABEL: func @identity_scalar
+// CHECK-SAME:     ([[ARG:%.+]]: i8) -> i8
+// CHECK-NEXT:     return [[ARG]] : i8
+func.func @identity_scalar(%x : i4) -> i4 {
+    return %x : i4
+}
+
+// CHECK-LABEL: func @identity_vector
+// CHECK-SAME:     ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT:     return [[ARG]] : vector<4xi8>
+func.func @identity_vector(%x : vector<4xi4>) -> vector<4xi4> {
+    return %x : vector<4xi4>
+}
+
+// CHECK-LABEL: func @identity_vector2d
+// CHECK-SAME:     ([[ARG:%.+]]: vector<3x4xi8>) -> vector<3x4xi8>
+// CHECK-NEXT:     return [[ARG]] : vector<3x4xi8>
+func.func @identity_vector2d(%x : vector<3x4xi4>) -> vector<3x4xi4> {
+    return %x : vector<3x4xi4>
+}
+
+// CHECK-LABEL: func @call
+// CHECK-SAME:     ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT:     [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT:     return [[RES]] : vector<4xi8>
+func.func @call(%a : vector<4xi4>) -> vector<4xi4> {
+    %res = func.call @identity_vector(%a) : (vector<4xi4>) -> vector<4xi4>
+    return %res : vector<4xi4>
+}

diff  --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir
new file mode 100644
index 00000000000000..85d4cc18c0d38f
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @memref_i32
+// CHECK:         [[M:%.+]] = memref.alloc() : memref<4xi32, 1>
+// CHECK-NEXT:    [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT:    memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT:    return
+func.func @memref_i32() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : i32
+    %m = memref.alloc() : memref<4xi32, 1>
+    %v = memref.load %m[%c0] : memref<4xi32, 1>
+    memref.store %c1, %m[%c0] : memref<4xi32, 1>
+    return
+}
+
+// -----
+
+// Expect no conversions, f32 is not an integer type.
+// CHECK-LABEL: func @memref_f32
+// CHECK:         [[M:%.+]] = memref.alloc() : memref<4xf32, 1>
+// CHECK-NEXT:    [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT:    memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT:    return
+func.func @memref_f32() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1.0 : f32
+    %m = memref.alloc() : memref<4xf32, 1>
+    %v = memref.load %m[%c0] : memref<4xf32, 1>
+    memref.store %c1, %m[%c0] : memref<4xf32, 1>
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_zero_rank
+// CHECK-NEXT:    %[[M:.*]]  = memref.alloc() : memref<i8>
+// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[M]] : memref<i8> -> memref<i8>, index
+// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[M]][] : memref<i8>
+// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[OFFSET]] : index to i8
+// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT:    %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT:    return
+func.func @memref_load_i4_zero_rank() {
+    %0 = memref.alloc() : memref<i4>
+    %1 = memref.load %0[] : memref<i4>
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4
+// CHECK-SAME:    (%[[ARG:.*]]: index)
+// CHECK-NEXT:    %[[M:.*]]  = memref.alloc() : memref<4xi8>
+// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
+// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
+// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
+// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT:    %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT:    return
+func.func @memref_load_i4(%arg0: index) {
+    %0 = memref.alloc() : memref<4xi4>
+    %1 = memref.load %0[%arg0] : memref<4xi4>
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_rank2
+// CHECK-SAME:    (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT:    memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
+// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
+// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK-NEXT:    %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
+// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT:    %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT:    return
+func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
+    memref.assume_alignment %0, 64 : memref<4x128xi4>
+    %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
+    return
+}

diff  --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
new file mode 100644
index 00000000000000..9d63b9d1acd08f
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8" %s | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+
+// Expect no conversions.
+// CHECK-LABEL: func @memref_i8
+// CHECK:         [[M:%.+]] = memref.alloc() : memref<4xi8, 1>
+// CHECK-NEXT:    [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi8, 1>
+// CHECK-NEXT:    memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi8, 1>
+// CHECK-NEXT:    return
+func.func @memref_i8() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : i8
+    %m = memref.alloc() : memref<4xi8, 1>
+    %v = memref.load %m[%c0] : memref<4xi8, 1>
+    memref.store %c1, %m[%c0] : memref<4xi8, 1>
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4
+// CHECK-SAME:    (%[[ARG:.*]]: index)
+// CHECK-NEXT:    %[[M:.*]]  = memref.alloc() : memref<4xi8>
+// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
+// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
+// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
+// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT:    %[[MASK:.*]] = arith.constant 15 : i8
+// CHECK-NEXT:    %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
+// CHECK-NEXT:    return
+func.func @memref_load_i4(%arg0: index) {
+    %0 = memref.alloc() : memref<4xi4>
+    %1 = memref.load %0[%arg0] : memref<4xi4>
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_rank2
+// CHECK-SAME:    (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT:    memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
+// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
+// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK-NEXT:    %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
+// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT:    %[[MASK:.*]] = arith.constant 15 : i8
+// CHECK-NEXT:    %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
+// CHECK-NEXT:    return
+func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
+    memref.assume_alignment %0, 64 : memref<4x128xi4>
+    %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
+    return
+}

diff  --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
index df3fdacd09358f..0498de3eb93178 100644
--- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
@@ -1,6 +1,7 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMemRefTestPasses
   TestComposeSubView.cpp
+  TestEmulateNarrowType.cpp
   TestMultiBuffer.cpp
 
   EXCLUDE_FROM_LIBMLIR

diff  --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
new file mode 100644
index 00000000000000..b1f23084449c2b
--- /dev/null
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -0,0 +1,118 @@
+//===- TestEmulateNarrowType.cpp - Test Narrow Type Emulation  ------*- c++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestEmulateNarrowTypePass
+    : public PassWrapper<TestEmulateNarrowTypePass,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowTypePass)
+
+  TestEmulateNarrowTypePass() = default;
+  TestEmulateNarrowTypePass(const TestEmulateNarrowTypePass &pass)
+      : PassWrapper(pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
+                vector::VectorDialect, affine::AffineDialect>();
+  }
+  StringRef getArgument() const final { return "test-emulate-narrow-int"; }
+  StringRef getDescription() const final {
+    return "Function pass to test Narrow Integer Emulation";
+  }
+
+  void runOnOperation() override {
+    if (!llvm::isPowerOf2_32(loadStoreEmulateBitwidth) ||
+        loadStoreEmulateBitwidth < 8) {
+      signalPassFailure();
+      return;
+    }
+
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+
+    arith::NarrowTypeEmulationConverter typeConverter(loadStoreEmulateBitwidth);
+
+    // Convert scalar type.
+    typeConverter.addConversion([this](IntegerType ty) -> std::optional<Type> {
+      unsigned width = ty.getWidth();
+      if (width >= arithComputeBitwidth)
+        return ty;
+
+      return IntegerType::get(ty.getContext(), arithComputeBitwidth);
+    });
+
+    // Convert vector type.
+    typeConverter.addConversion([this](VectorType ty) -> std::optional<Type> {
+      auto intTy = dyn_cast<IntegerType>(ty.getElementType());
+      if (!intTy)
+        return ty;
+
+      unsigned width = intTy.getWidth();
+      if (width >= arithComputeBitwidth)
+        return ty;
+
+      return VectorType::get(
+          to_vector(ty.getShape()),
+          IntegerType::get(ty.getContext(), arithComputeBitwidth));
+    });
+
+    memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
+    ConversionTarget target(*ctx);
+    target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
+      return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
+    });
+    auto opLegalCallback = [&typeConverter](Operation *op) {
+      return typeConverter.isLegal(op);
+    };
+    target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
+    target.addDynamicallyLegalDialect<
+        arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
+        affine::AffineDialect>(
+        [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+    RewritePatternSet patterns(ctx);
+
+    arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
+    memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
+
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+  }
+
+  Option<unsigned> loadStoreEmulateBitwidth{
+      *this, "memref-load-bitwidth",
+      llvm::cl::desc("memref load/store emulation bit width"),
+      llvm::cl::init(8)};
+
+  Option<unsigned> arithComputeBitwidth{
+      *this, "arith-compute-bitwidth",
+      llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)};
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestEmulateNarrowTypePass() {
+  PassRegistration<TestEmulateNarrowTypePass>();
+}
+} // namespace mlir::test

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index d75b54e8bd45ac..5b956635b960a1 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -87,6 +87,7 @@ void registerTestDiagnosticsPass();
 void registerTestDialectConversionPasses();
 void registerTestDominancePass();
 void registerTestDynamicPipelinePass();
+void registerTestEmulateNarrowTypePass();
 void registerTestExpandMathPass();
 void registerTestFooAnalysisPass();
 void registerTestComposeSubView();
@@ -205,6 +206,7 @@ void registerTestPasses() {
   mlir::test::registerTestDeadCodeAnalysisPass();
   mlir::test::registerTestDominancePass();
   mlir::test::registerTestDynamicPipelinePass();
+  mlir::test::registerTestEmulateNarrowTypePass();
   mlir::test::registerTestExpandMathPass();
   mlir::test::registerTestFooAnalysisPass();
   mlir::test::registerTestComposeSubView();


        


More information about the Mlir-commits mailing list