[Mlir-commits] [mlir] c5c0b83 - [mlir][memref] Rewrite scalar `memref.copy` through reinterpret_cast into load/store (#186118)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 19 08:35:05 PDT 2026
Author: ioana ghiban
Date: 2026-03-19T15:34:59Z
New Revision: c5c0b8348e6cbb966675c7f4c5e55922da5f6a88
URL: https://github.com/llvm/llvm-project/commit/c5c0b8348e6cbb966675c7f4c5e55922da5f6a88
DIFF: https://github.com/llvm/llvm-project/commit/c5c0b8348e6cbb966675c7f4c5e55922da5f6a88.diff
LOG: [mlir][memref] Rewrite scalar `memref.copy` through reinterpret_cast into load/store (#186118)
This change adds a rewrite that simplifies `memref.copy` operations whose
destination is a scalar view produced by `memref.reinterpret_cast`.
The pattern matches cases where a reinterpret cast creates a scalar view
(`sizes = [1, ..., 1]`) into a memref that has a single non-unit dimension. In
this situation the view refers to exactly one element in the base buffer, so
the accessed address depends only on the base pointer and the offset.
The stride information of the view does not affect the accessed element,
because the only valid index into the view is `[0, ..., 0]`.
Therefore the copy can be rewritten into a direct load from the source and a
store into the base memref using the offset from the reinterpret cast.
This makes the `memref.reinterpret_cast` redundant for the copy and simplifies
the IR.
Assisted-by: ChatGPT (refine implementation + tests). I reviewed all code and
tests before submission.
### Example
Before:
```mlir
func.func private @concat() {
%src = memref.alloc() : memref<1x1xf32>
%base = memref.alloc() : memref<1x108xf32>
%view = memref.reinterpret_cast %base
to offset: [0], sizes: [1, 1], strides: [108, 1]
: memref<1x108xf32>
to memref<1x1xf32, strided<[108, 1]>>
memref.copy %src, %view
: memref<1x1xf32>
to memref<1x1xf32, strided<[108, 1]>>
}
```
After:
```mlir
func.func private @concat() {
%src = memref.alloc() : memref<1x1xf32>
%base = memref.alloc() : memref<1x108xf32>
%c0 = arith.constant 0 : index
%v = memref.load %src[%c0, %c0] : memref<1x1xf32>
memref.store %v, %base[%c0, %c0] : memref<1x108xf32>
}
```
### Motivation
This rewrite simplifies IR and helps eliminate `memref.reinterpret_cast`
operations in preparation for later lowerings (e.g. EmitC lowering), where
pointer-based access patterns are easier to handle once scalar accesses are
explicit.
### Scope
This rewrite is intentionally narrow:
- It only applies when both source and destination reduce to scalar accesses.
- It does not attempt to rewrite general `memref.copy` operations.
- It does not introduce loops or handle multi-element copies.
The pass currently performs only this transformation, so it is expected to be
used intentionally rather than as part of a broad optimization pipeline.
### Why not use `memref.copy` directly?
`memref.copy` requires source and destination memrefs to have the same shape.
The destination of the copy here is a scalar view derived from a larger memref,
so copying directly into the base memref would violate this requirement.
Instead, the rewrite loads the scalar value from the source and stores it into
the base memref, at the index determined by the reinterpret cast offset.
Added:
mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
Modified:
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index d04ae101fe1dc..3fb0588df395a 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -11,6 +11,16 @@
include "mlir/Pass/PassBase.td"
+def ElideReinterpretCastPass : Pass<"memref-elide-reinterpret-cast"> {
+ let summary = "Replace ops depending on redundant reinterpret_cast(s) to be "
+ "convertible to EmitC.";
+ let description = [{
+ Replace data-movement ops that depend on redundant memref.reinterpret_cast
+ operations to obtain compatible shapes with equivalent ops that operate on
+ compatible shapes directly. This simplifies conversion to EmitC.
+}];
+}
+
def ExpandOpsPass : Pass<"memref-expand"> {
let summary = "Legalize memref operations to be convertible to LLVM.";
}
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index b0069bce89588..62745f8fa1dfa 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -39,6 +39,10 @@ class DeallocOp;
// Patterns
//===----------------------------------------------------------------------===//
+/// Collects a set of patterns that bypass memref.reinterpet_cast Ops. This
+/// simplifies the IR in the context of lowering to EmitC.
+void populateElideReinterpretCastPatterns(RewritePatternSet &patterns);
+
/// Collects a set of patterns to rewrite ops within the memref dialect.
void populateExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 9049faccadef3..1c5e07f89b338 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
AllocationOpInterfaceImpl.cpp
BufferViewFlowOpInterfaceImpl.cpp
ComposeSubView.cpp
+ ElideReinterpretCast.cpp
ExpandOps.cpp
ExpandRealloc.cpp
ExpandStridedMetadata.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
new file mode 100644
index 0000000000000..dc139d892f5e5
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
@@ -0,0 +1,225 @@
+//===-ElideReinterpretCast.cpp - Expansion patterns for MemRef operations-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <cassert>
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_ELIDEREINTERPRETCASTPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Returns true if `rc` represents a scalar view (all sizes == 1)
+/// into a memref that has exactly one non-unit dimension located at
+/// either the first or last position (i.e. a "row" or "column").
+///
+/// Examples that return true:
+///
+/// // Row-major slice (last dim is non-unit)
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [1, 1, 1], strides: [1, 1, 1]
+/// : memref<1x1x8xi32> to memref<1x1x1xi32>
+///
+/// // Column-major slice (first dim is non-unit)
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [1, 1], strides: [1, 1]
+/// : memref<2x1xf32> to memref<1x1xf32>
+///
+/// // Random strides
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [1, 1], strides: [10, 100]
+/// : memref<2x1xf32, strided<[10, 100]>>
+/// to memref<1x1xf32>
+///
+/// // Rank-1 case
+/// memref.reinterpret_cast %buf to offset: [%off],
+/// sizes: [1], strides: [1]
+/// : memref<8xi32> to memref<1xi32>
+///
+/// Examples that return false:
+///
+/// // More non-unit dims
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [1, 1, 1], strides: [1, 1, 1]
+/// : memref<1x2x8xi32> to memref<1x1x1xi32>
+///
+/// // View is not scalar (size != 1)
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [2, 1], strides: [1, 1]
+/// : memref<1x2xf32> to memref<2x1xf32>
+///
+/// // Base has non-identity layout
+/// %buff = memref.alloc() : memref<1x2xf32, strided<[1, 3]>>
+/// memref.reinterpret_cast %buff to offset: [%off],
+/// sizes: [1, 1], strides: [1, 1]
+/// : memref<1x2xf32, strided<[1, 3]>> to memref<1x1xf32>
+static bool isScalarSlice(memref::ReinterpretCastOp rc) {
+ auto rcInputTy = dyn_cast<MemRefType>(rc.getSource().getType());
+ auto rcOutputTy = dyn_cast<MemRefType>(rc.getType());
+
+ // Reject strided base - logic for computing linear idx is TODO
+ if (!rcInputTy.getLayout().isIdentity())
+ return false;
+
+ // Reject non-matching ranks
+ unsigned srcRank = rcInputTy.getRank();
+ if (srcRank != rcOutputTy.getRank())
+ return false;
+
+ ArrayRef<int64_t> sizes = rc.getStaticSizes();
+
+ // View must be scalar: memref<1x...x1>
+ if (!llvm::all_of(rcOutputTy.getShape(),
+ [](int64_t dim) { return dim == 1; }))
+ return false;
+
+ // Sizes must all be statically 1
+ if (!llvm::all_of(sizes, [](int64_t size) {
+ return !ShapedType::isDynamic(size) && size == 1;
+ }))
+ return false;
+
+ // Rank-1 special case
+ if (srcRank == 1) {
+ // Reject non-scalar output
+ if (rcOutputTy.getDimSize(0) > 1)
+ return false;
+ }
+
+ int nonUnitCount =
+ std::count_if(rcInputTy.getShape().begin(), rcInputTy.getShape().end(),
+ [](int dim) { return dim != 1; });
+ return nonUnitCount == 1;
+}
+
+/// Rewrites `memref.copy` of a 1-element MemRef as a scalar load-store pair
+///
+/// The pattern matches a reinterpret_cast that creates a scalar view
+/// (`sizes = [1, ..., 1]`) into a memref with a single non-unit dimension.
+/// Since the view contains only one element, the accessed address is
+/// determined solely by the base pointer and the offset.
+///
+/// Two layouts are supported:
+/// * row-major slice (stride pattern [N, ..., 1])
+/// * column-major slice (stride pattern [1, ..., N])
+///
+/// BEFORE (row-major slice)
+/// %view = memref.reinterpret_cast %base
+/// to offset: [%off], sizes: [1, ..., 1], strides: [N, ..., 1]
+/// : memref<1x...xNxf32>
+/// to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
+/// memref.copy %src, %view
+/// : memref<1x...x1xf32>
+/// to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
+///
+/// AFTER
+/// %c0 = arith.constant 0 : index
+/// %v = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
+/// memref.store %v, %base[%c0, ..., %off] : memref<1x...xNxf32>
+///
+/// BEFORE (column-major slice)
+/// %view = memref.reinterpret_cast %base
+/// to offset: [%off], sizes: [1, ..., 1], strides: [1, ..., N]
+/// : memref<Nx...x1xf32>
+/// to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
+/// memref.copy %src, %view
+/// : memref<1x...x1xf32>
+/// to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
+///
+/// AFTER
+/// %c0 = arith.constant 0 : index
+/// %v = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
+/// memref.store %v, %base[%off, ..., %c0] : memref<Nx...x1xf32>
+struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::CopyOp op,
+ PatternRewriter &rewriter) const final {
+ Value rcOutput = op.getTarget();
+ auto rc = rcOutput.getDefiningOp<memref::ReinterpretCastOp>();
+ if (!rc)
+ return rewriter.notifyMatchFailure(
+ op, "target is not a memref.reinterpret_cast");
+
+ if (!isScalarSlice(rc))
+ return rewriter.notifyMatchFailure(
+ op, "reinterpret_cast does not match scalar slice");
+
+ Location loc = op.getLoc();
+
+ Value src = op.getSource();
+ Value dst = rc.getSource();
+
+ auto dstType = cast<MemRefType>(dst.getType());
+ unsigned dstRank = dstType.getRank();
+
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+ auto srcType = cast<MemRefType>(src.getType());
+ SmallVector<Value> loadIndices(srcType.getRank(), zero);
+ auto offsets = rc.getMixedOffsets();
+ assert(offsets.size() == 1 && "Expecting single offset");
+ OpFoldResult offset = offsets[0];
+ Value storeOffset = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
+ unsigned offsetDim = dstType.getDimSize(0) == 1 ? dstRank - 1 : 0;
+ SmallVector<Value> storeIndices(dstRank, zero);
+ storeIndices[offsetDim] = storeOffset;
+ // If the only user of `rc` is the current Op (which is about to be erased),
+ // we can safely erase it.
+ if (rcOutput.hasOneUse())
+ rewriter.eraseOp(rc);
+
+ Value val = memref::LoadOp::create(rewriter, loc, src, loadIndices);
+ memref::StoreOp::create(rewriter, loc, val, dst, storeIndices);
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct ElideReinterpretCastPass
+ : public memref::impl::ElideReinterpretCastPassBase<
+ ElideReinterpretCastPass> {
+ void runOnOperation() override {
+ MLIRContext &ctx = getContext();
+
+ RewritePatternSet patterns(&ctx);
+ memref::populateElideReinterpretCastPatterns(patterns);
+ ConversionTarget target(ctx);
+ target.addDynamicallyLegalOp<memref::CopyOp>([](memref::CopyOp op) {
+ auto rc = op.getTarget().getDefiningOp<memref::ReinterpretCastOp>();
+ if (!rc)
+ return true;
+ return !isScalarSlice(rc);
+ });
+ target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::memref::populateElideReinterpretCastPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<CopyToScalarLoadAndStore>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
new file mode 100644
index 0000000000000..da47562e9c0d6
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
@@ -0,0 +1,222 @@
+// RUN: mlir-opt -memref-elide-reinterpret-cast %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @concat_zero_offset(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_zero_offset(%src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ /// reinterpret_cast removed
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1], strides: [1, 1]
+ : memref<1x108xf32> to memref<1x1xf32>
+
+ /// Ensure copy was replaced
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32> to memref<1x1xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_nonzero_offset(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_nonzero_offset(%src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [1], sizes: [1, 1], strides: [1, 1]
+ : memref<1x108xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: 1>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C1]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: 1>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_dynamic_offset(
+// CHECK-SAME: %[[OFF:.*]]: index
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_dynamic_offset(%offset: index, %src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [%offset], sizes: [1, 1], strides: [1, 1]
+ : memref<1x108xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: ?>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]]
+ // CHECK-SAME: : memref<1x1xf32>
+ /// Dynamic offset used in store
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[OFF]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32>
+ to memref<1x1xf32, strided<[1, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_strided(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
+func.func private @concat_strided(%src : memref<1x1xf32>,
+ %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1], strides: [107, 2]
+ : memref<1x108xf32> to memref<1x1xf32, strided<[107, 2]>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32> to memref<1x1xf32, strided<[107, 2]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_dynamic_stride(
+// CHECK-SAME: %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME: %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME: %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x1xf32>
+// CHECK-SAME: %[[DST:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xf32>
+func.func private @concat_dynamic_stride(%stride0: index,
+ %stride1: index, %src : memref<1x1xf32>, %dst : memref<1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
+ : memref<1x108xf32>
+ to memref<1x1xf32, strided<[?, ?]>>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
+ /// Dynamic offset used in store
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32>
+ to memref<1x1xf32, strided<[?, ?]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_rank1(
+// CHECK-SAME: %[[SRC:.*]]: memref<1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<108xf32>
+func.func private @concat_rank1(%src : memref<1xf32>, %dst : memref<108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1], strides: [1]
+ : memref<108xf32> to memref<1xf32>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0_0]]] : memref<108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1xf32> to memref<1xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @concat_rank3(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1xf32>
+// CHECK-SAME: %[[DST:.*]]: memref<1x1x108xf32>
+func.func private @concat_rank3(%src : memref<1x1x1xf32>,
+ %dst : memref<1x1x108xf32>) {
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1]
+ : memref<1x1x108xf32> to memref<1x1x1xf32>
+
+ // CHECK-NOT: memref.copy
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32>
+ // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0]], %[[C0_0]]] : memref<1x1x108xf32>
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1x1xf32> to memref<1x1x1xf32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// Negative tests (must NOT rewrite)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @negative_concat_strided_base(
+func.func private @negative_concat_strided_base(%src: memref<1x1xf32>,
+ %dst: memref<8x1xf32, strided<[10, 2]>>) {
+ // CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [6], sizes: [1, 1], strides: [11, 80]
+ : memref<8x1xf32, strided<[10, 2]>>
+ to memref<1x1xf32, strided<[11, 80], offset: 6>>
+
+ // CHECK: memref.copy %arg0, %reinterpret_cast
+ // CHECK-NOT: memref.load
+ // CHECK-NOT: memref.store
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32> to memref<1x1xf32, strided<[11, 80], offset: 6>>
+
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_rank_change(
+func.func private @negative_reshape_rank_change(%src : memref<2x3xf32>,
+ %dst : memref<6xf32>) {
+ // CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [2, 3], strides: [3, 1]
+ : memref<6xf32> to memref<2x3xf32>
+
+ // CHECK: memref.copy %arg0, %reinterpret_cast
+ // CHECK-NOT: memref.load
+ // CHECK-NOT: memref.store
+ memref.copy %src, %reinterpret_cast
+ : memref<2x3xf32> to memref<2x3xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_concat_multiple_non_unit_dims(
+func.func private @negative_concat_multiple_non_unit_dims(
+ %src : memref<1x1xf32>, %dst : memref<2x108xf32>) {
+ // CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1
+ %reinterpret_cast = memref.reinterpret_cast %dst
+ to offset: [0], sizes: [1, 1], strides: [1, 1]
+ : memref<2x108xf32>
+ to memref<1x1xf32>
+ // CHECK: memref.copy %arg0, %reinterpret_cast
+ // CHECK-NOT: memref.load
+ // CHECK-NOT: memref.store
+ memref.copy %src, %reinterpret_cast
+ : memref<1x1xf32> to memref<1x1xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_plain_copy(
+func.func private @negative_plain_copy(%src : memref<1x1xf32>,
+ %dst : memref<1x1xf32>) {
+ // CHECK: memref.copy %arg0, %arg1
+ // CHECK-NOT: memref.load
+ // CHECK-NOT: memref.store
+ memref.copy %src, %dst
+ : memref<1x1xf32> to memref<1x1xf32>
+ return
+}
More information about the Mlir-commits
mailing list