[Mlir-commits] [mlir] [mlir][memref] Rewrite scalar `memref.copy` through reinterpret_cast into load/store (PR #186118)
ioana ghiban
llvmlistbot at llvm.org
Thu Mar 19 02:43:25 PDT 2026
https://github.com/ioghiban updated https://github.com/llvm/llvm-project/pull/186118
>From b2d848761c5f4d089f8db2868ca890f3e1764cc8 Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Mon, 9 Mar 2026 20:11:27 +0100
Subject: [PATCH 1/2] Rewrite `memref.copy` of a 1-element MemRef as a scalar
load-store pair
Assisted-by: ChatGPT (refine implementation + tests). I reviewed all code and tests before submission.
---
.../mlir/Dialect/MemRef/Transforms/Passes.td | 11 +
.../Dialect/MemRef/Transforms/Transforms.h | 4 +
.../Dialect/MemRef/Transforms/CMakeLists.txt | 1 +
.../Transforms/ElideReinterpretCast.cpp | 233 +++++++++++++++++
.../MemRef/elide-reinterpret-cast.mlir | 237 ++++++++++++++++++
5 files changed, 486 insertions(+)
create mode 100644 mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
create mode 100644 mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index d04ae101fe1dc..08c22e6418fd3 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -11,6 +11,17 @@
include "mlir/Pass/PassBase.td"
+def ElideReinterpretCastPass : Pass<"memref-elide-reinterpret-cast"> {
+ let summary = "Replace ops depending on redundant reinterpret_cast(s) to be "
+ "convertible to EmitC.";
+ let description = [{
+ Replace ops that move data and depend on redundant memref.reinterpret_cast(s)
+ to provide compatible shapes with other ops that achieve similar data movement,
+ without requiring memref.reinterpret_cast(s). This helps simplify the conversion
+ to EmitC.";
+}];
+}
+
def ExpandOpsPass : Pass<"memref-expand"> {
let summary = "Legalize memref operations to be convertible to LLVM.";
}
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index b0069bce89588..654639ff4d06c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -39,6 +39,10 @@ class DeallocOp;
// Patterns
//===----------------------------------------------------------------------===//
+/// Collects a set of patterns to rewrite ops removing redundant reinterpret
+/// casts within the memref dialect.
+void populateElideReinterpretCastPatterns(RewritePatternSet &patterns);
+
/// Collects a set of patterns to rewrite ops within the memref dialect.
void populateExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 9049faccadef3..573c1cf8857f1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRMemRefTransforms
AllocationOpInterfaceImpl.cpp
BufferViewFlowOpInterfaceImpl.cpp
+ ElideReinterpretCast.cpp
ComposeSubView.cpp
ExpandOps.cpp
ExpandRealloc.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
new file mode 100644
index 0000000000000..6f41c2e5aab3b
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
@@ -0,0 +1,233 @@
+//===-ElideReinterpretCast.cpp - Expansion patterns for MemRef operations-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <cassert>
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_ELIDEREINTERPRETCASTPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Returns true if `rc` represents a scalar view (all sizes == 1)
+/// into a memref that has exactly one non-unit dimension located at
+/// either the first or last position (i.e. a "row" or "column").
+///
+/// Examples that return true:
+///
+/// // Row-major slice (last dim is non-unit)
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [1, 1, 1], strides: [1, 1, 1]
+/// : memref<1x1x8xi32> to memref<1x1x1xi32>
+///
+/// // Column-major slice (first dim is non-unit)
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [1, 1], strides: [1, 1]
+/// : memref<2x1xf32> to memref<1x1xf32>
+///
+/// // Random strides
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [1, 1], strides: [10, 100]
+/// : memref<2x1xf32, strided<[10, 100]>>
+/// to memref<1x1xf32>
+///
+/// // Rank-1 case
+/// memref.reinterpret_cast %buf to offset: [%off],
+/// sizes: [1], strides: [1]
+/// : memref<8xi32> to memref<1xi32>
+///
+/// Examples that return false:
+///
+/// // More non-unit dims
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [1, 1, 1], strides: [1, 1, 1]
+/// : memref<1x2x8xi32> to memref<1x1x1xi32>
+///
+/// // View is not scalar (size != 1)
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [2, 1], strides: [1, 1]
+/// : memref<1x2xf32> to memref<2x1xf32>
+///
+/// // Base has non-identity layout
+/// %buff = memref.alloc() : memref<1x2xf32, strided<[1, 3]>>
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [1, 1], strides: [1, 1]
+/// : memref<1x2xf32, strided<[1, 3]>> to memref<1x1xf32>
+static bool isScalarSlice(memref::ReinterpretCastOp rc) {
+ auto rcInputTy = dyn_cast<MemRefType>(rc.getSource().getType());
+ auto rcOutputTy = dyn_cast<MemRefType>(rc.getType());
+
+ // Reject strided base - logic for computing linear idx is TODO
+ if (!rcInputTy.getLayout().isIdentity())
+ return false;
+
+ // Reject non-matching ranks
+ unsigned srcRank = rcInputTy.getRank();
+ if (srcRank != rcOutputTy.getRank())
+ return false;
+
+ ArrayRef<int64_t> sizes = rc.getStaticSizes();
+
+ // View must be scalar: memref<1x...x1>
+ if (!llvm::all_of(rcOutputTy.getShape(),
+ [](int64_t dim) { return dim == 1; }))
+ return false;
+
+ // Sizes must all be statically 1
+ if (!llvm::all_of(sizes, [](int64_t size) {
+ return !ShapedType::isDynamic(size) && size == 1;
+ }))
+ return false;
+
+ // Rank-1 special case
+ if (srcRank == 1) {
+ // Reject non-scalar output
+ if (rcOutputTy.getDimSize(0) > 1)
+ return false;
+ }
+
+ int nonUnitDim = -1;
+
+ for (unsigned i = 0; i < srcRank; ++i) {
+ int64_t underlyingDim = rcInputTy.getDimSize(i);
+
+ // Must have at most one non-unit dimension
+ if (underlyingDim != 1) {
+ if (nonUnitDim != -1)
+ return false;
+ nonUnitDim = i;
+ }
+ }
+ return true;
+}
+
+/// Rewrites `memref.copy` of a 1-element MemRef as a scalar load-store pair
+///
+/// The pattern matches a reinterpret_cast that creates a scalar view
+/// (`sizes = [1, ..., 1]`) into a memref with a single non-unit dimension.
+/// Since the view contains only one element, the accessed address is
+/// determined solely by the base pointer and the offset.
+///
+/// Two layouts are supported:
+/// * row-major slice (stride pattern [N, ..., 1])
+/// * column-major slice (stride pattern [1, ..., N])
+///
+/// BEFORE (row-major slice)
+/// %view = memref.reinterpret_cast %base
+/// to offset: [%off], sizes: [1, ..., 1], strides: [N, ..., 1]
+/// : memref<1x...xNxf32>
+/// to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
+/// memref.copy %src, %view
+/// : memref<1x...x1xf32>
+/// to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
+///
+/// AFTER
+/// %c0 = arith.constant 0 : index
+/// %v = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
+/// memref.store %v, %base[%c0, ..., %off] : memref<1x...xNxf32>
+///
+/// BEFORE (column-major slice)
+/// %view = memref.reinterpret_cast %base
+/// to offset: [%off], sizes: [1, ..., 1], strides: [1, ..., N]
+/// : memref<Nx...x1xf32>
+/// to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
+/// memref.copy %src, %view
+/// : memref<1x...x1xf32>
+/// to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
+///
+/// AFTER
+/// %c0 = arith.constant 0 : index
+/// %v = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
+/// memref.store %v, %base[%off, ..., %c0] : memref<Nx...x1xf32>
+struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::CopyOp op,
+ PatternRewriter &rewriter) const final {
+ Value rcOutput = op.getTarget();
+ auto rc = rcOutput.getDefiningOp<memref::ReinterpretCastOp>();
+ if (!rc)
+ return rewriter.notifyMatchFailure(
+ op, "target is not a memref.reinterpret_cast");
+
+ if (!isScalarSlice(rc))
+ return rewriter.notifyMatchFailure(
+ op, "reinterpret_cast does not match scalar slice");
+
+ Location loc = op.getLoc();
+
+ Value src = op.getSource();
+ Value dst = rc.getSource();
+
+ auto dstType = cast<MemRefType>(dst.getType());
+ unsigned dstRank = dstType.getRank();
+
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+ auto srcType = cast<MemRefType>(src.getType());
+ SmallVector<Value> loadIndices(srcType.getRank(), zero);
+ auto offsets = rc.getMixedOffsets();
+ assert(offsets.size() == 1 && "Expecting single offset");
+ OpFoldResult offset = offsets[0];
+ Value storeOffset = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
+ unsigned offsetDim = dstType.getDimSize(0) == 1 ? dstRank - 1 : 0;
+ SmallVector<Value> storeIndices(dstRank, zero);
+ storeIndices[offsetDim] = storeOffset;
+
+ if (rcOutput.hasOneUse())
+ rewriter.eraseOp(rc);
+
+ Value val = memref::LoadOp::create(rewriter, loc, src, loadIndices);
+ memref::StoreOp::create(rewriter, loc, val, dst, storeIndices);
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct ElideReinterpretCastPass
+ : public memref::impl::ElideReinterpretCastPassBase<
+ ElideReinterpretCastPass> {
+ void runOnOperation() override {
+ MLIRContext &ctx = getContext();
+
+ RewritePatternSet patterns(&ctx);
+ memref::populateElideReinterpretCastPatterns(patterns);
+ ConversionTarget target(ctx);
+ target.addDynamicallyLegalOp<memref::CopyOp>([](memref::CopyOp op) {
+ auto rc = op.getTarget().getDefiningOp<memref::ReinterpretCastOp>();
+ if (!rc)
+ return true;
+ return !isScalarSlice(rc);
+ });
+ target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::memref::populateElideReinterpretCastPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<CopyToScalarLoadAndStore>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
new file mode 100644
index 0000000000000..39f5270cacef0
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
@@ -0,0 +1,237 @@
+// RUN: mlir-opt -memref-elide-reinterpret-cast %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @concat_zero_offset(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_zero_offset(%src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ /// reinterpret_cast removed
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1], strides: [1, 1]
+ : memref<1x108xf32> to memref<1x1xf32>
+
+ /// Ensure copy was replaced
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32> to memref<1x1xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_nonzero_offset(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_nonzero_offset(%src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [1], sizes: [1, 1], strides: [1, 1]
+ : memref<1x108xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: 1>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C1]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: 1>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_dynamic_offset(
+// CHECK-SAME: %[[OFF:.*]]: index
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_dynamic_offset(%offset: index, %src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [%offset], sizes: [1, 1], strides: [1, 1]
+ : memref<1x108xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: ?>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]]
+ // CHECK-SAME: : memref<1x1xf32>
+ /// Dynamic offset used in store
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[OFF]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_strided(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_strided(%src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1], strides: [107, 2]
+ : memref<1x108xf32> to memref<1x1xf32, strided<[107, 2]>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32> to memref<1x1xf32, strided<[107, 2]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_dynamic_stride(
+// CHECK-SAME: %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME: %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME: %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xf32>
+func.func private @concat_dynamic_stride(%stride0: index,
+ %stride1: index, %src : memref<1x1xf32>, %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
+ : memref<1x108xf32>
+ to memref<1x1xf32, strided<[?, ?]>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ /// Dynamic offset used in store
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32>
+ to memref<1x1xf32, strided<[?, ?]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_rank1(
+// CHECK-SAME: %[[SRC:.*]]: memref<1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<108xf32>
+func.func private @concat_rank1(%src : memref<1xf32>, %dst : memref<108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1], strides: [1]
+ : memref<108xf32> to memref<1xf32>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0_0]]] : memref<108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1xf32> to memref<1xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_rank3(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x1x108xf32>
+func.func private @concat_rank3(%src : memref<1x1x1xf32>,
+ %dst : memref<1x1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1]
+ : memref<1x1x108xf32> to memref<1x1x1xf32>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0]], %[[C0_0]]] : memref<1x1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1x1xf32> to memref<1x1x1xf32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// Negative tests (must NOT rewrite)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @concat_rank3_more_non_unit_dims(
+func.func private @concat_rank3_more_non_unit_dims(%src : memref<1x1x1xf32>,
+ %dst : memref<1x2x108xf32>) {
+ // CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1, 1], strides: [0, 0, 0]
+ : memref<1x2x108xf32> to memref<1x1x1xf32, strided<[0, 0, 0]>>
+
+ // CHECK: memref.copy %arg0, %reinterpret_cast
+ // CHECK-NOT: memref.load
+ // CHECK-NOT: memref.store
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1x1xf32> to memref<1x1x1xf32, strided<[0, 0, 0]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_strided_base(
+func.func private @concat_strided_base(%src: memref<1x1xf32>,
+ %dst: memref<8x1xf32, strided<[10, 2]>>) {
+ // CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [6], sizes: [1, 1], strides: [11, 80]
+ : memref<8x1xf32, strided<[10, 2]>>
+ to memref<1x1xf32, strided<[11, 80], offset: 6>>
+
+ // CHECK: memref.copy %arg0, %reinterpret_cast
+ // CHECK-NOT: memref.load
+ // CHECK-NOT: memref.store
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32> to memref<1x1xf32, strided<[11, 80], offset: 6>>
+
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_rank_change(
+func.func private @reshape_rank_change(%src : memref<2x3xf32>,
+ %dst : memref<6xf32>) {
+ // CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [2, 3], strides: [3, 1]
+ : memref<6xf32> to memref<2x3xf32>
+
+ // CHECK: memref.copy %arg0, %reinterpret_cast
+ // CHECK-NOT: memref.load
+ // CHECK-NOT: memref.store
+ memref.copy %src, %reinterpret_cast
+ : memref<2x3xf32> to memref<2x3xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_multi_dim_vector(
+func.func private @concat_multi_dim_vector(%src : memref<1x1xf32>,
+ %dst : memref<2x108xf32>) {
+ // CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1], strides: [0, 0]
+ : memref<2x108xf32>
+ to memref<1x1xf32, strided<[0, 0]>>
+ // CHECK: memref.copy %arg0, %reinterpret_cast
+ // CHECK-NOT: memref.load
+ // CHECK-NOT: memref.store
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32> to memref<1x1xf32, strided<[0, 0]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @plain_copy(
+func.func private @plain_copy(%src : memref<1x1xf32>, %dst : memref<1x1xf32>) {
+ // CHECK: memref.copy %arg0, %arg1
+ // CHECK-NOT: memref.load
+ // CHECK-NOT: memref.store
+ memref.copy %src, %dst
+ : memref<1x1xf32> to memref<1x1xf32>
+ return
+}
>From 6203cc7ba6a179c5945767aa20f7b559e82fbd4d Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Thu, 19 Mar 2026 10:42:53 +0100
Subject: [PATCH 2/2] fixup
---
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 573c1cf8857f1..1c5e07f89b338 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,8 +1,8 @@
add_mlir_dialect_library(MLIRMemRefTransforms
AllocationOpInterfaceImpl.cpp
BufferViewFlowOpInterfaceImpl.cpp
- ElideReinterpretCast.cpp
ComposeSubView.cpp
+ ElideReinterpretCast.cpp
ExpandOps.cpp
ExpandRealloc.cpp
ExpandStridedMetadata.cpp
More information about the Mlir-commits
mailing list