[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)
ioana ghiban
llvmlistbot at llvm.org
Thu Apr 2 05:08:02 PDT 2026
https://github.com/ioghiban updated https://github.com/llvm/llvm-project/pull/188459
>From f285b1dc9d61d679c4d49c52ce8a9b7ac3d1fef5 Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Tue, 24 Mar 2026 17:34:52 +0100
Subject: [PATCH 1/2] [memref] Simplify loads from reinterpret_cast of 1D
contiguous memrefs
Assisted-by: ChatGPT (refine implementation + tests). I reviewed all code and tests before submission.
---
.../Transforms/ElideReinterpretCast.cpp | 241 +++++++++++++-
.../MemRef/elide-reinterpret-cast.mlir | 309 +++++++++++++++++-
2 files changed, 548 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
index dc139d892f5e5..49d764fc5aee1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
@@ -195,6 +196,237 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
}
};
+static bool isConstZero(Value v) { return matchPattern(v, m_Zero()); }
+
+static bool isPureRankReshape(memref::ReinterpretCastOp rc, memref::LoadOp op) {
+ auto inputTy = cast<MemRefType>(rc.getSource().getType());
+ auto outputTy = cast<MemRefType>(rc.getResult().getType());
+
+ // This fold only handles reinterpret_casts that behave like pure rank
+ // reshapes of a single logical dimension:
+ //
+ // - all metadata is static
+ // - offset is 0
+ // - source/result each have at most one non-unit dim
+ // - if a non-unit dim exists, it is at the left or right boundary
+ //
+ // Examples accepted by this shape restriction:
+ // memref<999xf32> <-> memref<1x1x999xf32>
+ // memref<1x108xf32> <-> memref<1x1x1x108xf32>
+ // memref<100x1xf32> <-> memref<100x1x1xf32>
+ //
+ // General reinterpret_casts are intentionally rejected.
+
+ auto offsets = rc.getStaticOffsets();
+ assert(offsets.size() == 1 && "Expecting single offset");
+
+ // The rewrite drops the reinterpret_cast and remaps indices directly to the
+ // source memref. That is only correct if there is no storage shift.
+ if (ShapedType::isDynamic(offsets[0]) || offsets[0] != 0)
+ return false;
+
+ auto sizes = rc.getStaticSizes();
+ auto strides = rc.getStaticStrides();
+
+ // Require fully static metadata. The fold relies on knowing exactly which
+ // dimensions are unit dimensions and which indices may be ignored.
+ if (llvm::any_of(sizes, ShapedType::isDynamic))
+ return false;
+ if (llvm::any_of(strides, ShapedType::isDynamic))
+ return false;
+
+ // Count non-unit dims and remember their positions.
+ //
+ // The rewrite supports shapes with at most one non-unit dimension.
+ // This excludes underlying multi-dimensional layouts and keeps the
+ // fold limited to unit-dim insertion/removal reshapes.
+ unsigned inputRank = inputTy.getRank();
+ int inputNonUnitCount = 0;
+ int64_t inputNonUnitSize = 1;
+ unsigned inputNonUnitPos = 0;
+ for (unsigned i = 0; i < inputRank; ++i) {
+ if (inputTy.getDimSize(i) != 1) {
+ ++inputNonUnitCount;
+ inputNonUnitPos = i;
+ inputNonUnitSize = inputTy.getDimSize(i);
+ }
+ }
+
+ unsigned outputRank = outputTy.getRank();
+ int outputNonUnitCount = 0;
+ int64_t outputNonUnitSize = 1;
+ unsigned outputNonUnitPos = 0;
+ for (unsigned i = 0; i < outputRank; ++i) {
+ if (outputTy.getDimSize(i) != 1) {
+ ++outputNonUnitCount;
+ outputNonUnitPos = i;
+ outputNonUnitSize = outputTy.getDimSize(i);
+ }
+ }
+
+ // Reject reshapes with > 1 non-unit-dimension.
+ //
+ // The source and result must have the same number of non-unit dimensions:
+ // either both are all-ones, or both have exactly one non-unit dimension.
+ if (inputNonUnitCount > 1 || outputNonUnitCount > 1 ||
+ inputNonUnitCount != outputNonUnitCount)
+ return false;
+
+ // If there is a non-unit dimension, it must live at the same boundary
+ // (first or last dimension) on both input and output memrefs.
+ // The rewrite logic for preserving the load index is exclusive to these
+ // cases.
+ if (inputNonUnitCount == 1) {
+ auto isBoundary = [](unsigned pos, unsigned rank) {
+ return pos == 0 || pos == rank - 1;
+ };
+ if (!isBoundary(inputNonUnitPos, inputRank) ||
+ !isBoundary(outputNonUnitPos, outputRank))
+ return false;
+ }
+
+ // Size of non-unit dimension must be the same
+ if (inputNonUnitCount == 1 && outputNonUnitCount == 1 &&
+ inputNonUnitSize != outputNonUnitSize)
+ return false;
+
+ SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
+ SmallVector<unsigned> nonZeroIdxPositions;
+ nonZeroIdxPositions.reserve(idxs.size());
+
+ // Record non-zero indices.
+ //
+ // During rank expansion, the rewrite drops the extra unit-dimension indices.
+ // That is only semantics-preserving if every dropped index is zero.
+ for (auto [pos, idx] : llvm::enumerate(idxs)) {
+ if (!isConstZero(idx))
+ nonZeroIdxPositions.push_back(pos);
+ }
+
+ // Position of the unique non-unit dim in the output, if present:
+ // - 0 for shapes like [N, 1, 1]
+ // - outputRank-1 for shapes like [1, 1, N]
+ //
+ // For the all-ones case, treat it like the "non-unit on the right" case.
+ unsigned nonUnitDimPos =
+ (outputNonUnitCount == 1 && outputTy.getDimSize(0) != 1) ? 0
+ : outputRank - 1;
+
+ if (outputRank >= inputRank) {
+ // Rank expansion case.
+ //
+ // The rewrite keeps only inputRank indices. Any non-zero index in an
+ // expanded unit dimension that would be discarded makes the fold invalid.
+ if (nonUnitDimPos == 0) {
+ // Expansion on the right: keep the leftmost inputRank indices.
+ // Therefore any non-zero index in the suffix would be lost.
+ for (unsigned pos : nonZeroIdxPositions) {
+ if (pos >= inputRank)
+ return false;
+ }
+ } else {
+ // Expansion on the left: keep the rightmost inputRank indices.
+ // Therefore any non-zero index in the prefix would be lost.
+ unsigned firstValidPos = outputRank - inputRank;
+ for (unsigned pos : nonZeroIdxPositions) {
+ if (pos < firstValidPos)
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+struct FoldReinterpretCastLoad : public OpRewritePattern<memref::LoadOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::LoadOp op,
+ PatternRewriter &rewriter) const override {
+ auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
+ if (!rc)
+ return failure();
+
+ // This fold is only correct for the narrow "pure rank reshape of a single
+ // logical dimension" cases accepted by isPureRankReshape().
+ if (!isPureRankReshape(rc, op))
+ return failure();
+
+ auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
+ auto rcInputTy = cast<MemRefType>(rc.getSource().getType());
+
+ int64_t rcOutputRank = rcOutputTy.getRank();
+ int64_t rcInputRank = rcInputTy.getRank();
+
+ SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
+ SmallVector<Value> rcInputIdxs;
+
+ // The fold only supports reshapes with at most one non-unit dimension,
+ // located at the left or right boundary.
+ //
+ // The higher-rank side tells which side the reshape has expanded/collapsed.
+ //
+ // expansion: rcOutput has the higher rank
+ // collapse : rcInput has the higher rank
+ //
+ // Example:
+ // memref<999> -> memref<1x1x999> : extra dims to the left
+ // memref<999x1x1> -> memref<999> : extra dims to the right
+ MemRefType expandedTy =
+ rcOutputRank >= rcInputRank ? rcOutputTy : rcInputTy;
+ bool nonUnitOnLeft = expandedTy.getDimSize(0) != 1;
+
+ if (rcOutputRank >= rcInputRank) {
+ // Rank expansion:
+ // memref<N> -> memref<1x1xN> : keep the last rcInputRank indices
+ // memref<N> -> memref<Nx1x1> : keep the first rcInputRank indices
+ //
+ // Any discarded indices are known to be zero from isPureRankReshape().
+ if (nonUnitOnLeft) {
+ for (int64_t dim = 0; dim < rcInputRank; ++dim)
+ rcInputIdxs.push_back(idxs[dim]);
+ } else {
+ for (int64_t dim = 0; dim < rcInputRank; ++dim)
+ rcInputIdxs.push_back(idxs[rcOutputRank - rcInputRank + dim]);
+ }
+ } else {
+ // Rank collapse:
+ // memref<1x1xN> -> memref<N> : reinsert leading zeros
+ // memref<Nx1x1> -> memref<N> : reinsert trailing zeros
+ //
+ // The collapsed-away dimensions are unit dims, so readding them with
+ // zero indices preserves semantics.
+ Value c0 = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+ int64_t rankDiff = rcInputRank - rcOutputRank;
+
+ if (nonUnitOnLeft) {
+ rcInputIdxs.append(idxs.begin(), idxs.end());
+ rcInputIdxs.append(rankDiff, c0);
+ } else {
+ rcInputIdxs.append(rankDiff, c0);
+ rcInputIdxs.append(idxs.begin(), idxs.end());
+ }
+ }
+
+ // Sanity check: rewritten load must index the source memref with exactly
+ // as many indices as the rank.
+ if ((int64_t)rcInputIdxs.size() != rcInputRank)
+ return failure();
+
+ auto rcInput = rc.getSource();
+ // If the only user of rc is the current Op (which is about to be erased),
+ // we can safely erase it.
+ if (rc.getResult().hasOneUse())
+ rewriter.eraseOp(rc);
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(op, rcInput, rcInputIdxs);
+
+ // Do not erase the reinterpret_cast here. After the load is rewritten it
+ // may become dead, and canonical DCE can remove it.
+ return success();
+ }
+};
+
struct ElideReinterpretCastPass
: public memref::impl::ElideReinterpretCastPassBase<
ElideReinterpretCastPass> {
@@ -210,6 +442,12 @@ struct ElideReinterpretCastPass
return true;
return !isScalarSlice(rc);
});
+ target.addDynamicallyLegalOp<memref::LoadOp>([](memref::LoadOp op) {
+ auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
+ if (!rc)
+ return true;
+ return !isPureRankReshape(rc, op);
+ });
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -221,5 +459,6 @@ struct ElideReinterpretCastPass
void mlir::memref::populateElideReinterpretCastPatterns(
RewritePatternSet &patterns) {
- patterns.add<CopyToScalarLoadAndStore>(patterns.getContext());
+ patterns.add<CopyToScalarLoadAndStore, FoldReinterpretCastLoad>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
index da47562e9c0d6..5733c97ea8f3b 100644
--- a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
+++ b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -memref-elide-reinterpret-cast %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -memref-elide-reinterpret-cast %s \
+// RUN: | FileCheck %s
//===----------------------------------------------------------------------===//
// Positive tests
@@ -220,3 +221,309 @@ func.func private @negative_plain_copy(%src : memref<1x1xf32>,
: memref<1x1xf32> to memref<1x1xf32>
return
}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Positive tests
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @reshape_expand_scalar(
+// CHECK-SAME: %[[SRC:.*]]: memref<1xi64>) {
+func.func private @reshape_expand_scalar(%src : memref<1xi64>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+ to memref<1x1x1xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]]] : memref<1xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c1] : memref<1x1x1xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_scalar(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1xi64>) {
+func.func private @reshape_collapse_scalar(%src : memref<1x1x1xi64>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
+ to memref<1x1xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x1xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_left_vector(
+// CHECK-SAME: %[[SRC:.*]]: memref<999xi64>) {
+func.func private @reshape_expand_left_vector(%src : memref<999xi64>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1]
+ : memref<999xi64> to memref<1x1x999xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x999xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_left_vector(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x999xi64>) {
+func.func private @reshape_collapse_left_vector(%src : memref<1x1x999xi64>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [999], strides: [1]
+ : memref<1x1x999xi64> to memref<999xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C1]]] : memref<1x1x999xi64>
+ %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_left_inner_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x108xf32>) {
+func.func private @reshape_expand_left_inner_unit_dims(
+ %src : memref<1x108xf32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
+ : memref<1x108xf32> to memref<1x1x1x108xf32>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<1x108xf32>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c1, %c0]
+ : memref<1x1x1x108xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_left_inner_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @reshape_collapse_left_inner_unit_dims(
+ %src : memref<1x1x1x100xf32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 100], strides: [100, 1]
+ : memref<1x1x1x100xf32> to memref<1x100xf32>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1x100xf32>
+ %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x100xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_right_vector(
+// CHECK-SAME: %[[SRC:.*]]: memref<999xi64>) {
+func.func private @reshape_expand_right_vector(%src : memref<999xi64>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [999, 1, 1], strides: [1, 999, 999]
+ : memref<999xi64> to memref<999x1x1xi64, strided<[1, 999, 999]>>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<999xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<999x1x1xi64,
+ strided<[1, 999, 999]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_right_vector(
+// CHECK-SAME: %[[SRC:.*]]: memref<999x1x1xi64>) {
+func.func private @reshape_collapse_right_vector(%src : memref<999x1x1xi64>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [999], strides: [1]
+ : memref<999x1x1xi64> to memref<999xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0]]] : memref<999x1x1xi64>
+ %0 = memref.load %reinterpret_cast[%c1] : memref<999xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_expand_right_inner_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<108x1xf32>) {
+func.func private @reshape_expand_right_inner_unit_dims(
+ %src : memref<108x1xf32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
+ : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<108x1xf32>
+ %0 = memref.load %reinterpret_cast[%c0, %c1, %c0, %c0]
+ : memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @reshape_collapse_right_inner_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<100x1x1x1xf32>) {
+func.func private @reshape_collapse_right_inner_unit_dims(
+ %src : memref<100x1x1x1xf32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [100, 1], strides: [1, 100]
+ : memref<100x1x1x1xf32> to memref<100x1xf32, strided<[1, 100]>>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]], %[[C0_0]], %[[C0_0]]] : memref<100x1x1x1xf32>
+ %0 = memref.load %reinterpret_cast[%c1, %c0] : memref<100x1xf32,
+ strided<[1, 100]>>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// Negative tests (must NOT rewrite)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func private @negative_reshape_nonzero_offset(
+// CHECK-SAME: %[[SRC:.*]]: memref<1xi64>) {
+func.func private @negative_reshape_nonzero_offset(
+ %src : memref<1xi64>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64> to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+ to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+ // CHECK: memref.load %[[RC]]
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c1]
+ : memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_dynamic_shape(
+// CHECK-SAME: %[[DIM:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME: %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<?xi64>
+func.func private @negative_reshape_dynamic_shape(%dim : index, %i : index,
+ %src : memref<?xi64>) {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, %[[DIM]]], strides: [1, 1] : memref<?xi64> to memref<1x?xi64>
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, %dim], strides: [1, 1]
+ : memref<?xi64> to memref<1x?xi64>
+ // CHECK: memref.load %[[RC]]
+ %0 = memref.load %reinterpret_cast[%c0, %i] : memref<1x?xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_dynamic_stride(
+// CHECK-SAME: %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME: %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
+// CHECK-SAME: %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xi64>
+func.func private @negative_reshape_dynamic_stride(%stride0: index,
+ %stride1: index, %src : memref<1x108xi64>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 1], strides: [%[[STR0]], %[[STR1]]] : memref<1x108xi64> to memref<1x1xi64, strided<[?, ?]>>
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
+ : memref<1x108xi64>
+ to memref<1x1xi64, strided<[?, ?]>>
+ // CHECK: memref.load %[[RC]]
+ %0 = memref.load %reinterpret_cast[%c0, %c1]
+ : memref<1x1xi64, strided<[?, ?]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_multiple_non_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<2x1x1x100xf32>) {
+func.func private @negative_reshape_multiple_non_unit_dims(
+ %src : memref<2x1x1x100xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [2, 100], strides: [100, 1] : memref<2x1x1x100xf32> to memref<2x100xf32>
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [2, 100], strides: [100, 1]
+ : memref<2x1x1x100xf32> to memref<2x100xf32>
+ // CHECK: memref.load %[[RC]]
+ %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<2x100xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_diff_non_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @negative_reshape_diff_non_unit_dims(
+ %src : memref<1x1x1x100xf32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 101], strides: [101, 1] : memref<1x1x1x100xf32> to memref<1x101xf32>
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 101], strides: [101, 1]
+ : memref<1x1x1x100xf32> to memref<1x101xf32>
+ // CHECK: memref.load %[[RC]]
+ %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x101xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_inner_non_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @negative_reshape_inner_non_unit_dims(
+ %src : memref<1x1x1x100xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100] : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100]
+ : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
+ // CHECK: memref.load %[[RC]]
+ %0 = memref.load %reinterpret_cast[%c0, %c1, %c0] : memref<1x100x1xf32,
+ strided<[100, 1, 100]>>
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_expand_left_discarded_indices(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x108xf32>) {
+func.func private @negative_reshape_expand_left_discarded_indices(
+ %src : memref<1x108xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1] : memref<1x108xf32> to memref<1x1x1x108xf32>
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
+ : memref<1x108xf32> to memref<1x1x1x108xf32>
+ // CHECK: memref.load %[[RC]]
+ %0 = memref.load %reinterpret_cast[%c0, %c1, %c0, %c0]
+ : memref<1x1x1x108xf32>
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_reshape_expand_right_discarded_indices(
+// CHECK-SAME: %[[SRC:.*]]: memref<108x1xf32>) {
+func.func private @negative_reshape_expand_right_discarded_indices(
+ %src : memref<108x1xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108] : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
+ : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+ // CHECK: memref.load %[[RC]]
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c1, %c0]
+ : memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+ return
+}
>From e34a107942271924e1f3930452c16e65fffa7332 Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Tue, 31 Mar 2026 12:09:03 +0200
Subject: [PATCH 2/2] Address first round of comments
---
.../Transforms/ElideReinterpretCast.cpp | 323 ++++++++++--------
.../MemRef/elide-reinterpret-cast.mlir | 181 ++++++----
2 files changed, 284 insertions(+), 220 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
index 49d764fc5aee1..b132ff0b77597 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ElideReinterpretCast.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
+#include <optional>
namespace mlir {
namespace memref {
@@ -198,147 +199,164 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
static bool isConstZero(Value v) { return matchPattern(v, m_Zero()); }
-static bool isPureRankReshape(memref::ReinterpretCastOp rc, memref::LoadOp op) {
+static std::optional<int64_t> getConstantIndex(Value v) {
+ if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
+ return cst.value();
+ return std::nullopt;
+}
+
+struct SingleNonUnitDimInfo {
+ bool Exists = false;
+ bool isOnLeft = true;
+ bool isOnRight = true;
+};
+
+/// Returns information about a MemRef with at most one non-unit dimension.
+///
+/// The single non-unit dimension, if present, must be on the left or right
+/// boundary. Rank-1 non-unit memrefs are treated as being on both boundaries.
+static std::optional<SingleNonUnitDimInfo>
+getSingleNonUnitDimInfo(MemRefType type) {
+ ArrayRef<int64_t> shape = type.getShape();
+ int64_t nonUnitCount =
+ llvm::count_if(shape, [](int64_t dim) { return dim != 1; });
+ if (nonUnitCount == 0)
+ return SingleNonUnitDimInfo{};
+ if (nonUnitCount > 1)
+ return std::nullopt;
+
+ bool isOnLeft = shape.front() != 1;
+ bool isOnRight = shape.back() != 1;
+ if (!isOnLeft && !isOnRight)
+ return std::nullopt;
+
+ return SingleNonUnitDimInfo{/*Exists=*/true, isOnLeft, isOnRight};
+}
+
+static bool hasStaticZeroOffset(memref::ReinterpretCastOp rc) {
+ ArrayRef<int64_t> offsets = rc.getStaticOffsets();
+ // FIXME: Despite what `getStaticOffsets` implies, `reinterpret_cast` takes
+ // only a single offset. That should be fixed at the op definition level.
+ return offsets.size() == 1 && !ShapedType::isDynamic(offsets[0]) &&
+ offsets[0] == 0;
+}
+
+static bool isConstantIndexInBounds(Value idx, int64_t upperBound) {
+ if (isConstZero(idx))
+ return true;
+
+ std::optional<int64_t> idxVal = getConstantIndex(idx);
+ return idxVal && *idxVal >= 0 && *idxVal < upperBound;
+}
+
+/// Checks for pure rank expansion/collapsing of a single logical dimension:
+/// - all metadata is static
+/// - offset is 0
+/// - source/result each have at most one non-unit dim
+/// - if a non-unit dim exists, it is at the left or right boundary
+///
+/// Examples accepted by this shape restriction:
+/// memref<999xf32> <-> memref<1x1x999xf32>
+/// memref<1x108xf32> <-> memref<1x1x1x108xf32>
+/// memref<100x1xf32> <-> memref<100x1x1xf32>
+/// memref<1> <-> memref<1x1x1>
+///
+/// General reinterpret_casts are intentionally rejected.
+static bool isPureRankExpansionOrCollapsingRC(memref::ReinterpretCastOp rc) {
auto inputTy = cast<MemRefType>(rc.getSource().getType());
auto outputTy = cast<MemRefType>(rc.getResult().getType());
- // This fold only handles reinterpret_casts that behave like pure rank
- // reshapes of a single logical dimension:
- //
- // - all metadata is static
- // - offset is 0
- // - source/result each have at most one non-unit dim
- // - if a non-unit dim exists, it is at the left or right boundary
- //
- // Examples accepted by this shape restriction:
- // memref<999xf32> <-> memref<1x1x999xf32>
- // memref<1x108xf32> <-> memref<1x1x1x108xf32>
- // memref<100x1xf32> <-> memref<100x1x1xf32>
- //
- // General reinterpret_casts are intentionally rejected.
-
- auto offsets = rc.getStaticOffsets();
- assert(offsets.size() == 1 && "Expecting single offset");
-
- // The rewrite drops the reinterpret_cast and remaps indices directly to the
- // source memref. That is only correct if there is no storage shift.
- if (ShapedType::isDynamic(offsets[0]) || offsets[0] != 0)
+ // This check assumes the rewrite relies on "index re-use" and misses "index
+ // re-write/adjustment". Thus, storage shift and statically unknown offsets
+ // are rejected.
+ if (!hasStaticZeroOffset(rc))
return false;
- auto sizes = rc.getStaticSizes();
- auto strides = rc.getStaticStrides();
-
- // Require fully static metadata. The fold relies on knowing exactly which
- // dimensions are unit dimensions and which indices may be ignored.
- if (llvm::any_of(sizes, ShapedType::isDynamic))
- return false;
- if (llvm::any_of(strides, ShapedType::isDynamic))
+ // The check assumes the rewrite relies on completely static shape info.
+ if (llvm::any_of(rc.getStaticSizes(), ShapedType::isDynamic) ||
+ llvm::any_of(rc.getStaticStrides(), ShapedType::isDynamic))
return false;
- // Count non-unit dims and remember their positions.
- //
- // The rewrite supports shapes with at most one non-unit dimension.
- // This excludes underlying multi-dimensional layouts and keeps the
- // fold limited to unit-dim insertion/removal reshapes.
- unsigned inputRank = inputTy.getRank();
- int inputNonUnitCount = 0;
- int64_t inputNonUnitSize = 1;
- unsigned inputNonUnitPos = 0;
- for (unsigned i = 0; i < inputRank; ++i) {
- if (inputTy.getDimSize(i) != 1) {
- ++inputNonUnitCount;
- inputNonUnitPos = i;
- inputNonUnitSize = inputTy.getDimSize(i);
- }
- }
-
- unsigned outputRank = outputTy.getRank();
- int outputNonUnitCount = 0;
- int64_t outputNonUnitSize = 1;
- unsigned outputNonUnitPos = 0;
- for (unsigned i = 0; i < outputRank; ++i) {
- if (outputTy.getDimSize(i) != 1) {
- ++outputNonUnitCount;
- outputNonUnitPos = i;
- outputNonUnitSize = outputTy.getDimSize(i);
- }
- }
+ // The check assumes the rewrite supports shapes with at most one non-unit
+ // dimension. This excludes underlying multi-dimensional layouts and keeps the
+ // rewrite limited to unit-dim insertion/removal `reinterpret_cast`s.
+ std::optional<SingleNonUnitDimInfo> inputNonUnitDim =
+ getSingleNonUnitDimInfo(inputTy);
+ std::optional<SingleNonUnitDimInfo> outputNonUnitDim =
+ getSingleNonUnitDimInfo(outputTy);
+ if (!inputNonUnitDim || !outputNonUnitDim)
+ return false;
- // Reject reshapes with > 1 non-unit-dimension.
- //
- // The source and result must have the same number of non-unit dimensions:
- // either both are all-ones, or both have exactly one non-unit dimension.
- if (inputNonUnitCount > 1 || outputNonUnitCount > 1 ||
- inputNonUnitCount != outputNonUnitCount)
+ // The source and result must either both be all-ones, or both have a single
+ // non-unit dimension.
+ if (inputNonUnitDim->Exists != outputNonUnitDim->Exists)
return false;
+ if (!inputNonUnitDim->Exists)
+ return true;
- // If there is a non-unit dimension, it must live at the same boundary
- // (first or last dimension) on both input and output memrefs.
- // The rewrite logic for preserving the load index is exclusive to these
- // cases.
- if (inputNonUnitCount == 1) {
- auto isBoundary = [](unsigned pos, unsigned rank) {
- return pos == 0 || pos == rank - 1;
- };
- if (!isBoundary(inputNonUnitPos, inputRank) ||
- !isBoundary(outputNonUnitPos, outputRank))
- return false;
- }
+ // If there is a non-unit dimension, it must be preserved on the same
+ // boundary. Rank-1 memrefs are accepted against either boundary.
+ return (inputNonUnitDim->isOnLeft && outputNonUnitDim->isOnLeft) ||
+ (inputNonUnitDim->isOnRight && outputNonUnitDim->isOnRight);
+}
- // Size of non-unit dimension must be the same
- if (inputNonUnitCount == 1 && outputNonUnitCount == 1 &&
- inputNonUnitSize != outputNonUnitSize)
+/// Checks whether load indices corresponding to unit dims in the source
+/// MemRef are all 0, i.e. in-bounds.
+/// Returns false for out-of-bounds indices or non-constant indices.
+static bool areIndicesForUnitDimsInBounds(memref::LoadOp load) {
+ auto rc = load.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
+ auto rcInputTy = cast<MemRefType>(rc.getSource().getType());
+ auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
+ int64_t rcInputRank = rcInputTy.getRank();
+ int64_t rcOutputRank = rcOutputTy.getRank();
+
+ std::optional<SingleNonUnitDimInfo> inputNonUnitDim =
+ getSingleNonUnitDimInfo(rcInputTy);
+ std::optional<SingleNonUnitDimInfo> outputNonUnitDim =
+ getSingleNonUnitDimInfo(rcOutputTy);
+ if (!inputNonUnitDim || !outputNonUnitDim)
return false;
- SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
SmallVector<unsigned> nonZeroIdxPositions;
- nonZeroIdxPositions.reserve(idxs.size());
+ nonZeroIdxPositions.reserve(load.getIndices().size());
- // Record non-zero indices.
- //
- // During rank expansion, the rewrite drops the extra unit-dimension indices.
- // That is only semantics-preserving if every dropped index is zero.
- for (auto [pos, idx] : llvm::enumerate(idxs)) {
- if (!isConstZero(idx))
- nonZeroIdxPositions.push_back(pos);
- }
+ // Check if non-zero indices are out-of-bounds.
+ // Only care about indices corresponding to the load from the
+ // reinterpret_cast result.
+ for (auto [pos, idx] : llvm::enumerate(load.getIndices())) {
+ if (isConstZero(idx))
+ continue;
- // Position of the unique non-unit dim in the output, if present:
- // - 0 for shapes like [N, 1, 1]
- // - outputRank-1 for shapes like [1, 1, N]
- //
- // For the all-ones case, treat it like the "non-unit on the right" case.
- unsigned nonUnitDimPos =
- (outputNonUnitCount == 1 && outputTy.getDimSize(0) != 1) ? 0
- : outputRank - 1;
-
- if (outputRank >= inputRank) {
- // Rank expansion case.
- //
- // The rewrite keeps only inputRank indices. Any non-zero index in an
- // expanded unit dimension that would be discarded makes the fold invalid.
- if (nonUnitDimPos == 0) {
- // Expansion on the right: keep the leftmost inputRank indices.
- // Therefore any non-zero index in the suffix would be lost.
- for (unsigned pos : nonZeroIdxPositions) {
- if (pos >= inputRank)
- return false;
- }
- } else {
- // Expansion on the left: keep the rightmost inputRank indices.
- // Therefore any non-zero index in the prefix would be lost.
- unsigned firstValidPos = outputRank - inputRank;
- for (unsigned pos : nonZeroIdxPositions) {
- if (pos < firstValidPos)
- return false;
- }
- }
+ // Bail out early for the all-ones case.
+ if (!inputNonUnitDim->Exists)
+ return false;
+
+ // FIXME: This should be ensured by the memref.load semantics.
+ if (!isConstantIndexInBounds(idx, rcOutputTy.getDimSize(pos)))
+ return false;
+
+ nonZeroIdxPositions.push_back(pos);
}
- return true;
+ // During rank expansion, the rewrite drops the extra unit-dimension indices.
+ // That is only semantics-preserving if every dropped index is zero.
+ // This check is only relevant for expansions with a non-unit dimension.
+ if (rcOutputRank < rcInputRank || !inputNonUnitDim->Exists)
+ return true;
+
+ // The rewrite keeps either a prefix or suffix of length `rcInputRank`.
+ // Any non-zero index outside the preserved slice would be discarded.
+ if (outputNonUnitDim->isOnLeft)
+ return llvm::none_of(nonZeroIdxPositions,
+ [&](unsigned pos) { return pos >= rcInputRank; });
+
+ unsigned firstKeptPos = rcOutputRank - rcInputRank;
+ return llvm::none_of(nonZeroIdxPositions,
+ [&](unsigned pos) { return pos < firstKeptPos; });
}
-struct FoldReinterpretCastLoad : public OpRewritePattern<memref::LoadOp> {
+struct RewriteLoadFromReinterpretCast
+ : public OpRewritePattern<memref::LoadOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -348,9 +366,11 @@ struct FoldReinterpretCastLoad : public OpRewritePattern<memref::LoadOp> {
if (!rc)
return failure();
- // This fold is only correct for the narrow "pure rank reshape of a single
- // logical dimension" cases accepted by isPureRankReshape().
- if (!isPureRankReshape(rc, op))
+ // This rewrite is only correct for the narrow "pure rank expansion or
+ // collapsing of a single logical dimension" cases accepted by these two
+ // checks.
+ if (!isPureRankExpansionOrCollapsingRC(rc) ||
+ !areIndicesForUnitDimsInBounds(op))
return failure();
auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
@@ -361,46 +381,51 @@ struct FoldReinterpretCastLoad : public OpRewritePattern<memref::LoadOp> {
SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
SmallVector<Value> rcInputIdxs;
+ rcInputIdxs.reserve(rcInputRank);
- // The fold only supports reshapes with at most one non-unit dimension,
- // located at the left or right boundary.
+ // The rewrite only supports reinterpret_casts with at most one non-unit
+ // dimension, located at the left or right boundary.
//
- // The higher-rank side tells which side the reshape has expanded/collapsed.
+ // The higher-rank side tells which side the reinterpret_cast has
+ // expanded/collapsed.
//
// expansion: rcOutput has the higher rank
- // collapse : rcInput has the higher rank
+ // collapsing : rcInput has the higher rank
//
// Example:
// memref<999> -> memref<1x1x999> : extra dims to the left
// memref<999x1x1> -> memref<999> : extra dims to the right
MemRefType expandedTy =
rcOutputRank >= rcInputRank ? rcOutputTy : rcInputTy;
- bool nonUnitOnLeft = expandedTy.getDimSize(0) != 1;
+ std::optional<SingleNonUnitDimInfo> expandedNonUnitDim =
+ getSingleNonUnitDimInfo(expandedTy);
+ assert(expandedNonUnitDim && "expected a single boundary non-unit dim");
+ bool keepLeadingIndices = expandedNonUnitDim->isOnLeft;
if (rcOutputRank >= rcInputRank) {
// Rank expansion:
- // memref<N> -> memref<1x1xN> : keep the last rcInputRank indices
- // memref<N> -> memref<Nx1x1> : keep the first rcInputRank indices
+ // memref<N> -> memref<1x1xN> : keep the last rcInputRank indices
+ // memref<N> -> memref<Nx1x1> : keep the first rcInputRank indices
+ // memref<1> -> memref<1x1x1> : all indices are zero
//
- // Any discarded indices are known to be zero from isPureRankReshape().
- if (nonUnitOnLeft) {
- for (int64_t dim = 0; dim < rcInputRank; ++dim)
- rcInputIdxs.push_back(idxs[dim]);
- } else {
- for (int64_t dim = 0; dim < rcInputRank; ++dim)
- rcInputIdxs.push_back(idxs[rcOutputRank - rcInputRank + dim]);
- }
+ // Any discarded indices are known to be zero from
+ // areIndicesForUnitDimsInBounds().
+ int64_t firstKeptPos =
+ keepLeadingIndices ? 0 : rcOutputRank - rcInputRank;
+ rcInputIdxs.append(idxs.begin() + firstKeptPos,
+ idxs.begin() + firstKeptPos + rcInputRank);
} else {
- // Rank collapse:
- // memref<1x1xN> -> memref<N> : reinsert leading zeros
- // memref<Nx1x1> -> memref<N> : reinsert trailing zeros
+ // Rank collapsing:
+ // memref<1x1xN> -> memref<N> : reinsert leading zeros
+ // memref<Nx1x1> -> memref<N> : reinsert trailing zeros
+ // memref<1x1x1> -> memref<1> : all indices are zero
//
- // The collapsed-away dimensions are unit dims, so readding them with
+ // The collapsed-away dimensions are unit dims, so re-adding them with
// zero indices preserves semantics.
Value c0 = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
int64_t rankDiff = rcInputRank - rcOutputRank;
- if (nonUnitOnLeft) {
+ if (keepLeadingIndices) {
rcInputIdxs.append(idxs.begin(), idxs.end());
rcInputIdxs.append(rankDiff, c0);
} else {
@@ -409,10 +434,7 @@ struct FoldReinterpretCastLoad : public OpRewritePattern<memref::LoadOp> {
}
}
- // Sanity check: rewritten load must index the source memref with exactly
- // as many indices as the rank.
- if ((int64_t)rcInputIdxs.size() != rcInputRank)
- return failure();
+ assert(rcInputIdxs.size() == rcInputRank && "Incorrect number of indices!");
auto rcInput = rc.getSource();
// If the only user of rc is the current Op (which is about to be erased),
@@ -446,7 +468,8 @@ struct ElideReinterpretCastPass
auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
if (!rc)
return true;
- return !isPureRankReshape(rc, op);
+ return !(isPureRankExpansionOrCollapsingRC(rc) &&
+ areIndicesForUnitDimsInBounds(op));
});
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
if (failed(applyPartialConversion(getOperation(), target,
@@ -459,6 +482,6 @@ struct ElideReinterpretCastPass
void mlir::memref::populateElideReinterpretCastPatterns(
RewritePatternSet &patterns) {
- patterns.add<CopyToScalarLoadAndStore, FoldReinterpretCastLoad>(
+ patterns.add<CopyToScalarLoadAndStore, RewriteLoadFromReinterpretCast>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
index 5733c97ea8f3b..944751dbb005f 100644
--- a/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
+++ b/mlir/test/Dialect/MemRef/elide-reinterpret-cast.mlir
@@ -179,8 +179,8 @@ func.func private @negative_concat_strided_base(%src: memref<1x1xf32>,
return
}
-// CHECK-LABEL: func.func private @negative_reshape_rank_change(
-func.func private @negative_reshape_rank_change(%src : memref<2x3xf32>,
+// CHECK-LABEL: func.func private @negative_rank_change(
+func.func private @negative_rank_change(%src : memref<2x3xf32>,
%dst : memref<6xf32>) {
// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1
%reinterpret_cast = memref.reinterpret_cast %dst
@@ -229,42 +229,38 @@ func.func private @negative_plain_copy(%src : memref<1x1xf32>,
// Positive tests
//===----------------------------------------------------------------------===//
-// CHECK-LABEL: func.func private @reshape_expand_scalar(
+// CHECK-LABEL: func.func private @expand_scalar(
// CHECK-SAME: %[[SRC:.*]]: memref<1xi64>) {
-func.func private @reshape_expand_scalar(%src : memref<1xi64>) {
+func.func private @expand_scalar(%src : memref<1xi64>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[C1:.*]] = arith.constant 1 : index
%c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
// CHECK-NOT: memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %src
to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
to memref<1x1x1xi64>
- // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]]] : memref<1xi64>
- %0 = memref.load %reinterpret_cast[%c0, %c0, %c1] : memref<1x1x1xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x1xi64>
return
}
-// CHECK-LABEL: func.func private @reshape_collapse_scalar(
+// CHECK-LABEL: func.func private @collapse_scalar(
// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1xi64>) {
-func.func private @reshape_collapse_scalar(%src : memref<1x1x1xi64>) {
+func.func private @collapse_scalar(%src : memref<1x1x1xi64>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
%c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
// CHECK-NOT: memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %src
to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
to memref<1x1xi64>
- // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0]], %[[C1]]] : memref<1x1x1xi64>
- %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x1xi64>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C0_0]]] : memref<1x1x1xi64>
+ %0 = memref.load %reinterpret_cast[%c0, %c0] : memref<1x1xi64>
return
}
-// CHECK-LABEL: func.func private @reshape_expand_left_vector(
+// CHECK-LABEL: func.func private @expand_left_vector(
// CHECK-SAME: %[[SRC:.*]]: memref<999xi64>) {
-func.func private @reshape_expand_left_vector(%src : memref<999xi64>) {
+func.func private @expand_left_vector(%src : memref<999xi64>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK-NOT: memref.reinterpret_cast
@@ -276,9 +272,9 @@ func.func private @reshape_expand_left_vector(%src : memref<999xi64>) {
return
}
-// CHECK-LABEL: func.func private @reshape_collapse_left_vector(
+// CHECK-LABEL: func.func private @collapse_left_vector(
// CHECK-SAME: %[[SRC:.*]]: memref<1x1x999xi64>) {
-func.func private @reshape_collapse_left_vector(%src : memref<1x1x999xi64>) {
+func.func private @collapse_left_vector(%src : memref<1x1x999xi64>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
@@ -291,9 +287,9 @@ func.func private @reshape_collapse_left_vector(%src : memref<1x1x999xi64>) {
return
}
-// CHECK-LABEL: func.func private @reshape_expand_left_inner_unit_dims(
+// CHECK-LABEL: func.func private @expand_left_inner_unit_dims(
// CHECK-SAME: %[[SRC:.*]]: memref<1x108xf32>) {
-func.func private @reshape_expand_left_inner_unit_dims(
+func.func private @expand_left_inner_unit_dims(
%src : memref<1x108xf32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -303,15 +299,15 @@ func.func private @reshape_expand_left_inner_unit_dims(
%reinterpret_cast = memref.reinterpret_cast %src
to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
: memref<1x108xf32> to memref<1x1x1x108xf32>
- // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<1x108xf32>
- %0 = memref.load %reinterpret_cast[%c0, %c0, %c1, %c0]
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<1x108xf32>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c0, %c1]
: memref<1x1x1x108xf32>
return
}
-// CHECK-LABEL: func.func private @reshape_collapse_left_inner_unit_dims(
+// CHECK-LABEL: func.func private @collapse_left_inner_unit_dims(
// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1x100xf32>) {
-func.func private @reshape_collapse_left_inner_unit_dims(
+func.func private @collapse_left_inner_unit_dims(
%src : memref<1x1x1x100xf32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
@@ -327,9 +323,9 @@ func.func private @reshape_collapse_left_inner_unit_dims(
return
}
-// CHECK-LABEL: func.func private @reshape_expand_right_vector(
+// CHECK-LABEL: func.func private @expand_right_vector(
// CHECK-SAME: %[[SRC:.*]]: memref<999xi64>) {
-func.func private @reshape_expand_right_vector(%src : memref<999xi64>) {
+func.func private @expand_right_vector(%src : memref<999xi64>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK-NOT: memref.reinterpret_cast
@@ -342,9 +338,9 @@ func.func private @reshape_expand_right_vector(%src : memref<999xi64>) {
return
}
-// CHECK-LABEL: func.func private @reshape_collapse_right_vector(
+// CHECK-LABEL: func.func private @collapse_right_vector(
// CHECK-SAME: %[[SRC:.*]]: memref<999x1x1xi64>) {
-func.func private @reshape_collapse_right_vector(%src : memref<999x1x1xi64>) {
+func.func private @collapse_right_vector(%src : memref<999x1x1xi64>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
@@ -357,9 +353,9 @@ func.func private @reshape_collapse_right_vector(%src : memref<999x1x1xi64>) {
return
}
-// CHECK-LABEL: func.func private @reshape_expand_right_inner_unit_dims(
+// CHECK-LABEL: func.func private @expand_right_inner_unit_dims(
// CHECK-SAME: %[[SRC:.*]]: memref<108x1xf32>) {
-func.func private @reshape_expand_right_inner_unit_dims(
+func.func private @expand_right_inner_unit_dims(
%src : memref<108x1xf32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -369,15 +365,15 @@ func.func private @reshape_expand_right_inner_unit_dims(
%reinterpret_cast = memref.reinterpret_cast %src
to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
: memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
- // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0]], %[[C1]]] : memref<108x1xf32>
- %0 = memref.load %reinterpret_cast[%c0, %c1, %c0, %c0]
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C1]], %[[C0]]] : memref<108x1xf32>
+ %0 = memref.load %reinterpret_cast[%c1, %c0, %c0, %c0]
: memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
return
}
-// CHECK-LABEL: func.func private @reshape_collapse_right_inner_unit_dims(
+// CHECK-LABEL: func.func private @collapse_right_inner_unit_dims(
// CHECK-SAME: %[[SRC:.*]]: memref<100x1x1x1xf32>) {
-func.func private @reshape_collapse_right_inner_unit_dims(
+func.func private @collapse_right_inner_unit_dims(
%src : memref<100x1x1x1xf32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
@@ -394,17 +390,35 @@ func.func private @reshape_collapse_right_inner_unit_dims(
return
}
+// CHECK-LABEL: func.func private @negative_diff_non_unit_dims(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1x100xf32>) {
+func.func private @negative_diff_non_unit_dims(
+ %src : memref<1x1x1x100xf32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C98:.*]] = arith.constant 98 : index
+ %c0 = arith.constant 0 : index
+ %c98 = arith.constant 98 : index
+ // CHECK-NOT: memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 99], strides: [99, 1]
+ : memref<1x1x1x100xf32> to memref<1x99xf32>
+ // CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[C0_0]], %[[C0_0]], %[[C0]], %[[C98]]] : memref<1x1x1x100xf32>
+ %0 = memref.load %reinterpret_cast[%c0, %c98] : memref<1x99xf32>
+ return
+}
+
//===----------------------------------------------------------------------===//
// Negative tests (must NOT rewrite)
//===----------------------------------------------------------------------===//
-// CHECK-LABEL: func.func private @negative_reshape_nonzero_offset(
+// CHECK-LABEL: func.func private @negative_nonzero_offset(
// CHECK-SAME: %[[SRC:.*]]: memref<1xi64>) {
-func.func private @negative_reshape_nonzero_offset(
+func.func private @negative_nonzero_offset(
%src : memref<1xi64>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64> to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %src
to offset: [1], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
to memref<1x1x1xi64, strided<[1, 1, 1], offset: 1>>
@@ -414,13 +428,13 @@ func.func private @negative_reshape_nonzero_offset(
return
}
-// CHECK-LABEL: func.func private @negative_reshape_dynamic_shape(
+// CHECK-LABEL: func.func private @negative_dynamic_shape(
// CHECK-SAME: %[[DIM:[A-Za-z][A-Za-z0-9-]*]]: index
// CHECK-SAME: %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<?xi64>
-func.func private @negative_reshape_dynamic_shape(%dim : index, %i : index,
+func.func private @negative_dynamic_shape(%dim : index, %i : index,
%src : memref<?xi64>) {
%c0 = arith.constant 0 : index
- // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, %[[DIM]]], strides: [1, 1] : memref<?xi64> to memref<1x?xi64>
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %src
to offset: [0], sizes: [1, %dim], strides: [1, 1]
: memref<?xi64> to memref<1x?xi64>
@@ -429,15 +443,15 @@ func.func private @negative_reshape_dynamic_shape(%dim : index, %i : index,
return
}
-// CHECK-LABEL: func.func private @negative_reshape_dynamic_stride(
+// CHECK-LABEL: func.func private @negative_dynamic_stride(
// CHECK-SAME: %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
// CHECK-SAME: %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
// CHECK-SAME: %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xi64>
-func.func private @negative_reshape_dynamic_stride(%stride0: index,
+func.func private @negative_dynamic_stride(%stride0: index,
%stride1: index, %src : memref<1x108xi64>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 1], strides: [%[[STR0]], %[[STR1]]] : memref<1x108xi64> to memref<1x1xi64, strided<[?, ?]>>
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %src
to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
: memref<1x108xi64>
@@ -448,13 +462,13 @@ func.func private @negative_reshape_dynamic_stride(%stride0: index,
return
}
-// CHECK-LABEL: func.func private @negative_reshape_multiple_non_unit_dims(
+// CHECK-LABEL: func.func private @negative_multiple_non_unit_dims(
// CHECK-SAME: %[[SRC:.*]]: memref<2x1x1x100xf32>) {
-func.func private @negative_reshape_multiple_non_unit_dims(
- %src : memref<2x1x1x100xf32>) {
+func.func private @negative_multiple_non_unit_dims(
+ %src : memref<2x1x1x100xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [2, 100], strides: [100, 1] : memref<2x1x1x100xf32> to memref<2x100xf32>
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %src
to offset: [0], sizes: [2, 100], strides: [100, 1]
: memref<2x1x1x100xf32> to memref<2x100xf32>
@@ -463,46 +477,73 @@ func.func private @negative_reshape_multiple_non_unit_dims(
return
}
-// CHECK-LABEL: func.func private @negative_reshape_diff_non_unit_dims(
+// CHECK-LABEL: func.func private @negative_inner_non_unit_dims(
// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1x100xf32>) {
-func.func private @negative_reshape_diff_non_unit_dims(
+func.func private @negative_inner_non_unit_dims(
%src : memref<1x1x1x100xf32>) {
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 101], strides: [101, 1] : memref<1x1x1x100xf32> to memref<1x101xf32>
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %src
- to offset: [0], sizes: [1, 101], strides: [101, 1]
- : memref<1x1x1x100xf32> to memref<1x101xf32>
+ to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100]
+ : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
// CHECK: memref.load %[[RC]]
- %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x101xf32>
+ %0 = memref.load %reinterpret_cast[%c0, %c1, %c0] : memref<1x100x1xf32,
+ strided<[100, 1, 100]>>
return
}
-// CHECK-LABEL: func.func private @negative_reshape_inner_non_unit_dims(
-// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1x100xf32>) {
-func.func private @negative_reshape_inner_non_unit_dims(
- %src : memref<1x1x1x100xf32>) {
+// CHECK-LABEL: func.func private @negative_expand_out_of_bounds(
+// CHECK-SAME: %[[SRC:.*]]: memref<1xi64>) {
+func.func private @negative_expand_out_of_bounds(%src : memref<1xi64>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100] : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %src
- to offset: [0], sizes: [1, 100, 1], strides: [100, 1, 100]
- : memref<1x1x1x100xf32> to memref<1x100x1xf32, strided<[100, 1, 100]>>
+ to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+ to memref<1x1x1xi64>
// CHECK: memref.load %[[RC]]
- %0 = memref.load %reinterpret_cast[%c0, %c1, %c0] : memref<1x100x1xf32,
- strided<[100, 1, 100]>>
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %c1] : memref<1x1x1xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_collapse_out_of_bounds(
+// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1xi64>) {
+func.func private @negative_collapse_out_of_bounds(%src : memref<1x1x1xi64>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1], strides: [1, 1] : memref<1x1x1xi64>
+ to memref<1x1xi64>
+ // CHECK: memref.load %[[RC]]
+ %0 = memref.load %reinterpret_cast[%c0, %c1] : memref<1x1xi64>
+ return
+}
+
+// CHECK-LABEL: func.func private @negative_expand_negative_index(
+// CHECK-SAME: %[[SRC:.*]]: memref<1xi64>) {
+func.func private @negative_expand_negative_index(%src : memref<1xi64>) {
+ %c0 = arith.constant 0 : index
+ %cneg1 = arith.constant -1 : index
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast
+ %reinterpret_cast = memref.reinterpret_cast %src
+ to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] : memref<1xi64>
+ to memref<1x1x1xi64>
+ // CHECK: memref.load %[[RC]]
+ %0 = memref.load %reinterpret_cast[%c0, %c0, %cneg1] : memref<1x1x1xi64>
return
}
-// CHECK-LABEL: func.func private @negative_reshape_expand_left_discarded_indices(
+// CHECK-LABEL: func.func private @negative_expand_left_discarded_indices(
// CHECK-SAME: %[[SRC:.*]]: memref<1x108xf32>) {
-func.func private @negative_reshape_expand_left_discarded_indices(
+func.func private @negative_expand_left_discarded_indices(
%src : memref<1x108xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1] : memref<1x108xf32> to memref<1x1x1x108xf32>
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %src
to offset: [0], sizes: [1, 1, 1, 108], strides: [108, 108, 108, 1]
: memref<1x108xf32> to memref<1x1x1x108xf32>
@@ -512,13 +553,13 @@ func.func private @negative_reshape_expand_left_discarded_indices(
return
}
-// CHECK-LABEL: func.func private @negative_reshape_expand_right_discarded_indices(
+// CHECK-LABEL: func.func private @negative_expand_right_discarded_indices(
// CHECK-SAME: %[[SRC:.*]]: memref<108x1xf32>) {
-func.func private @negative_reshape_expand_right_discarded_indices(
+func.func private @negative_expand_right_discarded_indices(
%src : memref<108x1xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[RC:.*]] = memref.reinterpret_cast %[[SRC]] to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108] : memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
+ // CHECK: %[[RC:.*]] = memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %src
to offset: [0], sizes: [108, 1, 1, 1], strides: [1, 108, 108, 108]
: memref<108x1xf32> to memref<108x1x1x1xf32, strided<[1, 108, 108, 108]>>
More information about the Mlir-commits
mailing list