[Mlir-commits] [mlir] [mlir][memref] Rewrite scalar `memref.copy` through reinterpret_cast into load/store (PR #186118)

ioana ghiban llvmlistbot at llvm.org
Thu Mar 19 02:43:25 PDT 2026


https://github.com/ioghiban updated https://github.com/llvm/llvm-project/pull/186118

>From b2d848761c5f4d089f8db2868ca890f3e1764cc8 Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Mon, 9 Mar 2026 20:11:27 +0100
Subject: [PATCH 1/2] Rewrite `memref.copy` of a 1-element MemRef as a scalar
 load-store pair

Assisted-by: ChatGPT (refine implementation + tests). I reviewed all code and tests before submission.
---
 .../mlir/Dialect/MemRef/Transforms/Passes.td  |  11 +
 .../Dialect/MemRef/Transforms/Transforms.h    |   4 +
 .../Dialect/MemRef/Transforms/CMakeLists.txt  |   1 +
 .../Transforms/ElideReinterpretCast.cpp       | 233 +++++++++++++++++
 .../MemRef/elide-reinterpret-cast.mlir        | 237 ++++++++++++++++++
 5 files changed, 486 insertions(+)
 create mode 100644 mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
 create mode 100644 mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index d04ae101fe1dc..08c22e6418fd3 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -11,6 +11,17 @@
 
 include "mlir/Pass/PassBase.td"
 
+def ElideReinterpretCastPass : Pass<"memref-elide-reinterpret-cast"> {
+  let summary = "Replace ops depending on redundant reinterpret_cast(s) to be "
+                "convertible to EmitC.";
+  let description = [{
+  Replace ops that move data and depend on redundant memref.reinterpret_cast(s)
+  to provide compatible shapes with other ops that achieve similar data movement,
+  without requiring memref.reinterpret_cast(s). This helps simplify the conversion
+  to EmitC.";
+}];
+}
+
 def ExpandOpsPass : Pass<"memref-expand"> {
   let summary = "Legalize memref operations to be convertible to LLVM.";
 }
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index b0069bce89588..654639ff4d06c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -39,6 +39,10 @@ class DeallocOp;
 // Patterns
 //===----------------------------------------------------------------------===//
 
+/// Collects a set of patterns to rewrite ops removing redundant reinterpret
+/// casts within the memref dialect.
+void populateElideReinterpretCastPatterns(RewritePatternSet &patterns);
+
 /// Collects a set of patterns to rewrite ops within the memref dialect.
 void populateExpandOpsPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 9049faccadef3..573c1cf8857f1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRMemRefTransforms
   AllocationOpInterfaceImpl.cpp
   BufferViewFlowOpInterfaceImpl.cpp
+  ElideReinterpretCast.cpp
   ComposeSubView.cpp
   ExpandOps.cpp
   ExpandRealloc.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
new file mode 100644
index 0000000000000..6f41c2e5aab3b
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
@@ -0,0 +1,233 @@
+//===-ElideReinterpretCast.cpp - Expansion patterns for MemRef operations-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <cassert>
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_ELIDEREINTERPRETCASTPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Returns true if `rc` represents a scalar view (all sizes == 1)
+/// into a memref that has exactly one non-unit dimension located at
+/// either the first or last position (i.e. a "row" or "column").
+///
+/// Examples that return true:
+///
+///   // Row-major slice (last dim is non-unit)
+///   memref.reinterpret_cast %buff to offset: [%off],
+///     sizes: [1, 1, 1], strides: [1, 1, 1]
+///     : memref<1x1x8xi32> to memref<1x1x1xi32>
+///
+///   // Column-major slice (first dim is non-unit)
+///   memref.reinterpret_cast %buff to offset: [%off],
+///     sizes: [1, 1], strides: [1, 1]
+///     : memref<2x1xf32> to memref<1x1xf32>
+///
+///   // Random strides
+///   memref.reinterpret_cast %buff to offset: [%off],
+///     sizes: [1, 1], strides: [10, 100]
+///     : memref<2x1xf32, strided<[10, 100]>>
+///         to memref<1x1xf32>
+///
+///   // Rank-1 case
+///   memref.reinterpret_cast %buf to offset: [%off],
+///     sizes: [1], strides: [1]
+///     : memref<8xi32> to memref<1xi32>
+///
+/// Examples that return false:
+///
+///   // More non-unit dims
+///   memref.reinterpret_cast %buff to offset: [%off],
+///     sizes: [1, 1, 1], strides: [1, 1, 1]
+///     : memref<1x2x8xi32> to memref<1x1x1xi32>
+///
+///   // View is not scalar (size != 1)
+///   memref.reinterpret_cast %buff to offset: [%off],
+///     sizes: [2, 1], strides: [1, 1]
+///     : memref<1x2xf32> to memref<2x1xf32>
+///
+///   // Base has non-identity layout
+///   %buff = memref.alloc() : memref<1x2xf32, strided<[1, 3]>>
+///   memref.reinterpret_cast %buff to offset: [%off],
+///     sizes: [1, 1], strides: [1, 1]
+///     : memref<1x2xf32, strided<[1, 3]>> to memref<1x1xf32>
+static bool isScalarSlice(memref::ReinterpretCastOp rc) {
+  auto rcInputTy = dyn_cast<MemRefType>(rc.getSource().getType());
+  auto rcOutputTy = dyn_cast<MemRefType>(rc.getType());
+
+  // Reject strided base - logic for computing linear idx is TODO
+  if (!rcInputTy.getLayout().isIdentity())
+    return false;
+
+  // Reject non-matching ranks
+  unsigned srcRank = rcInputTy.getRank();
+  if (srcRank != rcOutputTy.getRank())
+    return false;
+
+  ArrayRef<int64_t> sizes = rc.getStaticSizes();
+
+  // View must be scalar: memref<1x...x1>
+  if (!llvm::all_of(rcOutputTy.getShape(),
+                    [](int64_t dim) { return dim == 1; }))
+    return false;
+
+  // Sizes must all be statically 1
+  if (!llvm::all_of(sizes, [](int64_t size) {
+        return !ShapedType::isDynamic(size) && size == 1;
+      }))
+    return false;
+
+  // Rank-1 special case
+  if (srcRank == 1) {
+    // Reject non-scalar output
+    if (rcOutputTy.getDimSize(0) > 1)
+      return false;
+  }
+
+  int nonUnitDim = -1;
+
+  for (unsigned i = 0; i < srcRank; ++i) {
+    int64_t underlyingDim = rcInputTy.getDimSize(i);
+
+    // Must have at most one non-unit dimension
+    if (underlyingDim != 1) {
+      if (nonUnitDim != -1)
+        return false;
+      nonUnitDim = i;
+    }
+  }
+  return true;
+}
+
+/// Rewrites `memref.copy` of a 1-element MemRef as a scalar load-store pair
+///
+/// The pattern matches a reinterpret_cast that creates a scalar view
+/// (`sizes = [1, ..., 1]`) into a memref with a single non-unit dimension.
+/// Since the view contains only one element, the accessed address is
+/// determined solely by the base pointer and the offset.
+///
+/// Two layouts are supported:
+///   * row-major slice  (stride pattern [N, ..., 1])
+///   * column-major slice (stride pattern [1, ..., N])
+///
+/// BEFORE (row-major slice)
+///   %view = memref.reinterpret_cast %base
+///     to offset: [%off], sizes: [1, ..., 1], strides: [N, ..., 1]
+///       : memref<1x...xNxf32>
+///         to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
+///   memref.copy %src, %view
+///     : memref<1x...x1xf32>
+///       to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
+///
+/// AFTER
+///   %c0 = arith.constant 0 : index
+///   %v  = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
+///   memref.store %v, %base[%c0, ..., %off] : memref<1x...xNxf32>
+///
+/// BEFORE (column-major slice)
+///   %view = memref.reinterpret_cast %base
+///     to offset: [%off], sizes: [1, ..., 1], strides: [1, ..., N]
+///       : memref<Nx...x1xf32>
+///         to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
+///   memref.copy %src, %view
+///     : memref<1x...x1xf32>
+///       to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
+///
+/// AFTER
+///   %c0 = arith.constant 0 : index
+///   %v  = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
+///   memref.store %v, %base[%off, ..., %c0] : memref<Nx...x1xf32>
+struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::CopyOp op,
+                                PatternRewriter &rewriter) const final {
+    Value rcOutput = op.getTarget();
+    auto rc = rcOutput.getDefiningOp<memref::ReinterpretCastOp>();
+    if (!rc)
+      return rewriter.notifyMatchFailure(
+          op, "target is not a memref.reinterpret_cast");
+
+    if (!isScalarSlice(rc))
+      return rewriter.notifyMatchFailure(
+          op, "reinterpret_cast does not match scalar slice");
+
+    Location loc = op.getLoc();
+
+    Value src = op.getSource();
+    Value dst = rc.getSource();
+
+    auto dstType = cast<MemRefType>(dst.getType());
+    unsigned dstRank = dstType.getRank();
+
+    Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+    auto srcType = cast<MemRefType>(src.getType());
+    SmallVector<Value> loadIndices(srcType.getRank(), zero);
+    auto offsets = rc.getMixedOffsets();
+    assert(offsets.size() == 1 && "Expecting single offset");
+    OpFoldResult offset = offsets[0];
+    Value storeOffset = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
+    unsigned offsetDim = dstType.getDimSize(0) == 1 ? dstRank - 1 : 0;
+    SmallVector<Value> storeIndices(dstRank, zero);
+    storeIndices[offsetDim] = storeOffset;
+
+    if (rcOutput.hasOneUse())
+      rewriter.eraseOp(rc);
+
+    Value val = memref::LoadOp::create(rewriter, loc, src, loadIndices);
+    memref::StoreOp::create(rewriter, loc, val, dst, storeIndices);
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct ElideReinterpretCastPass
+    : public memref::impl::ElideReinterpretCastPassBase<
+          ElideReinterpretCastPass> {
+  void runOnOperation() override {
+    MLIRContext &ctx = getContext();
+
+    RewritePatternSet patterns(&ctx);
+    memref::populateElideReinterpretCastPatterns(patterns);
+    ConversionTarget target(ctx);
+    target.addDynamicallyLegalOp<memref::CopyOp>([](memref::CopyOp op) {
+      auto rc = op.getTarget().getDefiningOp<memref::ReinterpretCastOp>();
+      if (!rc)
+        return true;
+      return !isScalarSlice(rc);
+    });
+    target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+void mlir::memref::populateElideReinterpretCastPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<CopyToScalarLoadAndStore>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
new file mode 100644
index 0000000000000..39f5270cacef0
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
@@ -0,0 +1,237 @@
+// RUN: mlir-opt -memref-elide-reinterpret-cast %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @concat_zero_offset(
+// CHECK-SAME:   %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_zero_offset(%src : memref<1x1xf32>,
+  %dst : memref<1x108xf32>) {
+  /// reinterpret_cast removed
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1, 1], strides: [1, 1]
+    : memref<1x108xf32> to memref<1x1xf32>
+
+  /// Ensure copy was replaced
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32> to memref<1x1xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_nonzero_offset(
+// CHECK-SAME:   %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_nonzero_offset(%src : memref<1x1xf32>,
+  %dst : memref<1x108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [1], sizes: [1, 1], strides: [1, 1]
+    : memref<1x108xf32>
+      to memref<1x1xf32, strided<[1, 1], offset: 1>>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C1:.*]] = arith.constant 1 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C1]]] : memref<1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32>
+      to memref<1x1xf32, strided<[1, 1], offset: 1>>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_dynamic_offset(
+// CHECK-SAME:   %[[OFF:.*]]: index
+// CHECK-SAME:   %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_dynamic_offset(%offset: index, %src : memref<1x1xf32>,
+  %dst : memref<1x108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [%offset], sizes: [1, 1], strides: [1, 1]
+    : memref<1x108xf32>
+      to memref<1x1xf32, strided<[1, 1], offset: ?>>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: : memref<1x1xf32>
+  /// Dynamic offset used in store
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[OFF]]] : memref<1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32>
+      to memref<1x1xf32, strided<[1, 1], offset: ?>>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_strided(
+// CHECK-SAME:   %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_strided(%src : memref<1x1xf32>,
+  %dst : memref<1x108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1, 1], strides: [107, 2]
+    : memref<1x108xf32> to memref<1x1xf32, strided<[107, 2]>>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32> to memref<1x1xf32, strided<[107, 2]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_dynamic_stride(
+// CHECK-SAME:   %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME:   %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME:   %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x1xf32>
+// CHECK-SAME:   %[[DST:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xf32>
+func.func private @concat_dynamic_stride(%stride0: index,
+  %stride1: index, %src : memref<1x1xf32>, %dst : memref<1x108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
+    : memref<1x108xf32>
+      to memref<1x1xf32, strided<[?, ?]>>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+  /// Dynamic offset used in store
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32>
+      to memref<1x1xf32, strided<[?, ?]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_rank1(
+// CHECK-SAME:   %[[SRC:.*]]: memref<1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<108xf32>
+func.func private @concat_rank1(%src : memref<1xf32>, %dst : memref<108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1], strides: [1]
+    : memref<108xf32> to memref<1xf32>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xf32>
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0_0]]] : memref<108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1xf32> to memref<1xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_rank3(
+// CHECK-SAME:   %[[SRC:.*]]: memref<1x1x1xf32>
+// CHECK-SAME:   %[[DST:.*]]: memref<1x1x108xf32>
+func.func private @concat_rank3(%src : memref<1x1x1xf32>,
+  %dst : memref<1x1x108xf32>) {
+  // CHECK-NOT:  memref.reinterpret_cast
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1]
+    : memref<1x1x108xf32> to memref<1x1x1xf32>
+
+  // CHECK-NOT:  memref.copy
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32>
+  // CHECK:      memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0]], %[[C0_0]]] : memref<1x1x108xf32>
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1x1xf32> to memref<1x1x1xf32>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// Negative tests (must NOT rewrite)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @concat_rank3_more_non_unit_dims(
+func.func private @concat_rank3_more_non_unit_dims(%src : memref<1x1x1xf32>,
+  %dst : memref<1x2x108xf32>) {
+  // CHECK:      %reinterpret_cast = memref.reinterpret_cast %arg1
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1, 1, 1], strides: [0, 0, 0]
+    : memref<1x2x108xf32> to memref<1x1x1xf32, strided<[0, 0, 0]>>
+
+  // CHECK:      memref.copy %arg0, %reinterpret_cast
+  // CHECK-NOT:  memref.load
+  // CHECK-NOT:  memref.store
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1x1xf32> to memref<1x1x1xf32, strided<[0, 0, 0]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_strided_base(
+func.func private @concat_strided_base(%src: memref<1x1xf32>,
+  %dst: memref<8x1xf32, strided<[10, 2]>>) {
+  // CHECK:      %reinterpret_cast = memref.reinterpret_cast %arg1
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [6], sizes: [1, 1], strides: [11, 80]
+    : memref<8x1xf32, strided<[10, 2]>>
+      to memref<1x1xf32, strided<[11, 80], offset: 6>>
+
+  // CHECK:      memref.copy %arg0, %reinterpret_cast
+  // CHECK-NOT:  memref.load
+  // CHECK-NOT:  memref.store
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32> to memref<1x1xf32, strided<[11, 80], offset: 6>>
+
+  return
+}
+
+// CHECK-LABEL: func.func private @reshape_rank_change(
+func.func private @reshape_rank_change(%src : memref<2x3xf32>,
+  %dst : memref<6xf32>) {
+  // CHECK:      %reinterpret_cast = memref.reinterpret_cast %arg1
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [2, 3], strides: [3, 1]
+    : memref<6xf32> to memref<2x3xf32>
+
+  // CHECK:      memref.copy %arg0, %reinterpret_cast
+  // CHECK-NOT:  memref.load
+  // CHECK-NOT:  memref.store
+  memref.copy %src, %reinterpret_cast
+    : memref<2x3xf32> to memref<2x3xf32>
+  return
+}
+
+// CHECK-LABEL: func.func private @concat_multi_dim_vector(
+func.func private @concat_multi_dim_vector(%src : memref<1x1xf32>,
+  %dst : memref<2x108xf32>) {
+  // CHECK:      %reinterpret_cast = memref.reinterpret_cast %arg1
+  %reinterpret_cast = memref.reinterpret_cast %dst
+    to offset: [0], sizes: [1, 1], strides: [0, 0]
+    : memref<2x108xf32>
+      to memref<1x1xf32, strided<[0, 0]>>
+  // CHECK:      memref.copy %arg0, %reinterpret_cast
+  // CHECK-NOT:  memref.load
+  // CHECK-NOT:  memref.store
+  memref.copy %src, %reinterpret_cast
+    : memref<1x1xf32> to memref<1x1xf32, strided<[0, 0]>>
+  return
+}
+
+// CHECK-LABEL: func.func private @plain_copy(
+func.func private @plain_copy(%src : memref<1x1xf32>, %dst : memref<1x1xf32>) {
+  // CHECK:      memref.copy %arg0, %arg1
+  // CHECK-NOT:  memref.load
+  // CHECK-NOT:  memref.store
+  memref.copy %src, %dst
+  : memref<1x1xf32> to memref<1x1xf32>
+  return
+}

>From 6203cc7ba6a179c5945767aa20f7b559e82fbd4d Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Thu, 19 Mar 2026 10:42:53 +0100
Subject: [PATCH 2/2] fixup

---
 mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 573c1cf8857f1..1c5e07f89b338 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,8 +1,8 @@
 add_mlir_dialect_library(MLIRMemRefTransforms
   AllocationOpInterfaceImpl.cpp
   BufferViewFlowOpInterfaceImpl.cpp
-  ElideReinterpretCast.cpp
   ComposeSubView.cpp
+  ElideReinterpretCast.cpp
   ExpandOps.cpp
   ExpandRealloc.cpp
   ExpandStridedMetadata.cpp



More information about the Mlir-commits mailing list