[Mlir-commits] [mlir] [MLIR][MemRef] Fix AllocOp/AllocaOp flattening domination violation (PR #188980)
Hocky Yudhiono
llvmlistbot at llvm.org
Thu Apr 2 03:11:28 PDT 2026
================
@@ -230,19 +210,98 @@ static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
.Default([&](auto op) { return success(); });
}
+// Pattern for memref::AllocOp and memref::AllocaOp.
+//
+// The "source" memref for these ops IS the op's own result, so the generic
+// MemRefRewritePattern cannot be used: getFlattenMemrefAndOffset would insert
+// ExtractStridedMetadataOp and ReinterpretCastOp that use op.result BEFORE op
+// in the block. After replaceOpWithNewOp the original result is RAUW'd to the
+// new ReinterpretCastOp, leaving the earlier ops with forward references
+// (domination violations) caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.
+//
+// Instead, sizes and strides are computed from the op's operands and type
+// (which all dominate the op), avoiding any reference to op.result until the
+// final replaceOpWithNewOp inside castAllocResult.
+template <typename T>
+struct AllocLikeFlattenPattern : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
+ LogicalResult matchAndRewrite(T op,
+ PatternRewriter &rewriter) const override {
+ if (!needFlattening(op.getMemref()) || !checkLayout(op.getMemref()))
+ return failure();
+
+ Location loc = op->getLoc();
+ auto memrefType = cast<MemRefType>(op.getType());
+ auto elemType = memrefType.getElementType();
+ if (!elemType.isIntOrFloat())
+ return failure();
+ unsigned elemBitWidth = elemType.getIntOrFloatBitWidth();
+
+ SmallVector<OpFoldResult> sizes = op.getMixedSizes();
+
+ int64_t staticOffset;
+ SmallVector<int64_t> staticStrides;
+ if (failed(memrefType.getStridesAndOffset(staticStrides, staticOffset)))
+ return failure();
+ // Reject non-zero or dynamic base offsets (alloc results start at 0).
+ if (staticOffset != 0)
----------------
hockyy wrote:
In the old pass the behavior is supposed to be flattened I think #136797
> After the linearization, a MemRef's offset is kept, so a memref<4x8xf32, strided<[8, 1], offset: 100>> becomes memref<32xf32, strided<[1], offset: 100>>.
This would change the behavior?
https://github.com/llvm/llvm-project/pull/188980
More information about the Mlir-commits
mailing list