[Mlir-commits] [mlir] [mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (PR #189533)
Hocky Yudhiono
llvmlistbot at llvm.org
Tue Mar 31 00:24:21 PDT 2026
================
@@ -2273,41 +2273,70 @@ struct ReinterpretCastOpConstantFolder
LogicalResult matchAndRewrite(ReinterpretCastOp op,
PatternRewriter &rewriter) const override {
- unsigned srcStaticCount = llvm::count_if(
- llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
- op.getMixedStrides()),
- [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
+ MemRefType resultType = op.getType();
+ SmallVector<OpFoldResult> srcSizes = op.getMixedSizes();
- SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
+ OpFoldResult offset = op.getConstifiedMixedOffset();
SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
- // TODO: Using counting comparison instead of direct comparison because
- // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
- // IntegerAttrs, while constifyIndexValues (and therefore
- // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
- if (srcStaticCount ==
- llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
- [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
- return failure();
+ int64_t layoutOffset = ShapedType::kDynamic;
- // Do not fold if the offset is a negative constant; ViewLikeInterface
- // verifies that static offsets are non-negative.
- if (auto cst = getConstantIntValue(offsets[0]))
+ if (auto cst = getConstantIntValue(offset)) {
+ // If the offset is a negative constant, we can't fold it because the
+ // resulting memref type would be invalid. In that case, we keep the
+ // original offset.
if (*cst < 0)
- return rewriter.notifyMatchFailure(
- op, "negative constant offset is invalid");
+ offset = op.getMixedOffsets()[0];
+ else
+ layoutOffset = *cst;
+ }
- // Do not fold if any size is a negative constant; MemRefType::get asserts
- // non-negative static sizes.
- for (OpFoldResult sizeOfr : sizes)
- if (auto cst = getConstantIntValue(sizeOfr))
- if (*cst < 0)
- return rewriter.notifyMatchFailure(
- op, "negative constant size is invalid");
+ int64_t rank = resultType.getRank();
+ int64_t lastStride = 1;
+ bool isContiguousMemrefType = (layoutOffset == 0);
+ SmallVector<int64_t> layoutStrides(rank), shapes(rank);
+
+ for (int64_t dim = rank - 1; dim >= 0; --dim) {
----------------
hockyy wrote:
I personally would separate this logic into another function.
https://github.com/llvm/llvm-project/pull/189533
More information about the Mlir-commits
mailing list