[Mlir-commits] [mlir] [MLIR][MemRef] Extend narrow-type emulation for dynamic offsets (PR #196945)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 11 06:21:29 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: Alan Li (lialan)

<details>
<summary>Changes</summary>

This patch adds three related extensions to the MemRef narrow-type emulation patterns.

* `ConvertMemRefSubview` now accepts a dynamic innermost offset.
* `ConvertMemRefReinterpretCast` is generalized from the previous static-rank-1, static-offset shape to accept any rank and dynamic offsets, with the same alignment contract as the subview pattern.
* A new `ConvertMemRefCast` pattern handles `memref.cast` between equivalent narrow-typed memref types so that emulation does not get blocked by trivial casts.

---
Full diff: https://github.com/llvm/llvm-project/pull/196945.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+98-47) 
- (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+100) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 8686f22c9e3c2..71aa345121bbd 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -32,15 +32,21 @@ using namespace mlir;
 //===----------------------------------------------------------------------===//
 
 /// Converts a memref::ReinterpretCastOp to the converted type. The result
-/// MemRefType of the old op must have a rank and stride of 1, with static
-/// offset and size. The number of bits in the offset must evenly divide the
-/// bitwidth of the new converted type.
+/// memref is linearized to a rank-1 byte view (or rank-0 if the source is
+/// rank-0). Dynamic offsets are accepted under the alignment contract that
+/// the caller guarantees the offset is a multiple of `dstBits / srcBits`;
+/// statically-provable misalignment is rejected.
 static LogicalResult
 convertCastingOp(ConversionPatternRewriter &rewriter,
                  memref::ReinterpretCastOp::Adaptor adaptor,
                  memref::ReinterpretCastOp op, MemRefType newTy) {
-  auto convertedElementType = newTy.getElementType();
-  auto oldElementType = op.getType().getElementType();
+  if (newTy == op.getType()) {
+    return rewriter.notifyMatchFailure(
+        op, "result type was not converted by narrow-type emulation");
+  }
+
+  Type convertedElementType = newTy.getElementType();
+  Type oldElementType = op.getType().getElementType();
   int srcBits = oldElementType.getIntOrFloatBitWidth();
   int dstBits = convertedElementType.getIntOrFloatBitWidth();
   if (dstBits % srcBits != 0) {
@@ -48,35 +54,54 @@ convertCastingOp(ConversionPatternRewriter &rewriter,
                                        "only dstBits % srcBits == 0 supported");
   }
 
-  // Only support stride of 1.
-  if (llvm::any_of(op.getStaticStrides(),
-                   [](int64_t stride) { return stride != 1; })) {
+  ArrayRef<int64_t> staticStrides = op.getStaticStrides();
+  if (!staticStrides.empty() && staticStrides.back() != 1) {
     return rewriter.notifyMatchFailure(op->getLoc(),
-                                       "stride != 1 is not supported");
+                                       "innermost stride != 1 is not supported");
   }
 
-  auto sizes = op.getStaticSizes();
-  int64_t offset = op.getStaticOffset(0);
-  // Only support static sizes and offsets.
-  if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
-      offset == ShapedType::kDynamic) {
-    return rewriter.notifyMatchFailure(
-        op, "dynamic size or offset is not supported");
+  if (llvm::is_contained(op.getStaticSizes(), ShapedType::kDynamic)) {
+    return rewriter.notifyMatchFailure(op, "dynamic sizes are not supported");
   }
 
-  int elementsPerByte = dstBits / srcBits;
-  if (offset % elementsPerByte != 0) {
+  if (!memref::isStaticShapeAndContiguousRowMajor(op.getType())) {
     return rewriter.notifyMatchFailure(
-        op, "offset not multiple of elementsPerByte is not supported");
+        op, "result memref is not row-major contiguous");
   }
 
-  SmallVector<int64_t> size;
-  if (!sizes.empty())
-    size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
-  offset = offset / elementsPerByte;
+  Location loc = op.getLoc();
+  SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+  OpFoldResult origOffset = op.getMixedOffsets()[0];
+
+  SmallVector<OpFoldResult> newSizes;
+  SmallVector<OpFoldResult> newStrides;
+  OpFoldResult newOffset;
+  OpFoldResult intraOffset;
+  if (mixedSizes.empty()) {
+    int64_t elementsPerByte = dstBits / srcBits;
+    AffineExpr s0;
+    bindSymbols(rewriter.getContext(), s0);
+    newOffset = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, s0.floorDiv(elementsPerByte), {origOffset});
+    intraOffset = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, s0 % elementsPerByte, {origOffset});
+  } else {
+    memref::LinearizedMemRefInfo info =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits, origOffset, mixedSizes);
+    newOffset = info.linearizedOffset;
+    intraOffset = info.intraDataOffset;
+    newSizes.push_back(info.linearizedSize);
+    newStrides.push_back(rewriter.getIndexAttr(1));
+  }
+
+  if (auto cst = getConstantIntValue(intraOffset); cst && *cst != 0) {
+    return rewriter.notifyMatchFailure(
+        op, "offset is provably not a multiple of dstBits / srcBits");
+  }
 
   rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
-      op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
+      op, newTy, adaptor.getSource(), newOffset, newSizes, newStrides);
   return success();
 }
 
@@ -349,6 +374,32 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemRefCast
+//===----------------------------------------------------------------------===//
+
+/// `memref.cast` between two narrow-typed memrefs forwards through the type
+/// converter to a cast between the converted byte-typed memrefs.
+struct ConvertMemRefCast final : OpConversionPattern<memref::CastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type newTy = getTypeConverter()->convertType(op.getType());
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+    }
+    if (newTy == op.getType())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<memref::CastOp>(op, newTy, adaptor.getSource());
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMemRefMemorySpaceCast
 //===----------------------------------------------------------------------===//
@@ -377,8 +428,7 @@ struct ConvertMemRefMemorySpaceCast final
 // ConvertMemRefReinterpretCast
 //===----------------------------------------------------------------------===//
 
-/// Output types should be at most one dimensional, so only the 0 or 1
-/// dimensional cases are supported.
+/// Forwards to `convertCastingOp`, which enforces all preconditions.
 struct ConvertMemRefReinterpretCast final
     : OpConversionPattern<memref::ReinterpretCastOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -394,12 +444,6 @@ struct ConvertMemRefReinterpretCast final
           llvm::formatv("failed to convert memref type: {0}", op.getType()));
     }
 
-    // Only support for 0 or 1 dimensional cases.
-    if (op.getType().getRank() > 1) {
-      return rewriter.notifyMatchFailure(
-          op->getLoc(), "subview with rank > 1 is not supported");
-    }
-
     return convertCastingOp(rewriter, adaptor, op, newTy);
   }
 };
@@ -503,9 +547,11 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
 //===----------------------------------------------------------------------===//
 
 /// Emulating narrow ints on subview have limited support, supporting only
-/// static offset and size and stride of 1. Ideally, the subview should be
-/// folded away before running narrow type emulation, and this pattern should
-/// only run for cases that can't be folded.
+/// static sizes and stride of 1. Dynamic offsets are accepted under the
+/// alignment contract that the caller guarantees the offset is a multiple of
+/// `dstBits / srcBits`. Ideally, the subview should be folded away before
+/// running narrow type emulation, and this pattern should only run for cases
+/// that can't be folded.
 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -543,12 +589,9 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
     }
 
     auto sizes = subViewOp.getStaticSizes();
-    int64_t lastOffset = subViewOp.getStaticOffsets().back();
-    // Only support static sizes and offsets.
-    if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
-        lastOffset == ShapedType::kDynamic) {
-      return rewriter.notifyMatchFailure(
-          subViewOp->getLoc(), "dynamic size or offset is not supported");
+    if (llvm::is_contained(sizes, ShapedType::kDynamic)) {
+      return rewriter.notifyMatchFailure(subViewOp->getLoc(),
+                                         "dynamic size is not supported");
     }
 
     // Transform the offsets, sizes and strides according to the emulation.
@@ -566,6 +609,13 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
             getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
                            rewriter));
 
+    if (auto cst = getConstantIntValue(linearizedInfo.intraDataOffset);
+        cst && *cst != 0) {
+      return rewriter.notifyMatchFailure(
+          subViewOp,
+          "subview offset is provably not a multiple of dstBits / srcBits");
+    }
+
     rewriter.replaceOpWithNewOp<memref::SubViewOp>(
         subViewOp, newTy, adaptor.getSource(), linearizedIndices,
         linearizedInfo.linearizedSize, strides.back());
@@ -634,12 +684,13 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
 
   // Populate `memref.*` conversion patterns.
   patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
-               ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
-               ConvertMemRefDealloc, ConvertMemRefCollapseShape,
-               ConvertMemRefExpandShape, ConvertMemRefLoad,
-               ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
-               ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
-      typeConverter, patterns.getContext());
+               ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCast,
+               ConvertMemRefCopy, ConvertMemRefDealloc,
+               ConvertMemRefCollapseShape, ConvertMemRefExpandShape,
+               ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
+               ConvertMemRefMemorySpaceCast, ConvertMemRefSubview,
+               ConvertMemRefReinterpretCast>(typeConverter,
+                                             patterns.getContext());
   patterns.insert<ConvertMemrefStore>(typeConverter, patterns.getContext(),
                                       disableAtomicRMW);
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index dd64ecc98721a..b4de76cd6d29c 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -238,6 +238,51 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
 
 // -----
 
+func.func @memref_subview_dynamic_inner_offset_i4(%off: index) -> i4 {
+  %c0 = arith.constant 0 : index
+  %arr = memref.alloc() : memref<128xi4>
+  %subview = memref.subview %arr[%off] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset: ?>>
+  %ld = memref.load %subview[%c0] : memref<32xi4, strided<[1], offset: ?>>
+  return %ld : i4
+}
+
+// CHECK-LABEL:   func.func @memref_subview_dynamic_inner_offset_i4(
+// CHECK-SAME:        %[[OFF:[a-zA-Z0-9_]+]]: index
+// CHECK:           %[[ALLOC:.+]] = memref.alloc() : memref<64xi8>
+// CHECK:           %[[IDX:.+]] = affine.apply {{.*}}%[[OFF]]
+// CHECK:           %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][%[[IDX]]] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: ?>>
+// CHECK:           memref.load %[[SUBVIEW]]
+
+// CHECK32-LABEL:   func.func @memref_subview_dynamic_inner_offset_i4(
+// CHECK32-SAME:        %[[OFF:[a-zA-Z0-9_]+]]: index
+// CHECK32:           %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
+// CHECK32:           %[[IDX:.+]] = affine.apply {{.*}}%[[OFF]]
+// CHECK32:           %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][%[[IDX]]] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: ?>>
+// CHECK32:           memref.load %[[SUBVIEW]]
+
+// -----
+
+// Dynamic innermost offset that is provably aligned (multiple of
+// `dstBits / srcBits`). The affine simplifier folds the `floordiv` away.
+
+func.func @memref_subview_aligned_dynamic_inner_offset_i4(%x: index) -> i4 {
+  %c0 = arith.constant 0 : index
+  %off = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%x]
+  %arr = memref.alloc() : memref<128xi4>
+  %subview = memref.subview %arr[%off] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset: ?>>
+  %ld = memref.load %subview[%c0] : memref<32xi4, strided<[1], offset: ?>>
+  return %ld : i4
+}
+
+// CHECK-LABEL:   func.func @memref_subview_aligned_dynamic_inner_offset_i4(
+// CHECK-SAME:        %[[X:[a-zA-Z0-9_]+]]: index
+// CHECK:           %[[ALLOC:.+]] = memref.alloc() : memref<64xi8>
+// CHECK-NOT:       affine.apply
+// CHECK:           %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][%[[X]]] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: ?>>
+// CHECK:           memref.load %[[SUBVIEW]]
+
+// -----
+
 func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
   %c0 = arith.constant 0 : index
   %arr = memref.alloc() : memref<40x40xi4>
@@ -249,6 +294,61 @@ func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
 
 // -----
 
+// Rank-3 reinterpret_cast on a sub-byte (i4) memref with a static, aligned
+// offset.
+
+func.func @reinterpret_cast_memref_rank3_static_offset_i4(%arg0: memref<2x4x8xi4>) -> memref<4x4x8xi4, strided<[32, 8, 1]>> {
+  %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [4, 4, 8], strides: [32, 8, 1] : memref<2x4x8xi4> to memref<4x4x8xi4, strided<[32, 8, 1]>>
+  return %r : memref<4x4x8xi4, strided<[32, 8, 1]>>
+}
+
+// CHECK-LABEL:   func @reinterpret_cast_memref_rank3_static_offset_i4(
+// CHECK-SAME:        %[[ARG0:.+]]: memref<32xi8>
+// CHECK:           %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [64], strides: [1] : memref<32xi8> to memref<64xi8>
+// CHECK:           return %[[R]]
+
+// CHECK32-LABEL:   func @reinterpret_cast_memref_rank3_static_offset_i4(
+// CHECK32-SAME:        %[[ARG0:.+]]: memref<8xi32>
+// CHECK32:           %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [16], strides: [1] : memref<8xi32> to memref<16xi32>
+// CHECK32:           return %[[R]]
+
+// -----
+
+// Rank-3 reinterpret_cast with a dynamic offset accepted under the alignment
+// contract.
+
+func.func @reinterpret_cast_memref_rank3_dynamic_offset_i4(%arg0: memref<2x4x8xi4>, %off: index) -> memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>> {
+  %r = memref.reinterpret_cast %arg0 to offset: [%off], sizes: [4, 4, 8], strides: [32, 8, 1] : memref<2x4x8xi4> to memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>>
+  return %r : memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>>
+}
+
+// CHECK-LABEL:   func @reinterpret_cast_memref_rank3_dynamic_offset_i4(
+// CHECK-SAME:        %[[ARG0:.+]]: memref<32xi8>,
+// CHECK-SAME:        %[[OFF:.+]]: index
+// CHECK:           %[[NEWOFF:.+]] = affine.apply {{.*}}%[[OFF]]
+// CHECK:           %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[NEWOFF]]{{\]}}, sizes: [64], strides: [1] : memref<32xi8> to memref<64xi8, strided<[1], offset: ?>>
+// CHECK:           return %[[R]]
+
+// CHECK32-LABEL:   func @reinterpret_cast_memref_rank3_dynamic_offset_i4(
+// CHECK32-SAME:        %[[ARG0:.+]]: memref<8xi32>,
+// CHECK32-SAME:        %[[OFF:.+]]: index
+// CHECK32:           %[[NEWOFF:.+]] = affine.apply {{.*}}%[[OFF]]
+// CHECK32:           %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[NEWOFF]]{{\]}}, sizes: [16], strides: [1] : memref<8xi32> to memref<16xi32, strided<[1], offset: ?>>
+// CHECK32:           return %[[R]]
+
+// -----
+
+// Provably-misaligned static offset (1 is not a multiple of i4 -> i8 ratio
+// of 2). Lowering must fail.
+
+func.func @negative_reinterpret_cast_memref_misaligned_static_offset_i4(%arg0: memref<2x4x8xi4>) -> memref<4x4x8xi4, strided<[32, 8, 1], offset: 1>> {
+  // expected-error @+1 {{failed to legalize operation 'memref.reinterpret_cast' that was explicitly marked illegal}}
+  %r = memref.reinterpret_cast %arg0 to offset: [1], sizes: [4, 4, 8], strides: [32, 8, 1] : memref<2x4x8xi4> to memref<4x4x8xi4, strided<[32, 8, 1], offset: 1>>
+  return %r : memref<4x4x8xi4, strided<[32, 8, 1], offset: 1>>
+}
+
+// -----
+
 func.func @reinterpret_cast_memref_load_0D() -> i4 {
     %0 = memref.alloc() : memref<5xi4>
     %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>

``````````

</details>


https://github.com/llvm/llvm-project/pull/196945


More information about the Mlir-commits mailing list