[Mlir-commits] [mlir] 0f8bab8 - [mlir] Revamp implementation of sub-byte load/store emulation.
Mahesh Ravishankar
llvmlistbot at llvm.org
Thu Aug 17 13:28:06 PDT 2023
Author: Mahesh Ravishankar
Date: 2023-08-17T20:27:53Z
New Revision: 0f8bab8d590ed0cb5402f72009b32cbad115f013
URL: https://github.com/llvm/llvm-project/commit/0f8bab8d590ed0cb5402f72009b32cbad115f013
DIFF: https://github.com/llvm/llvm-project/commit/0f8bab8d590ed0cb5402f72009b32cbad115f013.diff
LOG: [mlir] Revamp implementation of sub-byte load/store emulation.
When handling sub-byte emulation, the sizes of the converted `memref`s
also need to be updated (this was not done in the current
implementation). This adds the additional complexity of having to
linearize the `memref`s as well. Consider a `memref<3x3xi4>` where the
`i4` elements are packed. This has a overall size of 5 bytes (rounded
up to number of bytes). This can only be represented by a
`memref<5xi8>`. A `memref<3x2xi8>` would imply an implicit padding of
4 bits at the end of each row. So incorporate linearization into the
sub-byte load-store emulation.
This patch also updates some of the utility functions to make better
use of statically available information using `OpFoldResult` and
`makeComposedFoldedAffineApplyOps`.
Reviewed By: hanchung, yzhang93
Differential Revision: https://reviews.llvm.org/D158125
Added:
mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Modified:
mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
Removed:
mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir
mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index e1504f030defa6..9ff009b8c27bf6 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -28,36 +28,37 @@ namespace memref {
/// contiguous chunk of memory.
bool isStaticShapeAndContiguousRowMajor(MemRefType type);
-/// Returns the flattened 1-D memref and linearized offset for narrow type
-/// emulation.
-///
-/// The emulation only works on 1D memref types. To make this work on N-D
-/// memref, we need to linearize the offset.
-///
-/// For example, to emulate i4 to i8, the following op:
-///
-/// %0 = memref.load %arg0[%v0, %v1] :
-/// memref<?x?xi4, strided<[?, ?], offset: ?>>
-///
-/// can be replaced with
-///
-/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
-///
-/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
-/// %linearized_size = %size0 * %size1
-/// %scaled_linear_offset = %linearized_offset / 8 * 4
-/// %scaled_base_offset = %offset / 8 * 4
-///
-/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
-/// sizes = [%linearized_size], strides = [%stride#1]
-///
-/// %new_load = memref.load %linearized[%scaled_linear_offset] :
-/// memref<?xi8, strided<[?], offset: ?>>
-std::pair<Value, Value>
-getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits,
- int dstBits, SmallVector<Value> indices,
- memref::ExtractStridedMetadataOp stridedMetadata,
- OpBuilder &builder);
+/// For a `memref` with `offset`, `sizes` and `strides`, returns the
+/// offset and size to use for the linearized `memref`.
+/// - If the linearization is done for emulating load/stores of
+/// element type with bitwidth `srcBits` using element type with
+/// bitwidth `dstBits`, the linearized offset and size are
+/// scaled down by `dstBits`/`srcBits`.
+/// - If `indices` is provided, it represents the position in the
+/// original `memref` being accessed. The method then returns the
+/// index to use in the linearized `memref`. The linearized index
+/// is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided
+/// 0, is returned for the linearized index.
+struct LinearizedMemRefInfo {
+ OpFoldResult linearizedOffset;
+ OpFoldResult linearizedSize;
+};
+std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
+ OpBuilder &builder, Location loc, int srcBits, int dstBits,
+ OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices = {});
+
+/// For a `memref` with `offset` and `sizes`, returns the
+/// offset and size to use for the linearized `memref`, assuming that
+/// the strides are computed from a row-major ordering of the sizes;
+/// - If the linearization is done for emulating load/stores of
+/// element type with bitwidth `srcBits` using element type with
+/// bitwidth `dstBits`, the linearized offset and size are
+/// scaled down by `dstBits`/`srcBits`.
+LinearizedMemRefInfo
+getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
+ int dstBits, OpFoldResult offset,
+ ArrayRef<OpFoldResult> sizes);
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index fa5c90c24551da..2a524ceb9db887 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -35,18 +35,18 @@ using namespace mlir;
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
/// located at (x % 2) * 4. Because there are two elements in one i8, and one
/// element has 4 bits.
-static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
- int targetBits, OpBuilder &builder) {
+static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
+ int sourceBits, int targetBits,
+ OpBuilder &builder) {
assert(targetBits % sourceBits == 0);
- IntegerType targetType = builder.getIntegerType(targetBits);
- IntegerAttr idxAttr =
- builder.getIntegerAttr(targetType, targetBits / sourceBits);
- auto idx = builder.create<arith::ConstantOp>(loc, targetType, idxAttr);
- IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
- auto srcBitsValue =
- builder.create<arith::ConstantOp>(loc, targetType, srcBitsAttr);
- auto m = builder.create<arith::RemUIOp>(loc, srcIdx, idx);
- return builder.create<arith::MulIOp>(loc, targetType, m, srcBitsValue);
+ AffineExpr s0;
+ bindSymbols(builder.getContext(), s0);
+ int scaleFactor = targetBits / sourceBits;
+ OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
+ builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+ Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
+ IntegerType dstType = builder.getIntegerType(targetBits);
+ return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
}
namespace {
@@ -61,15 +61,43 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
LogicalResult
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type newTy = getTypeConverter()->convertType(op.getType());
- if (!newTy) {
+ auto currentType = op.getMemref().getType().cast<MemRefType>();
+ auto newResultType =
+ getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+ if (!newResultType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert memref type: {0}", op.getType()));
}
+ // Special case zero-rank memrefs.
+ if (currentType.getRank() == 0) {
+ rewriter.replaceOpWithNewOp<memref::AllocOp>(
+ op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
+ return success();
+ }
+
+ Location loc = op.getLoc();
+ OpFoldResult zero = rewriter.getIndexAttr(0);
+ SmallVector<OpFoldResult> indices(currentType.getRank(), zero);
+
+ // Get linearized type.
+ int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
+ int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
+ SmallVector<OpFoldResult> sizes = op.getMixedSizes();
+
+ memref::LinearizedMemRefInfo linearizedMemRefInfo =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
+ SmallVector<Value> dynamicLinearizedSize;
+ if (!newResultType.hasStaticShape()) {
+ dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
+ rewriter, loc, linearizedMemRefInfo.linearizedSize));
+ }
+
rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
+ op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
adaptor.getAlignmentAttr());
return success();
}
@@ -109,62 +137,57 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type newTy = getTypeConverter()->convertType(op.getMemRefType());
- if (!newTy) {
- return rewriter.notifyMatchFailure(
- op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
- op.getMemRefType()));
- }
-
- if (op.getMemRefType() == newTy)
- return failure();
-
- auto loc = op.getLoc();
- auto sourceType = cast<MemRefType>(adaptor.getMemref().getType());
- unsigned sourceRank = sourceType.getRank();
- SmallVector<Value> indices = adaptor.getIndices();
- assert(indices.size() == sourceRank);
-
- auto srcElementType = sourceType.getElementType();
+ auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+ auto convertedElementType = convertedType.getElementType();
auto oldElementType = op.getMemRefType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = srcElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}
- auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, adaptor.getMemref());
-
- Value newLoad, lastIdx;
- if (sourceRank == 0) {
- newLoad = rewriter.create<memref::LoadOp>(
- loc, srcElementType, adaptor.getMemref(), adaptor.getIndices());
-
- lastIdx = stridedMetadata.getOffset();
+ Location loc = op.getLoc();
+ // Special case 0-rank memref loads.
+ Value bitsLoad;
+ if (convertedType.getRank() == 0) {
+ bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
+ ValueRange{});
} else {
- auto [reinterpret, linearizedOffset] =
- memref::getLinearizeMemRefAndOffset(loc, sourceType, srcBits, dstBits,
- adaptor.getIndices(),
- stridedMetadata, rewriter);
-
- newLoad = rewriter.create<memref::LoadOp>(loc, srcElementType,
- reinterpret, linearizedOffset);
-
- lastIdx = adaptor.getIndices().back();
+ SmallVector<OpFoldResult> indices =
+ getAsOpFoldResult(adaptor.getIndices());
+
+ auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, op.getMemRef());
+
+ // Linearize the indices of the original load instruction. Do not account
+ // for the scaling yet. This will be accounted for later.
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, srcBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(), indices);
+
+ AffineExpr s0;
+ bindSymbols(rewriter.getContext(), s0);
+ int64_t scaler = dstBits / srcBits;
+ OpFoldResult scaledLinearizedIndices =
+ affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+ Value newLoad = rewriter.create<memref::LoadOp>(
+ loc, adaptor.getMemref(),
+ getValueOrCreateConstantIndexOp(rewriter, loc,
+ scaledLinearizedIndices));
+
+ // Get the offset and shift the bits to the rightmost.
+ // Note, currently only the big-endian is supported.
+ Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
+ srcBits, dstBits, rewriter);
+ bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
}
- // Get the offset and shift the bits to the rightmost.
- // Note, currently only the big-endian is supported.
- auto castLastIdx =
- rewriter.create<arith::IndexCastUIOp>(loc, srcElementType, lastIdx);
-
- Value BitwidthOffset =
- getOffsetForBitwidth(loc, castLastIdx, srcBits, dstBits, rewriter);
- auto bitsLoad =
- rewriter.create<arith::ShRSIOp>(loc, newLoad, BitwidthOffset);
-
// Get the corresponding bits. If the arith computation bitwidth equals
// to the emulated bitwidth, we apply a mask to extract the low bits.
// It is not clear if this case actually happens in practice, but we keep
@@ -172,10 +195,10 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// is
diff erent from the emulated bitwidth we truncate the result.
Operation *result;
auto resultTy = getTypeConverter()->convertType(oldElementType);
- if (resultTy == srcElementType) {
+ if (resultTy == convertedElementType) {
auto mask = rewriter.create<arith::ConstantOp>(
- loc, srcElementType,
- rewriter.getIntegerAttr(srcElementType, (1 << srcBits) - 1));
+ loc, convertedElementType,
+ rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
} else {
@@ -200,6 +223,25 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
patterns
.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
typeConverter, patterns.getContext());
+ memref::populateResolveExtractStridedMetadataPatterns(patterns);
+}
+
+static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
+ int dstBits) {
+ if (ty.getRank() == 0)
+ return {};
+
+ int64_t linearizedShape = 1;
+ for (auto shape : ty.getShape()) {
+ if (shape == ShapedType::kDynamic)
+ return {ShapedType::kDynamic};
+ linearizedShape *= shape;
+ }
+ int scale = dstBits / srcBits;
+ // Scale the size to the ceilDiv(linearizedShape, scale)
+ // to accomodate all the values.
+ linearizedShape = (linearizedShape + scale - 1) / scale;
+ return {linearizedShape};
}
void memref::populateMemRefNarrowTypeEmulationConversions(
@@ -215,11 +257,26 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
if (width >= loadStoreWidth)
return ty;
+ // Currently only handle innermost stride being 1, checking
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(getStridesAndOffset(ty, strides, offset)))
+ return std::nullopt;
+ if (!strides.empty() && strides.back() != 1)
+ return std::nullopt;
+
auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
intTy.getSignedness());
if (!newElemTy)
return std::nullopt;
- return ty.cloneWith(std::nullopt, newElemTy);
+ StridedLayoutAttr layoutAttr;
+ if (offset != 0) {
+ layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
+ ArrayRef<int64_t>{1});
+ }
+
+ return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
+ newElemTy, layoutAttr, ty.getMemorySpace());
});
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index ff2c4107ee46dc..672ef3eb4cd50f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -687,13 +687,17 @@ struct ExtractStridedMetadataOpAllocFolder
auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
int64_t offset = 0;
- if (allocLikeOp.getType() == baseBufferType)
- results.push_back(allocLikeOp);
- else
- results.push_back(rewriter.create<memref::ReinterpretCastOp>(
- loc, baseBufferType, allocLikeOp, offset,
- /*sizes=*/ArrayRef<int64_t>(),
- /*strides=*/ArrayRef<int64_t>()));
+ if (op.getBaseBuffer().use_empty()) {
+ results.push_back(nullptr);
+ } else {
+ if (allocLikeOp.getType() == baseBufferType)
+ results.push_back(allocLikeOp);
+ else
+ results.push_back(rewriter.create<memref::ReinterpretCastOp>(
+ loc, baseBufferType, allocLikeOp, offset,
+ /*sizes=*/ArrayRef<int64_t>(),
+ /*strides=*/ArrayRef<int64_t>()));
+ }
// Offset.
results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index f6e3f97f455bee..e640248af6e499 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -46,79 +46,78 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
return curDim < 0;
}
-std::pair<Value, Value>
-getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits,
- int dstBits, SmallVector<Value> indices,
- memref::ExtractStridedMetadataOp stridedMetadata,
- OpBuilder &builder) {
- auto srcElementType = sourceType.getElementType();
- unsigned sourceRank = indices.size();
-
- Value baseBuffer = stridedMetadata.getBaseBuffer();
- SmallVector<Value> baseSizes = stridedMetadata.getSizes();
- SmallVector<Value> baseStrides = stridedMetadata.getStrides();
- Value baseOffset = stridedMetadata.getOffset();
- assert(indices.size() == baseStrides.size());
+std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
+ OpBuilder &builder, Location loc, int srcBits, int dstBits,
+ OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices) {
+ unsigned sourceRank = sizes.size();
+ assert(sizes.size() == strides.size() &&
+ "expected as many sizes as strides for a memref");
+ SmallVector indicesVec = llvm::to_vector(indices);
+ if (indices.empty())
+ indicesVec.resize(sourceRank, builder.getIndexAttr(0));
+ assert(indicesVec.size() == strides.size() &&
+ "expected as many indices as rank of memref");
// Create the affine symbols and values for linearization.
- SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
+ SmallVector<AffineExpr> symbols(2 * sourceRank);
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
- symbols[0] = builder.getAffineSymbolExpr(0);
- AffineExpr addMulMap = symbols.front();
- AffineExpr mulMap = symbols.front();
+ AffineExpr addMulMap = builder.getAffineConstantExpr(0);
+ AffineExpr mulMap = builder.getAffineConstantExpr(1);
- SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
- offsetValues[0] = builder.getIndexAttr(0);
- SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
- sizeValues[0] = builder.getIndexAttr(1);
+ SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
+ SmallVector<OpFoldResult> sizeValues(sourceRank);
for (unsigned i = 0; i < sourceRank; ++i) {
- unsigned offsetIdx = 2 * i + 1;
+ unsigned offsetIdx = 2 * i;
addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
- offsetValues[offsetIdx] = indices[i];
- offsetValues[offsetIdx + 1] = baseStrides[i];
+ offsetValues[offsetIdx] = indicesVec[i];
+ offsetValues[offsetIdx + 1] = strides[i];
- unsigned sizeIdx = i + 1;
- mulMap = mulMap * symbols[sizeIdx];
- sizeValues[sizeIdx] = baseSizes[i];
+ mulMap = mulMap * symbols[i];
}
- // Adjust linearizedOffset by the scale factor (dstBits / srcBits).
- OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
- AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
- offsetValues.back() = scaler;
+ // Adjust linearizedIndices, size and offset by the scale factor (dstBits /
+ // srcBits).
+ int64_t scaler = dstBits / srcBits;
+ addMulMap = addMulMap.floorDiv(scaler);
+ mulMap = mulMap.floorDiv(scaler);
- OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
- builder, loc, scaledAddMulMap, offsetValues);
+ OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
+ builder, loc, addMulMap, offsetValues);
OpFoldResult linearizedSize =
- affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);
+ affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
// Adjust baseOffset by the scale factor (dstBits / srcBits).
- AffineExpr s0, s1;
- bindSymbols(builder.getContext(), s0, s1);
+ AffineExpr s0;
+ bindSymbols(builder.getContext(), s0);
OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
- builder, loc, s0.floorDiv(s1), {baseOffset, scaler});
-
- // Flatten n-D MemRef to 1-D MemRef.
- std::optional<int64_t> stride =
- getConstantIntValue(stridedMetadata.getConstifiedMixedStrides().back());
- auto layoutAttr =
- StridedLayoutAttr::get(sourceType.getContext(), ShapedType::kDynamic,
- {stride ? stride.value() : ShapedType::kDynamic});
- int64_t staticShape = sourceType.hasStaticShape()
- ? sourceType.getNumElements()
- : ShapedType::kDynamic;
- auto flattenMemrefType = MemRefType::get(
- staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());
-
- auto reinterpret = builder.create<memref::ReinterpretCastOp>(
- loc, flattenMemrefType, baseBuffer,
- getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
- getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
- baseStrides.back());
-
- return std::make_pair(reinterpret, getValueOrCreateConstantIndexOp(
- builder, loc, linearizedOffset));
+ builder, loc, s0.floorDiv(scaler), {offset});
+
+ return {{adjustBaseOffset, linearizedSize}, linearizedIndices};
+}
+
+LinearizedMemRefInfo
+getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
+ int dstBits, OpFoldResult offset,
+ ArrayRef<OpFoldResult> sizes) {
+ SmallVector<OpFoldResult> strides(sizes.size());
+ if (sizes.size() > 0) {
+ strides.back() = builder.getIndexAttr(1);
+ AffineExpr s0, s1;
+ bindSymbols(builder.getContext(), s0, s1);
+ for (int index = sizes.size() - 1; index > 0; --index) {
+ strides[index - 1] = affine::makeComposedFoldedAffineApply(
+ builder, loc, s0 * s1,
+ ArrayRef<OpFoldResult>{strides[index], sizes[index]});
+ }
+ }
+
+ LinearizedMemRefInfo linearizedMemRefInfo;
+ std::tie(linearizedMemRefInfo, std::ignore) =
+ getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset,
+ sizes, strides);
+ return linearizedMemRefInfo;
}
} // namespace memref
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 7e747aaa450ab5..01ee43354a711e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -69,19 +69,24 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
if (origElements % scale != 0)
return failure();
- auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, adaptor.getBase());
-
- auto [reinterpret, linearizedOffset] = memref::getLinearizeMemRefAndOffset(
- loc, sourceType, srcBits, dstBits, adaptor.getIndices(),
- stridedMetadata, rewriter);
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(),
+ getAsOpFoldResult(adaptor.getIndices()));
auto srcElementType = sourceType.getElementType();
auto numElements =
static_cast<int>(std::ceil(static_cast<double>(origElements) / scale));
auto newLoad = rewriter.create<vector::LoadOp>(
- loc, VectorType::get(numElements, srcElementType), reinterpret,
- linearizedOffset);
+ loc, VectorType::get(numElements, srcElementType), adaptor.getBase(),
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
numElements *= scale;
auto castType = VectorType::get(numElements, oldElementType);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir
deleted file mode 100644
index f2db905370a588..00000000000000
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type-
diff -load-compute.mlir
+++ /dev/null
@@ -1,107 +0,0 @@
-// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
-// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
-
-// Expect no conversions, i32 is supported.
-// CHECK-LABEL: func @memref_i32
-// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi32, 1>
-// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1>
-// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1>
-// CHECK-NEXT: return
-func.func @memref_i32() {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : i32
- %m = memref.alloc() : memref<4xi32, 1>
- %v = memref.load %m[%c0] : memref<4xi32, 1>
- memref.store %c1, %m[%c0] : memref<4xi32, 1>
- return
-}
-
-// -----
-
-// Expect no conversions, f32 is not an integer type.
-// CHECK-LABEL: func @memref_f32
-// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1>
-// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1>
-// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1>
-// CHECK-NEXT: return
-func.func @memref_f32() {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1.0 : f32
- %m = memref.alloc() : memref<4xf32, 1>
- %v = memref.load %m[%c0] : memref<4xf32, 1>
- memref.store %c1, %m[%c0] : memref<4xf32, 1>
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @memref_load_i4_zero_rank
-// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<i8>
-// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[M]] : memref<i8> -> memref<i8>, index
-// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[M]][] : memref<i8>
-// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[OFFSET]] : index to i8
-// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
-// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
-// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
-// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
-// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
-// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
-// CHECK-NEXT: return
-func.func @memref_load_i4_zero_rank() {
- %0 = memref.alloc() : memref<i4>
- %1 = memref.load %0[] : memref<i4>
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @memref_load_i4
-// CHECK-SAME: (%[[ARG:.*]]: index)
-// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8>
-// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
-// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
-// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
-// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[1], offset: ?>>
-// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[1], offset: ?>>
-// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
-// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
-// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
-// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
-// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
-// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
-// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
-// CHECK-NEXT: return
-func.func @memref_load_i4(%arg0: index) {
- %0 = memref.alloc() : memref<4xi4>
- %1 = memref.load %0[%arg0] : memref<4xi4>
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @memref_load_i4_rank2
-// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
-// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
-// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
-// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
-// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
-// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
-// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[1], offset: ?>>
-// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[1], offset: ?>>
-// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
-// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
-// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
-// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
-// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
-// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
-// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
-// CHECK-NEXT: return
-func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
- memref.assume_alignment %0, 64 : memref<4x128xi4>
- %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
- return
-}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
deleted file mode 100644
index 19625e2f2beb17..00000000000000
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
+++ /dev/null
@@ -1,72 +0,0 @@
-// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8" %s | FileCheck %s
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
-// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
-
-// Expect no conversions.
-// CHECK-LABEL: func @memref_i8
-// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi8, 1>
-// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi8, 1>
-// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi8, 1>
-// CHECK-NEXT: return
-func.func @memref_i8() {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : i8
- %m = memref.alloc() : memref<4xi8, 1>
- %v = memref.load %m[%c0] : memref<4xi8, 1>
- memref.store %c1, %m[%c0] : memref<4xi8, 1>
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @memref_load_i4
-// CHECK-SAME: (%[[ARG:.*]]: index)
-// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8>
-// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
-// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
-// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
-// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[1], offset: ?>>
-// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[1], offset: ?>>
-// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
-// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
-// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
-// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
-// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
-// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
-// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8
-// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
-// CHECK-NEXT: return
-func.func @memref_load_i4(%arg0: index) {
- %0 = memref.alloc() : memref<4xi4>
- %1 = memref.load %0[%arg0] : memref<4xi4>
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @memref_load_i4_rank2
-// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
-// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
-// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
-// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
-// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
-// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
-// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[1], offset: ?>>
-// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[1], offset: ?>>
-// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
-// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
-// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
-// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
-// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
-// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
-// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8
-// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
-// CHECK-NEXT: return
-func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
- memref.assume_alignment %0, 64 : memref<4x128xi4>
- %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
- return
-}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
new file mode 100644
index 00000000000000..c0050d8c510d53
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -0,0 +1,157 @@
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+
+// Expect no conversions.
+func.func @memref_i8() -> i8 {
+ %c3 = arith.constant 3 : index
+ %m = memref.alloc() : memref<4xi8, 1>
+ %v = memref.load %m[%c3] : memref<4xi8, 1>
+ return %v : i8
+}
+// CHECK-LABEL: func @memref_i8()
+// CHECK: %[[M:.+]] = memref.alloc() : memref<4xi8, 1>
+// CHECK-NEXT: %[[V:.+]] = memref.load %[[M]][%{{.+}}] : memref<4xi8, 1>
+// CHECK-NEXT: return %[[V]]
+
+// CHECK32-LABEL: func @memref_i8()
+// CHECK32: %[[M:.+]] = memref.alloc() : memref<1xi32, 1>
+// CHECK32: %[[C0:.+]] = arith.constant 0 : index
+// CHECK32: %[[V:.+]] = memref.load %[[M]][%[[C0]]] : memref<1xi32, 1>
+// CHECK32: %[[C24:.+]] = arith.constant 24 : index
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[C24]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[V]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i8
+// CHECK32-NEXT: return %[[TRUNC]]
+
+// -----
+
+func.func @memref_load_i4(%arg0: index) -> i4 {
+ %0 = memref.alloc() : memref<5xi4>
+ %1 = memref.load %0[%arg0] : memref<5xi4>
+ return %1 : i4
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+// CHECK: func @memref_load_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: return %[[TRUNC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+// CHECK32: func @memref_load_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: return %[[TRUNC]]
+
+// -----
+
+func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
+ %0 = memref.alloc() : memref<3x125xi4>
+ memref.assume_alignment %0, 64 : memref<3x125xi4>
+ %1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4>
+ return %1 : i4
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)
+// CHECK: func @memref_load_i4_rank2(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
+// CHECK: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: return %[[TRUNC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)
+// CHECK32: func @memref_load_i4_rank2(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
+// CHECK32: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: return %[[TRUNC]]
+
+// -----
+
+func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index) -> i4 {
+ %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
+ %1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
+ return %1 : i4
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
+// CHECK: func @memref_load_i4_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: return %[[TRUNC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
+// CHECK32: func @memref_load_i4_dynamic(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: return %[[TRUNC]]
+
+// -----
+
+func.func @rank_zero_memref() -> i4 {
+ %0 = memref.alloc() : memref<i4>
+ %1 = memref.load %0[] : memref<i4>
+ return %1 : i4
+}
+// CHECK-LABEL: func @rank_zero_memref()
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref<i8>
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i8 to i4
+// CHECK: return %[[TRUNC]]
+
+// CHECK32-LABEL: func @rank_zero_memref()
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref<i32>
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
+// CHECK32: return %[[TRUNC]]
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 71c133778087d6..e3c6b098b70ba4 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,36 +1,81 @@
-// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s
-
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
-// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
+ %0 = memref.alloc() : memref<3x4xi8>
+ %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
+ return %1 : vector<4xi8>
+}
// Expect no conversions, i8 is supported.
-// CHECK-LABEL: func @vector_load_i8
-// CHECK-SAME: (%[[ARG:.*]]: memref<3x4xi8>, %[[IDX0:.*]]: index, %[[IDX1:.*]]: index)
-// CHECK-NEXT: [[L:%.+]] = vector.load %[[ARG]][%[[IDX0]], %[[IDX1]]] : memref<3x4xi8>, vector<4xi8>
-// CHECK-NEXT: return
-func.func @vector_load_i8(%arg0: memref<3x4xi8>, %arg1: index, %arg2: index) {
- %0 = vector.load %arg0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
- return
+// CHECK: func @vector_load_i8(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
+// CHECK-NEXT: [[L:%.+]] = vector.load %[[ALLOC]][%[[ARG0]], %[[ARG1]]] : memref<3x4xi8>, vector<4xi8>
+// CHECK-NEXT: return
+
+// CHECK32: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
+// CHECK32: func @vector_load_i8(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[VECLOAD:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
+// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VECLOAD]] : vector<1xi32> to vector<4xi8>
+// CHECK32: return %[[VEC_I4]]
+
+// -----
+
+func.func @vector_load_i4(%arg1: index, %arg2: index) -> vector<3x8xi4> {
+ %0 = memref.alloc() : memref<3x8xi4>
+ %cst = arith.constant dense<0> : vector<3x8xi4>
+ %1 = vector.load %0[%arg1, %arg2] : memref<3x8xi4>, vector<8xi4>
+ %2 = vector.insert %1, %cst [0] : vector<8xi4> into vector<3x8xi4>
+ return %2 : vector<3x8xi4>
}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_load_i4
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<12xi8>, vector<4xi8>
+// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_load_i4
+// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
+// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
// -----
-// CHECK-LABEL: func @vector_load_i4
-// CHECK-SAME: (%[[ARG:.*]]: memref<3x4xi8>, %[[IDX0:.*]]: index, %[[IDX1:.*]]: index)
-// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<0> : vector<3x4xi4>
-// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x4xi8> -> memref<i8>, index, index, index, index, index
-// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[IDX0]], %[[STRIDES]]#0, %[[IDX1]], %[[STRIDES]]#1]
-// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
-// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
-// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<12xi8, strided<[1], offset: ?>>
-// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[CAST]][%[[INDEX]]] : memref<12xi8, strided<[1], offset: ?>>, vector<2xi8>
-// CHECK-NEXT: %[[BITCAST:.*]] = vector.bitcast %[[LOAD]] : vector<2xi8> to vector<4xi4>
-// CHECK-NEXT: %[[INSERT:.*]] = vector.insert %[[BITCAST]], %[[CST]] [0] : vector<4xi4> into vector<3x4xi4>
-// CHECK-NEXT: return
-func.func @vector_load_i4(%arg0: memref<3x4xi4>, %arg1: index, %arg2: index) {
- %cst = arith.constant dense<0> : vector<3x4xi4>
- %0 = vector.load %arg0[%arg1, %arg2] : memref<3x4xi4>, vector<4xi4>
- %1 = vector.insert %0, %cst [0] : vector<4xi4> into vector<3x4xi4>
- return
+func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xi4> {
+ %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
+ %1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
+ return %1 : vector<8xi4>
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+// CHECK: func.func @vector_load_i4_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
+// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+// CHECK32: func.func @vector_load_i4_dynamic(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
+// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index 64646b01b9f515..eeb26d1876c1db 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -89,8 +89,7 @@ struct TestEmulateNarrowTypePass
target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
target.addDynamicallyLegalDialect<
arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
- affine::AffineDialect>(
- [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+ affine::AffineDialect>(opLegalCallback);
RewritePatternSet patterns(ctx);
More information about the Mlir-commits
mailing list