[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Apr 23 01:42:22 PDT 2026
================
@@ -196,6 +198,234 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
}
};
+/// Describes the unique non-unit dimension of a MemRef shape.
+///
+/// This helper is only used for shapes that have at most one non-unit
+/// dimension. `exists` is false for all-ones shapes. Otherwise, `isOnLeft`
+/// indicates whether the non-unit dimension is on the left boundary.
+///
+/// If `exists` is true and `isOnLeft` is false, the non-unit dimension is on
+/// the right boundary. Rank-1 non-unit MemRefs are treated as matching both
+/// boundaries and callers that care about the right boundary must account for
+/// that from the MemRef type.
+struct SingleNonUnitDimInfo {
+ bool exists = false;
+ bool isOnLeft = false;
+};
+
+/// Returns information about a MemRef if it contains at most one non-unit
+/// dimension.
+///
+/// The single non-unit dimension, if present, must be on the left or right
+/// boundary. Rank-1 non-unit MemRefs are treated as being on both boundaries.
+static std::optional<SingleNonUnitDimInfo>
+getSingleNonUnitDimInfo(MemRefType type) {
+ ArrayRef<int64_t> shape = type.getShape();
+ int64_t nonUnitCount =
+ llvm::count_if(shape, [](int64_t dim) { return dim != 1; });
+ // Return default values if missing nonUnitDim
+ if (nonUnitCount == 0)
+ return SingleNonUnitDimInfo{};
+ // Return no info if MemRef breaks nonUnitDim requirements (more nonUnitDims)
+ if (nonUnitCount > 1)
+ return std::nullopt;
+
+ bool isOnLeft = shape.front() != 1;
+ // Return no info if MemRef breaks nonUnitDim requirements (nonUnitDim in
+ // non-boundary pos)
+ if (!isOnLeft && shape.back() == 1)
+ return std::nullopt;
+
+ return SingleNonUnitDimInfo{/*exists=*/true, isOnLeft};
+}
+
+static bool hasStaticZeroOffset(memref::ReinterpretCastOp rc) {
+ ArrayRef<int64_t> offsets = rc.getStaticOffsets();
+ // FIXME: Despite what `getStaticOffsets` implies, `reinterpret_cast` takes
+ // only a single offset. That should be fixed at the op definition level.
+ assert(offsets.size() == 1 && "Expecting single offset");
+ return !ShapedType::isDynamic(offsets[0]) && offsets[0] == 0;
+}
+
+static std::optional<int64_t> getConstantIndex(Value v) {
+ if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
+ return cst.value();
+ return std::nullopt;
+}
+
+static bool isConstantIndexExplicitlyOutOfBounds(Value idx,
+ int64_t upperBound) {
+ std::optional<int64_t> idxVal = getConstantIndex(idx);
+ return idxVal && (*idxVal < 0 || *idxVal >= upperBound);
+}
+
+/// Examples accepted by this shape restriction:
+/// memref<999xf32> <-> memref<1x1x999xf32>
+/// memref<1x108xf32> <-> memref<1x1x1x108xf32>
+/// memref<100x1xf32> <-> memref<100x1x1xf32>
+/// memref<1> <-> memref<1x1x1>
+///
+/// General reinterpret_casts are intentionally rejected.
+static bool isPureRankExpansionOrCollapsingRC(memref::ReinterpretCastOp rc) {
+ auto inputTy = cast<MemRefType>(rc.getSource().getType());
+ auto outputTy = cast<MemRefType>(rc.getResult().getType());
+
+ // This rewrite assumes "index re-use" and misses "index
+ // re-write/adjustment" logic, hence the requirement for the offset to be 0.
+ // Thus, storage shift and statically unknown offsets are rejected.
+ if (!hasStaticZeroOffset(rc))
+ return false;
+
+ // The check assumes the rewrite relies on completely static shape info.
----------------
banach-space wrote:
[nit] We should aim to decouple checks and rewrites - this comment implies otherwise.
Basically, a "check" should provide some guarantees and that's it. Whether that's going to be used by a particular transformation is a separate thing. Also, importantly, we cannot guarantee that there won't be other transformations later using this helper method.
In this case, AFAIK, the presence of dynamic shapes would mean that we cannot really reason about the underlying reinterpret_cast and that's why we should reject these cases.
https://github.com/llvm/llvm-project/pull/188459
More information about the Mlir-commits
mailing list