[Mlir-commits] [mlir] 5a1cdcb - [mlir] Narrow bitwidth emulation for MemRef load
Hanhan Wang
llvmlistbot at llvm.org
Mon Jun 26 14:21:44 PDT 2023
Author: yzhang93
Date: 2023-06-26T14:18:30-07:00
New Revision: 5a1cdcbd8698cd263696b38e2672fccac9ec793c
URL: https://github.com/llvm/llvm-project/commit/5a1cdcbd8698cd263696b38e2672fccac9ec793c
DIFF: https://github.com/llvm/llvm-project/commit/5a1cdcbd8698cd263696b38e2672fccac9ec793c.diff
LOG: [mlir] Narrow bitwidth emulation for MemRef load
This patch adds support for narrow bitwidth storage emulation. The goal is to support sub-byte type
codegen for LLVM CPU. Specifically, a type converter is added to convert memref of narrow bitwidth
(e.g., i4) into supported wider bitwidth (e.g., i8). Another focus of this patch is to populate the
pattern for int4 memref.load. memref.store pattern should be added in a seperate patch.
Reviewed By: hanchung, mravishankar
Differential Revision: https://reviews.llvm.org/D151519
Added:
mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h
mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
mlir/test/Dialect/Arith/emulate-narrow-type.mlir
mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir
mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
Modified:
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/MemRef/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h
new file mode 100644
index 00000000000000..528bb51b214a57
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h
@@ -0,0 +1,31 @@
+//===- NarrowTypeEmulationConverter.h - Type Converter for NTE -----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_
+#define MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::arith {
+/// Converts narrow integer or float types that are not supported
+/// by the target hardware to wider types. Currently, we only
+/// handle power-of-two integer types and convert them to wider
+/// integers that are equal or larger than 8 bits.
+class NarrowTypeEmulationConverter : public TypeConverter {
+public:
+ explicit NarrowTypeEmulationConverter(unsigned targetBitwidth);
+
+ unsigned getLoadStoreBitwidth() const { return loadStoreBitwidth; }
+
+private:
+ unsigned loadStoreBitwidth;
+};
+} // namespace mlir::arith
+
+#endif // MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index c4010b7c0b57a1..de36cb48e6d024 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -22,6 +22,7 @@ namespace arith {
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
class WideIntEmulationConverter;
+class NarrowTypeEmulationConverter;
/// Create a pass to bufferize Arith ops.
std::unique_ptr<Pass> createArithBufferizePass();
@@ -35,6 +36,12 @@ std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
void populateArithWideIntEmulationPatterns(
WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns);
+/// Adds patterns to emulate narrow Arith and Function ops into wide
+/// supported types. Users need to add conversions about the computation
+/// domain of narrow types.
+void populateArithNarrowTypeEmulationPatterns(
+ NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns);
+
/// Add patterns to expand Arith ceil/floor division ops.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 91ef1620fce643..0b1af47e1becec 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -25,6 +25,7 @@ class ValueRange;
namespace arith {
class WideIntEmulationConverter;
+class NarrowTypeEmulationConverter;
} // namespace arith
namespace memref {
@@ -73,6 +74,17 @@ void populateMemRefWideIntEmulationPatterns(
void populateMemRefWideIntEmulationConversions(
arith::WideIntEmulationConverter &typeConverter);
+/// Appends patterns for emulating memref operations over narrow types with ops
+/// over wider types.
+void populateMemRefNarrowTypeEmulationPatterns(
+ arith::NarrowTypeEmulationConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+/// Appends type conversions for emulating memref operations over narrow types
+/// with ops over wider types.
+void populateMemRefNarrowTypeEmulationConversions(
+ arith::NarrowTypeEmulationConverter &typeConverter);
+
/// Transformation to do multi-buffering/array expansion to remove dependencies
/// on the temporary allocation between consecutive loop iterations.
/// It returns the new allocation if the original allocation was multi-buffered
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 87d9bebfd2a7c2..b969389f223995 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
EmulateWideInt.cpp
+ EmulateNarrowType.cpp
ExpandOps.cpp
IntNarrowing.cpp
IntRangeOptimizations.cpp
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
new file mode 100644
index 00000000000000..e0e1385b6bc174
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
@@ -0,0 +1,61 @@
+//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
+ unsigned targetBitwidth)
+ : loadStoreBitwidth(targetBitwidth) {
+ assert(llvm::isPowerOf2_32(targetBitwidth) &&
+ "Only power-of-two integers are supported");
+
+ // Allow unknown types.
+ addConversion([](Type ty) -> std::optional<Type> { return ty; });
+
+ // Function case.
+ addConversion([this](FunctionType ty) -> std::optional<Type> {
+ SmallVector<Type> inputs;
+ if (failed(convertTypes(ty.getInputs(), inputs)))
+ return std::nullopt;
+
+ SmallVector<Type> results;
+ if (failed(convertTypes(ty.getResults(), results)))
+ return std::nullopt;
+
+ return FunctionType::get(ty.getContext(), inputs, results);
+ });
+}
+
+void arith::populateArithNarrowTypeEmulationPatterns(
+ NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) {
+ // Populate `func.*` conversion patterns.
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index a16d8505b5c72c..10ca179a6c9e4c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
ExpandOps.cpp
ExpandStridedMetadata.cpp
EmulateWideInt.cpp
+ EmulateNarrowType.cpp
ExtractAddressComputations.cpp
FoldMemRefAliasOps.cpp
IndependenceTransforms.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
new file mode 100644
index 00000000000000..a876bc74801d1b
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -0,0 +1,315 @@
+//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// The emulation only works on 1D memref types.
+/// To make this work on N-D memref, we need to linearize the offset.
+///
+/// For example, to emulate i4 to i8, the following op:
+///
+/// %0 = memref.load %arg0[%v0, %v1] :
+/// memref<?x?xi4, strided<[?, ?], offset: ?>>
+///
+/// can be replaced with
+///
+/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
+///
+/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
+/// %linearized_size = %size0 * %size1
+/// %scaled_linear_offset = %linearized_offset / 8 * 4
+/// %scaled_base_offset = %offset / 8 * 4
+///
+/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
+/// sizes = [%linearized_size], strides = [%stride#1]
+///
+/// %new_load = memref.load %linearized[%scaled_linear_offset] :
+/// memref<?xi8, strided<[?], offset: ?>>
+
+static Value
+linearizeMemrefLoad(Location loc, MemRefType sourceType, int srcBits,
+ int dstBits, SmallVector<Value> indices,
+ memref::ExtractStridedMetadataOp stridedMetadata,
+ OpBuilder &builder) {
+ auto srcElementType = sourceType.getElementType();
+ unsigned sourceRank = indices.size();
+
+ Value baseBuffer = stridedMetadata.getBaseBuffer();
+ SmallVector<Value> baseSizes = stridedMetadata.getSizes();
+ SmallVector<Value> baseStrides = stridedMetadata.getStrides();
+ Value baseOffset = stridedMetadata.getOffset();
+ assert(indices.size() == baseStrides.size());
+
+ // Create the affine symbols and values for linearization.
+ SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
+ bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
+ symbols[0] = builder.getAffineSymbolExpr(0);
+ AffineExpr addMulMap = symbols.front();
+ AffineExpr mulMap = symbols.front();
+
+ SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
+ offsetValues[0] = builder.getIndexAttr(0);
+ SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
+ sizeValues[0] = builder.getIndexAttr(1);
+
+ for (unsigned i = 0; i < sourceRank; ++i) {
+ unsigned offsetIdx = 2 * i + 1;
+ addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
+ offsetValues[offsetIdx] = indices[i];
+ offsetValues[offsetIdx + 1] = baseStrides[i];
+
+ unsigned sizeIdx = i + 1;
+ mulMap = mulMap * symbols[sizeIdx];
+ sizeValues[sizeIdx] = baseSizes[i];
+ }
+
+ // Adjust linearizedOffset by the scale factor (dstBits / srcBits).
+ OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
+ AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
+ offsetValues.back() = scaler;
+
+ OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
+ builder, loc, scaledAddMulMap, offsetValues);
+ OpFoldResult linearizedSize =
+ affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);
+
+ // Adjust baseOffset by the scale factor (dstBits / srcBits).
+ AffineExpr s0, s1;
+ bindSymbols(builder.getContext(), s0, s1);
+ OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
+ builder, loc, s0.floorDiv(s1), {baseOffset, scaler});
+
+ // Flatten n-D MemRef to 1-D MemRef.
+ auto layoutAttr = StridedLayoutAttr::get(
+ sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic});
+ int64_t staticShape = sourceType.hasStaticShape()
+ ? sourceType.getNumElements()
+ : ShapedType::kDynamic;
+ auto flattenMemrefType = MemRefType::get(
+ staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());
+
+ auto reinterpret = builder.create<memref::ReinterpretCastOp>(
+ loc, flattenMemrefType, baseBuffer,
+ getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
+ getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
+ baseStrides.back());
+
+ return builder.create<memref::LoadOp>(
+ loc, srcElementType, reinterpret.getResult(),
+ getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset));
+}
+
+/// When data is loaded/stored in `targetBits` granularity, but is used in
+/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
+/// treated as an array of elements of width `sourceBits`.
+/// Return the bit offset of the value at position `srcIdx`. For example, if
+/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
+/// located at (x % 2) * 4. Because there are two elements in one i8, and one
+/// element has 4 bits.
+static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
+ int targetBits, OpBuilder &builder) {
+ assert(targetBits % sourceBits == 0);
+ IntegerType targetType = builder.getIntegerType(targetBits);
+ IntegerAttr idxAttr =
+ builder.getIntegerAttr(targetType, targetBits / sourceBits);
+ auto idx = builder.create<arith::ConstantOp>(loc, targetType, idxAttr);
+ IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
+ auto srcBitsValue =
+ builder.create<arith::ConstantOp>(loc, targetType, srcBitsAttr);
+ auto m = builder.create<arith::RemUIOp>(loc, srcIdx, idx);
+ return builder.create<arith::MulIOp>(loc, targetType, m, srcBitsValue);
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAlloc
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getType());
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ }
+
+ rewriter.replaceOpWithNewOp<memref::AllocOp>(
+ op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAssumeAlignment
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefAssumeAlignment final
+ : OpConversionPattern<memref::AssumeAlignmentOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+ op.getMemref().getType()));
+ }
+
+ rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
+ op, adaptor.getMemref(), adaptor.getAlignmentAttr());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getMemRefType());
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+ op.getMemRefType()));
+ }
+
+ if (op.getMemRefType() == newTy)
+ return failure();
+
+ auto loc = op.getLoc();
+ auto sourceType = cast<MemRefType>(adaptor.getMemref().getType());
+ unsigned sourceRank = sourceType.getRank();
+ SmallVector<Value> indices = adaptor.getIndices();
+ assert(indices.size() == sourceRank);
+
+ auto srcElementType = sourceType.getElementType();
+ auto oldElementType = op.getMemRefType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = srcElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, adaptor.getMemref());
+
+ Value newLoad, lastIdx;
+ if (sourceRank == 0) {
+ newLoad = rewriter.create<memref::LoadOp>(
+ loc, srcElementType, adaptor.getMemref(), adaptor.getIndices());
+
+ lastIdx = stridedMetadata.getOffset();
+ } else {
+ newLoad = linearizeMemrefLoad(loc, sourceType, srcBits, dstBits, indices,
+ stridedMetadata, rewriter);
+
+ lastIdx = adaptor.getIndices().back();
+ }
+
+ // Get the offset and shift the bits to the rightmost.
+ // Note, currently only the big-endian is supported.
+ auto castLastIdx =
+ rewriter.create<arith::IndexCastUIOp>(loc, srcElementType, lastIdx);
+
+ Value BitwidthOffset =
+ getOffsetForBitwidth(loc, castLastIdx, srcBits, dstBits, rewriter);
+ auto bitsLoad =
+ rewriter.create<arith::ShRSIOp>(loc, newLoad, BitwidthOffset);
+
+ // Get the corresponding bits. If the arith computation bitwidth equals
+ // to the emulated bitwidth, we apply a mask to extract the low bits.
+ // It is not clear if this case actually happens in practice, but we keep
+ // the operations just in case. Otherwise, if the arith computation bitwidth
+ // is
diff erent from the emulated bitwidth we truncate the result.
+ Operation *result;
+ auto resultTy = getTypeConverter()->convertType(oldElementType);
+ if (resultTy == srcElementType) {
+ auto mask = rewriter.create<arith::ConstantOp>(
+ loc, srcElementType,
+ rewriter.getIntegerAttr(srcElementType, (1 << srcBits) - 1));
+
+ result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
+ } else {
+ result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
+ }
+
+ rewriter.replaceOp(op, result->getResult(0));
+ return success();
+ }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+void memref::populateMemRefNarrowTypeEmulationPatterns(
+ arith::NarrowTypeEmulationConverter &typeConverter,
+ RewritePatternSet &patterns) {
+
+ // Populate `memref.*` conversion patterns.
+ patterns
+ .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
+ typeConverter, patterns.getContext());
+}
+
+void memref::populateMemRefNarrowTypeEmulationConversions(
+ arith::NarrowTypeEmulationConverter &typeConverter) {
+ typeConverter.addConversion(
+ [&typeConverter](MemRefType ty) -> std::optional<Type> {
+ auto intTy = dyn_cast<IntegerType>(ty.getElementType());
+ if (!intTy)
+ return ty;
+
+ unsigned width = intTy.getWidth();
+ unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
+ if (width >= loadStoreWidth)
+ return ty;
+
+ auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
+ intTy.getSignedness());
+ if (!newElemTy)
+ return std::nullopt;
+
+ return ty.cloneWith(std::nullopt, newElemTy);
+ });
+}
diff --git a/mlir/test/Dialect/Arith/emulate-narrow-type.mlir b/mlir/test/Dialect/Arith/emulate-narrow-type.mlir
new file mode 100644
index 00000000000000..7120882b0c07c0
--- /dev/null
+++ b/mlir/test/Dialect/Arith/emulate-narrow-type.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8" %s | FileCheck %s
+
+// Expect no conversions, f32 is not an integer type.
+// CHECK-LABEL: func @identity_f32
+// CHECK-SAME: ([[ARG:%.+]]: f32) -> f32
+// CHECK-NEXT: return [[ARG]] : f32
+func.func @identity_f32(%a : f32) -> f32 {
+ return %a : f32
+}
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @identity_i32
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: return [[ARG]] : vector<2xi32>
+func.func @identity_i32(%a : vector<2xi32>) -> vector<2xi32> {
+ return %a : vector<2xi32>
+}
+
+// CHECK-LABEL: func @identity_scalar
+// CHECK-SAME: ([[ARG:%.+]]: i8) -> i8
+// CHECK-NEXT: return [[ARG]] : i8
+func.func @identity_scalar(%x : i4) -> i4 {
+ return %x : i4
+}
+
+// CHECK-LABEL: func @identity_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT: return [[ARG]] : vector<4xi8>
+func.func @identity_vector(%x : vector<4xi4>) -> vector<4xi4> {
+ return %x : vector<4xi4>
+}
+
+// CHECK-LABEL: func @identity_vector2d
+// CHECK-SAME: ([[ARG:%.+]]: vector<3x4xi8>) -> vector<3x4xi8>
+// CHECK-NEXT: return [[ARG]] : vector<3x4xi8>
+func.func @identity_vector2d(%x : vector<3x4xi4>) -> vector<3x4xi4> {
+ return %x : vector<3x4xi4>
+}
+
+// CHECK-LABEL: func @call
+// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT: return [[RES]] : vector<4xi8>
+func.func @call(%a : vector<4xi4>) -> vector<4xi4> {
+ %res = func.call @identity_vector(%a) : (vector<4xi4>) -> vector<4xi4>
+ return %res : vector<4xi4>
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir
new file mode 100644
index 00000000000000..85d4cc18c0d38f
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @memref_i32
+// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi32, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT: return
+func.func @memref_i32() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : i32
+ %m = memref.alloc() : memref<4xi32, 1>
+ %v = memref.load %m[%c0] : memref<4xi32, 1>
+ memref.store %c1, %m[%c0] : memref<4xi32, 1>
+ return
+}
+
+// -----
+
+// Expect no conversions, f32 is not an integer type.
+// CHECK-LABEL: func @memref_f32
+// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT: return
+func.func @memref_f32() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1.0 : f32
+ %m = memref.alloc() : memref<4xf32, 1>
+ %v = memref.load %m[%c0] : memref<4xf32, 1>
+ memref.store %c1, %m[%c0] : memref<4xf32, 1>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_zero_rank
+// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<i8>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[M]] : memref<i8> -> memref<i8>, index
+// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[M]][] : memref<i8>
+// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[OFFSET]] : index to i8
+// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT: return
+func.func @memref_load_i4_zero_rank() {
+ %0 = memref.alloc() : memref<i4>
+ %1 = memref.load %0[] : memref<i4>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4
+// CHECK-SAME: (%[[ARG:.*]]: index)
+// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
+// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
+// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
+// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT: return
+func.func @memref_load_i4(%arg0: index) {
+ %0 = memref.alloc() : memref<4xi4>
+ %1 = memref.load %0[%arg0] : memref<4xi4>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_rank2
+// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
+// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
+// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT: return
+func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
+ memref.assume_alignment %0, 64 : memref<4x128xi4>
+ %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
+ return
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
new file mode 100644
index 00000000000000..9d63b9d1acd08f
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8" %s | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+
+// Expect no conversions.
+// CHECK-LABEL: func @memref_i8
+// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi8, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi8, 1>
+// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi8, 1>
+// CHECK-NEXT: return
+func.func @memref_i8() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : i8
+ %m = memref.alloc() : memref<4xi8, 1>
+ %v = memref.load %m[%c0] : memref<4xi8, 1>
+ memref.store %c1, %m[%c0] : memref<4xi8, 1>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4
+// CHECK-SAME: (%[[ARG:.*]]: index)
+// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
+// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
+// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
+// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8
+// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
+// CHECK-NEXT: return
+func.func @memref_load_i4(%arg0: index) {
+ %0 = memref.alloc() : memref<4xi4>
+ %1 = memref.load %0[%arg0] : memref<4xi4>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_rank2
+// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
+// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
+// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8
+// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
+// CHECK-NEXT: return
+func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
+ memref.assume_alignment %0, 64 : memref<4x128xi4>
+ %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
+ return
+}
diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
index df3fdacd09358f..0498de3eb93178 100644
--- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMemRefTestPasses
TestComposeSubView.cpp
+ TestEmulateNarrowType.cpp
TestMultiBuffer.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
new file mode 100644
index 00000000000000..b1f23084449c2b
--- /dev/null
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -0,0 +1,118 @@
+//===- TestEmulateNarrowType.cpp - Test Narrow Type Emulation ------*- c++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestEmulateNarrowTypePass
+ : public PassWrapper<TestEmulateNarrowTypePass,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowTypePass)
+
+ TestEmulateNarrowTypePass() = default;
+ TestEmulateNarrowTypePass(const TestEmulateNarrowTypePass &pass)
+ : PassWrapper(pass) {}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
+ vector::VectorDialect, affine::AffineDialect>();
+ }
+ StringRef getArgument() const final { return "test-emulate-narrow-int"; }
+ StringRef getDescription() const final {
+ return "Function pass to test Narrow Integer Emulation";
+ }
+
+ void runOnOperation() override {
+ if (!llvm::isPowerOf2_32(loadStoreEmulateBitwidth) ||
+ loadStoreEmulateBitwidth < 8) {
+ signalPassFailure();
+ return;
+ }
+
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+
+ arith::NarrowTypeEmulationConverter typeConverter(loadStoreEmulateBitwidth);
+
+ // Convert scalar type.
+ typeConverter.addConversion([this](IntegerType ty) -> std::optional<Type> {
+ unsigned width = ty.getWidth();
+ if (width >= arithComputeBitwidth)
+ return ty;
+
+ return IntegerType::get(ty.getContext(), arithComputeBitwidth);
+ });
+
+ // Convert vector type.
+ typeConverter.addConversion([this](VectorType ty) -> std::optional<Type> {
+ auto intTy = dyn_cast<IntegerType>(ty.getElementType());
+ if (!intTy)
+ return ty;
+
+ unsigned width = intTy.getWidth();
+ if (width >= arithComputeBitwidth)
+ return ty;
+
+ return VectorType::get(
+ to_vector(ty.getShape()),
+ IntegerType::get(ty.getContext(), arithComputeBitwidth));
+ });
+
+ memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
+ ConversionTarget target(*ctx);
+ target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
+ return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
+ });
+ auto opLegalCallback = [&typeConverter](Operation *op) {
+ return typeConverter.isLegal(op);
+ };
+ target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
+ target.addDynamicallyLegalDialect<
+ arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
+ affine::AffineDialect>(
+ [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+ RewritePatternSet patterns(ctx);
+
+ arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
+ memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
+
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+ }
+
+ Option<unsigned> loadStoreEmulateBitwidth{
+ *this, "memref-load-bitwidth",
+ llvm::cl::desc("memref load/store emulation bit width"),
+ llvm::cl::init(8)};
+
+ Option<unsigned> arithComputeBitwidth{
+ *this, "arith-compute-bitwidth",
+ llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)};
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestEmulateNarrowTypePass() {
+ PassRegistration<TestEmulateNarrowTypePass>();
+}
+} // namespace mlir::test
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index d75b54e8bd45ac..5b956635b960a1 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -87,6 +87,7 @@ void registerTestDiagnosticsPass();
void registerTestDialectConversionPasses();
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
+void registerTestEmulateNarrowTypePass();
void registerTestExpandMathPass();
void registerTestFooAnalysisPass();
void registerTestComposeSubView();
@@ -205,6 +206,7 @@ void registerTestPasses() {
mlir::test::registerTestDeadCodeAnalysisPass();
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();
+ mlir::test::registerTestEmulateNarrowTypePass();
mlir::test::registerTestExpandMathPass();
mlir::test::registerTestFooAnalysisPass();
mlir::test::registerTestComposeSubView();
More information about the Mlir-commits
mailing list