[Mlir-commits] [mlir] [MLIR][MemRef] Extend narrow-type emulation for dynamic offsets (PR #196945)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 11 06:21:29 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Alan Li (lialan)
<details>
<summary>Changes</summary>
This patch adds three related extensions to the MemRef narrow-type emulation patterns.
* `ConvertMemRefSubview` now accepts a dynamic innermost offset.
* `ConvertMemRefReinterpretCast` is generalized from the previous static-rank-1, static-offset shape to accept any rank and dynamic offsets, with the same alignment contract as the subview pattern.
* A new `ConvertMemRefCast` pattern handles `memref.cast` between equivalent narrow-typed memref types so that emulation does not get blocked by trivial casts.
---
Full diff: https://github.com/llvm/llvm-project/pull/196945.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+98-47)
- (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+100)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 8686f22c9e3c2..71aa345121bbd 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -32,15 +32,21 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
/// Converts a memref::ReinterpretCastOp to the converted type. The result
-/// MemRefType of the old op must have a rank and stride of 1, with static
-/// offset and size. The number of bits in the offset must evenly divide the
-/// bitwidth of the new converted type.
+/// memref is linearized to a rank-1 byte view (or rank-0 if the source is
+/// rank-0). Dynamic offsets are accepted under the alignment contract that
+/// the caller guarantees the offset is a multiple of `dstBits / srcBits`;
+/// statically-provable misalignment is rejected.
static LogicalResult
convertCastingOp(ConversionPatternRewriter &rewriter,
memref::ReinterpretCastOp::Adaptor adaptor,
memref::ReinterpretCastOp op, MemRefType newTy) {
- auto convertedElementType = newTy.getElementType();
- auto oldElementType = op.getType().getElementType();
+ if (newTy == op.getType()) {
+ return rewriter.notifyMatchFailure(
+ op, "result type was not converted by narrow-type emulation");
+ }
+
+ Type convertedElementType = newTy.getElementType();
+ Type oldElementType = op.getType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0) {
@@ -48,35 +54,54 @@ convertCastingOp(ConversionPatternRewriter &rewriter,
"only dstBits % srcBits == 0 supported");
}
- // Only support stride of 1.
- if (llvm::any_of(op.getStaticStrides(),
- [](int64_t stride) { return stride != 1; })) {
+ ArrayRef<int64_t> staticStrides = op.getStaticStrides();
+ if (!staticStrides.empty() && staticStrides.back() != 1) {
return rewriter.notifyMatchFailure(op->getLoc(),
- "stride != 1 is not supported");
+ "innermost stride != 1 is not supported");
}
- auto sizes = op.getStaticSizes();
- int64_t offset = op.getStaticOffset(0);
- // Only support static sizes and offsets.
- if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
- offset == ShapedType::kDynamic) {
- return rewriter.notifyMatchFailure(
- op, "dynamic size or offset is not supported");
+ if (llvm::is_contained(op.getStaticSizes(), ShapedType::kDynamic)) {
+ return rewriter.notifyMatchFailure(op, "dynamic sizes are not supported");
}
- int elementsPerByte = dstBits / srcBits;
- if (offset % elementsPerByte != 0) {
+ if (!memref::isStaticShapeAndContiguousRowMajor(op.getType())) {
return rewriter.notifyMatchFailure(
- op, "offset not multiple of elementsPerByte is not supported");
+ op, "result memref is not row-major contiguous");
}
- SmallVector<int64_t> size;
- if (!sizes.empty())
- size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
- offset = offset / elementsPerByte;
+ Location loc = op.getLoc();
+ SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ OpFoldResult origOffset = op.getMixedOffsets()[0];
+
+ SmallVector<OpFoldResult> newSizes;
+ SmallVector<OpFoldResult> newStrides;
+ OpFoldResult newOffset;
+ OpFoldResult intraOffset;
+ if (mixedSizes.empty()) {
+ int64_t elementsPerByte = dstBits / srcBits;
+ AffineExpr s0;
+ bindSymbols(rewriter.getContext(), s0);
+ newOffset = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0.floorDiv(elementsPerByte), {origOffset});
+ intraOffset = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0 % elementsPerByte, {origOffset});
+ } else {
+ memref::LinearizedMemRefInfo info =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits, origOffset, mixedSizes);
+ newOffset = info.linearizedOffset;
+ intraOffset = info.intraDataOffset;
+ newSizes.push_back(info.linearizedSize);
+ newStrides.push_back(rewriter.getIndexAttr(1));
+ }
+
+ if (auto cst = getConstantIntValue(intraOffset); cst && *cst != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "offset is provably not a multiple of dstBits / srcBits");
+ }
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
- op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
+ op, newTy, adaptor.getSource(), newOffset, newSizes, newStrides);
return success();
}
@@ -349,6 +374,32 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefCast
+//===----------------------------------------------------------------------===//
+
+/// `memref.cast` between two narrow-typed memrefs forwards through the type
+/// converter to a cast between the converted byte-typed memrefs.
+struct ConvertMemRefCast final : OpConversionPattern<memref::CastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getType());
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ }
+ if (newTy == op.getType())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<memref::CastOp>(op, newTy, adaptor.getSource());
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefMemorySpaceCast
//===----------------------------------------------------------------------===//
@@ -377,8 +428,7 @@ struct ConvertMemRefMemorySpaceCast final
// ConvertMemRefReinterpretCast
//===----------------------------------------------------------------------===//
-/// Output types should be at most one dimensional, so only the 0 or 1
-/// dimensional cases are supported.
+/// Forwards to `convertCastingOp`, which enforces all preconditions.
struct ConvertMemRefReinterpretCast final
: OpConversionPattern<memref::ReinterpretCastOp> {
using OpConversionPattern::OpConversionPattern;
@@ -394,12 +444,6 @@ struct ConvertMemRefReinterpretCast final
llvm::formatv("failed to convert memref type: {0}", op.getType()));
}
- // Only support for 0 or 1 dimensional cases.
- if (op.getType().getRank() > 1) {
- return rewriter.notifyMatchFailure(
- op->getLoc(), "subview with rank > 1 is not supported");
- }
-
return convertCastingOp(rewriter, adaptor, op, newTy);
}
};
@@ -503,9 +547,11 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
//===----------------------------------------------------------------------===//
/// Emulating narrow ints on subview have limited support, supporting only
-/// static offset and size and stride of 1. Ideally, the subview should be
-/// folded away before running narrow type emulation, and this pattern should
-/// only run for cases that can't be folded.
+/// static sizes and stride of 1. Dynamic offsets are accepted under the
+/// alignment contract that the caller guarantees the offset is a multiple of
+/// `dstBits / srcBits`. Ideally, the subview should be folded away before
+/// running narrow type emulation, and this pattern should only run for cases
+/// that can't be folded.
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
using OpConversionPattern::OpConversionPattern;
@@ -543,12 +589,9 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
}
auto sizes = subViewOp.getStaticSizes();
- int64_t lastOffset = subViewOp.getStaticOffsets().back();
- // Only support static sizes and offsets.
- if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
- lastOffset == ShapedType::kDynamic) {
- return rewriter.notifyMatchFailure(
- subViewOp->getLoc(), "dynamic size or offset is not supported");
+ if (llvm::is_contained(sizes, ShapedType::kDynamic)) {
+ return rewriter.notifyMatchFailure(subViewOp->getLoc(),
+ "dynamic size is not supported");
}
// Transform the offsets, sizes and strides according to the emulation.
@@ -566,6 +609,13 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
rewriter));
+ if (auto cst = getConstantIntValue(linearizedInfo.intraDataOffset);
+ cst && *cst != 0) {
+ return rewriter.notifyMatchFailure(
+ subViewOp,
+ "subview offset is provably not a multiple of dstBits / srcBits");
+ }
+
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
subViewOp, newTy, adaptor.getSource(), linearizedIndices,
linearizedInfo.linearizedSize, strides.back());
@@ -634,12 +684,13 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
- ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
- ConvertMemRefDealloc, ConvertMemRefCollapseShape,
- ConvertMemRefExpandShape, ConvertMemRefLoad,
- ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
- ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
- typeConverter, patterns.getContext());
+ ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCast,
+ ConvertMemRefCopy, ConvertMemRefDealloc,
+ ConvertMemRefCollapseShape, ConvertMemRefExpandShape,
+ ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
+ ConvertMemRefMemorySpaceCast, ConvertMemRefSubview,
+ ConvertMemRefReinterpretCast>(typeConverter,
+ patterns.getContext());
patterns.insert<ConvertMemrefStore>(typeConverter, patterns.getContext(),
disableAtomicRMW);
memref::populateResolveExtractStridedMetadataPatterns(patterns);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index dd64ecc98721a..b4de76cd6d29c 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -238,6 +238,51 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
// -----
+func.func @memref_subview_dynamic_inner_offset_i4(%off: index) -> i4 {
+ %c0 = arith.constant 0 : index
+ %arr = memref.alloc() : memref<128xi4>
+ %subview = memref.subview %arr[%off] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset: ?>>
+ %ld = memref.load %subview[%c0] : memref<32xi4, strided<[1], offset: ?>>
+ return %ld : i4
+}
+
+// CHECK-LABEL: func.func @memref_subview_dynamic_inner_offset_i4(
+// CHECK-SAME: %[[OFF:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64xi8>
+// CHECK: %[[IDX:.+]] = affine.apply {{.*}}%[[OFF]]
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][%[[IDX]]] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: ?>>
+// CHECK: memref.load %[[SUBVIEW]]
+
+// CHECK32-LABEL: func.func @memref_subview_dynamic_inner_offset_i4(
+// CHECK32-SAME: %[[OFF:[a-zA-Z0-9_]+]]: index
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
+// CHECK32: %[[IDX:.+]] = affine.apply {{.*}}%[[OFF]]
+// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][%[[IDX]]] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: ?>>
+// CHECK32: memref.load %[[SUBVIEW]]
+
+// -----
+
+// Dynamic innermost offset that is provably aligned (multiple of
+// `dstBits / srcBits`). The affine simplifier folds the `floordiv` away.
+
+func.func @memref_subview_aligned_dynamic_inner_offset_i4(%x: index) -> i4 {
+ %c0 = arith.constant 0 : index
+ %off = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%x]
+ %arr = memref.alloc() : memref<128xi4>
+ %subview = memref.subview %arr[%off] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset: ?>>
+ %ld = memref.load %subview[%c0] : memref<32xi4, strided<[1], offset: ?>>
+ return %ld : i4
+}
+
+// CHECK-LABEL: func.func @memref_subview_aligned_dynamic_inner_offset_i4(
+// CHECK-SAME: %[[X:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64xi8>
+// CHECK-NOT: affine.apply
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][%[[X]]] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: ?>>
+// CHECK: memref.load %[[SUBVIEW]]
+
+// -----
+
func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
%c0 = arith.constant 0 : index
%arr = memref.alloc() : memref<40x40xi4>
@@ -249,6 +294,61 @@ func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
// -----
+// Rank-3 reinterpret_cast on a sub-byte (i4) memref with a static, aligned
+// offset.
+
+func.func @reinterpret_cast_memref_rank3_static_offset_i4(%arg0: memref<2x4x8xi4>) -> memref<4x4x8xi4, strided<[32, 8, 1]>> {
+ %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [4, 4, 8], strides: [32, 8, 1] : memref<2x4x8xi4> to memref<4x4x8xi4, strided<[32, 8, 1]>>
+ return %r : memref<4x4x8xi4, strided<[32, 8, 1]>>
+}
+
+// CHECK-LABEL: func @reinterpret_cast_memref_rank3_static_offset_i4(
+// CHECK-SAME: %[[ARG0:.+]]: memref<32xi8>
+// CHECK: %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [64], strides: [1] : memref<32xi8> to memref<64xi8>
+// CHECK: return %[[R]]
+
+// CHECK32-LABEL: func @reinterpret_cast_memref_rank3_static_offset_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: memref<8xi32>
+// CHECK32: %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [16], strides: [1] : memref<8xi32> to memref<16xi32>
+// CHECK32: return %[[R]]
+
+// -----
+
+// Rank-3 reinterpret_cast with a dynamic offset accepted under the alignment
+// contract.
+
+func.func @reinterpret_cast_memref_rank3_dynamic_offset_i4(%arg0: memref<2x4x8xi4>, %off: index) -> memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>> {
+ %r = memref.reinterpret_cast %arg0 to offset: [%off], sizes: [4, 4, 8], strides: [32, 8, 1] : memref<2x4x8xi4> to memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>>
+ return %r : memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>>
+}
+
+// CHECK-LABEL: func @reinterpret_cast_memref_rank3_dynamic_offset_i4(
+// CHECK-SAME: %[[ARG0:.+]]: memref<32xi8>,
+// CHECK-SAME: %[[OFF:.+]]: index
+// CHECK: %[[NEWOFF:.+]] = affine.apply {{.*}}%[[OFF]]
+// CHECK: %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[NEWOFF]]{{\]}}, sizes: [64], strides: [1] : memref<32xi8> to memref<64xi8, strided<[1], offset: ?>>
+// CHECK: return %[[R]]
+
+// CHECK32-LABEL: func @reinterpret_cast_memref_rank3_dynamic_offset_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: memref<8xi32>,
+// CHECK32-SAME: %[[OFF:.+]]: index
+// CHECK32: %[[NEWOFF:.+]] = affine.apply {{.*}}%[[OFF]]
+// CHECK32: %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[NEWOFF]]{{\]}}, sizes: [16], strides: [1] : memref<8xi32> to memref<16xi32, strided<[1], offset: ?>>
+// CHECK32: return %[[R]]
+
+// -----
+
+// Provably-misaligned static offset (1 is not a multiple of i4 -> i8 ratio
+// of 2). Lowering must fail.
+
+func.func @negative_reinterpret_cast_memref_misaligned_static_offset_i4(%arg0: memref<2x4x8xi4>) -> memref<4x4x8xi4, strided<[32, 8, 1], offset: 1>> {
+ // expected-error @+1 {{failed to legalize operation 'memref.reinterpret_cast' that was explicitly marked illegal}}
+ %r = memref.reinterpret_cast %arg0 to offset: [1], sizes: [4, 4, 8], strides: [32, 8, 1] : memref<2x4x8xi4> to memref<4x4x8xi4, strided<[32, 8, 1], offset: 1>>
+ return %r : memref<4x4x8xi4, strided<[32, 8, 1], offset: 1>>
+}
+
+// -----
+
func.func @reinterpret_cast_memref_load_0D() -> i4 {
%0 = memref.alloc() : memref<5xi4>
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
``````````
</details>
https://github.com/llvm/llvm-project/pull/196945
More information about the Mlir-commits
mailing list