[Mlir-commits] [mlir] [memref] Simplify loads from reinterpret_cast of 1D contiguous memrefs (PR #188459)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Apr 23 05:20:43 PDT 2026
================
@@ -196,6 +198,234 @@ struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
}
};
+/// Describes the unique non-unit dimension of a MemRef shape.
+///
+/// This helper is only used for shapes that have at most one non-unit
+/// dimension. `exists` is false for all-ones shapes. Otherwise, `isOnLeft`
+/// indicates whether the non-unit dimension is on the left boundary.
+///
+/// If `exists` is true and `isOnLeft` is false, the non-unit dimension is on
+/// the right boundary. Rank-1 non-unit MemRefs are treated as matching both
+/// boundaries and callers that care about the right boundary must account for
+/// that from the MemRef type.
+struct SingleNonUnitDimInfo {
+ bool exists = false;
+ bool isOnLeft = false;
+};
+
+/// Returns information about a MemRef if it contains at most one non-unit
+/// dimension.
+///
+/// The single non-unit dimension, if present, must be on the left or right
+/// boundary. Rank-1 non-unit MemRefs are treated as being on both boundaries.
+static std::optional<SingleNonUnitDimInfo>
+getSingleNonUnitDimInfo(MemRefType type) {
+ ArrayRef<int64_t> shape = type.getShape();
+ int64_t nonUnitCount =
+ llvm::count_if(shape, [](int64_t dim) { return dim != 1; });
+ // Return default values if missing nonUnitDim
+ if (nonUnitCount == 0)
+ return SingleNonUnitDimInfo{};
+ // Return no info if MemRef breaks nonUnitDim requirements (more nonUnitDims)
+ if (nonUnitCount > 1)
+ return std::nullopt;
+
+ bool isOnLeft = shape.front() != 1;
+ // Return no info if MemRef breaks nonUnitDim requirements (nonUnitDim in
+ // non-boundary pos)
+ if (!isOnLeft && shape.back() == 1)
+ return std::nullopt;
+
+ return SingleNonUnitDimInfo{/*exists=*/true, isOnLeft};
+}
+
+static bool hasStaticZeroOffset(memref::ReinterpretCastOp rc) {
+ ArrayRef<int64_t> offsets = rc.getStaticOffsets();
+ // FIXME: Despite what `getStaticOffsets` implies, `reinterpret_cast` takes
+ // only a single offset. That should be fixed at the op definition level.
+ assert(offsets.size() == 1 && "Expecting single offset");
+ return !ShapedType::isDynamic(offsets[0]) && offsets[0] == 0;
+}
+
+static std::optional<int64_t> getConstantIndex(Value v) {
+ if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
+ return cst.value();
+ return std::nullopt;
+}
+
+static bool isConstantIndexExplicitlyOutOfBounds(Value idx,
+ int64_t upperBound) {
+ std::optional<int64_t> idxVal = getConstantIndex(idx);
+ return idxVal && (*idxVal < 0 || *idxVal >= upperBound);
+}
+
+/// Examples accepted by this shape restriction:
+/// memref<999xf32> <-> memref<1x1x999xf32>
+/// memref<1x108xf32> <-> memref<1x1x1x108xf32>
+/// memref<100x1xf32> <-> memref<100x1x1xf32>
+/// memref<1> <-> memref<1x1x1>
+///
+/// General reinterpret_casts are intentionally rejected.
+static bool isPureRankExpansionOrCollapsingRC(memref::ReinterpretCastOp rc) {
+ auto inputTy = cast<MemRefType>(rc.getSource().getType());
+ auto outputTy = cast<MemRefType>(rc.getResult().getType());
+
+ // This rewrite assumes "index re-use" and misses "index
+ // re-write/adjustment" logic, hence the requirement for the offset to be 0.
+ // Thus, storage shift and statically unknown offsets are rejected.
+ if (!hasStaticZeroOffset(rc))
+ return false;
+
+ // The check assumes the rewrite relies on completely static shape info.
+ if (llvm::any_of(rc.getStaticSizes(), ShapedType::isDynamic) ||
+ llvm::any_of(rc.getStaticStrides(), ShapedType::isDynamic))
+ return false;
+
+ // The check assumes the rewrite supports shapes with at most one non-unit
+ // dimension. This excludes underlying multi-dimensional layouts and keeps the
+ // rewrite limited to unit-dim insertion/removal `reinterpret_cast`s.
+ std::optional<SingleNonUnitDimInfo> inputNonUnitDim =
+ getSingleNonUnitDimInfo(inputTy);
+ std::optional<SingleNonUnitDimInfo> outputNonUnitDim =
+ getSingleNonUnitDimInfo(outputTy);
+ // Bail out early if nonUnitDims don't follow rewrite assumptions.
+ if (!inputNonUnitDim || !outputNonUnitDim)
+ return false;
+
+ // The source and result must either both have a single non-unit dimension
+ // or both be all-ones.
+ if (inputNonUnitDim->exists != outputNonUnitDim->exists)
+ return false;
+ if (!inputNonUnitDim->exists)
+ return true;
+
+ // The preserved non-unit dimension must have the same size.
+ if (inputTy.getDimSize(inputNonUnitDim->isOnLeft ? 0
+ : inputTy.getRank() - 1) !=
+ outputTy.getDimSize(outputNonUnitDim->isOnLeft ? 0
+ : outputTy.getRank() - 1))
+ return false;
+
+ // If both sides have rank > 1, the non-unit dimension must be on the same
+ // boundary. Rank-1 MemRefs are accepted against either boundary.
+ if (inputTy.getRank() != 1 && outputTy.getRank() != 1 &&
+ inputNonUnitDim->isOnLeft != outputNonUnitDim->isOnLeft)
+ return false;
----------------
banach-space wrote:
How about:
```mlir
%reinterpret_cast = memref.reinterpret_cast %src
: memref<1x1x1x100xf32> to memref<100x1x1xf32, strided<[100, 1, 100]>>
```
? This is interesting as in isolation, but input and output meet the criteria for your transformation. However, the Op itself does not.
https://github.com/llvm/llvm-project/pull/188459
More information about the Mlir-commits
mailing list