[Mlir-commits] [mlir] [MLIR][MemRef] Fix AllocOp/AllocaOp flattening domination violation (PR #188980)
Hocky Yudhiono
llvmlistbot at llvm.org
Thu Apr 2 04:51:24 PDT 2026
================
@@ -230,19 +210,98 @@ static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
.Default([&](auto op) { return success(); });
}
+// Pattern for memref::AllocOp and memref::AllocaOp.
+//
+// The "source" memref for these ops IS the op's own result, so the generic
+// MemRefRewritePattern cannot be used: getFlattenMemrefAndOffset would insert
+// ExtractStridedMetadataOp and ReinterpretCastOp that use op.result BEFORE op
+// in the block. After replaceOpWithNewOp the original result is RAUW'd to the
+// new ReinterpretCastOp, leaving the earlier ops with forward references
+// (domination violations) caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.
+//
+// Instead, sizes and strides are computed from the op's operands and type
+// (which all dominate the op), avoiding any reference to op.result until the
+// final replaceOpWithNewOp inside castAllocResult.
+template <typename T>
+struct AllocLikeFlattenPattern : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
+ LogicalResult matchAndRewrite(T op,
+ PatternRewriter &rewriter) const override {
+ if (!needFlattening(op.getMemref()) || !checkLayout(op.getMemref()))
+ return failure();
+
+ Location loc = op->getLoc();
+ auto memrefType = cast<MemRefType>(op.getType());
+ auto elemType = memrefType.getElementType();
+ if (!elemType.isIntOrFloat())
+ return failure();
+ unsigned elemBitWidth = elemType.getIntOrFloatBitWidth();
+
+ SmallVector<OpFoldResult> sizes = op.getMixedSizes();
+
+ int64_t staticOffset;
+ SmallVector<int64_t> staticStrides;
+ if (failed(memrefType.getStridesAndOffset(staticStrides, staticOffset)))
+ return failure();
+ // Reject non-zero or dynamic base offsets (alloc results start at 0).
+ if (staticOffset != 0)
+ return failure();
+ SmallVector<OpFoldResult> strides;
+ strides.reserve(staticStrides.size());
+ for (int64_t stride : staticStrides) {
+ if (stride == ShapedType::kDynamic)
+ return failure();
+ strides.push_back(rewriter.getIndexAttr(stride));
+ }
+
+ // Compute the linearized flat size from sizes and strides (no SSA ops
+ // referencing op.result are created here).
+ memref::LinearizedMemRefInfo linearizedInfo;
+ OpFoldResult linearizedOffset;
+ std::tie(linearizedInfo, linearizedOffset) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, elemBitWidth, elemBitWidth, rewriter.getIndexAttr(0),
+ sizes, strides);
+ (void)linearizedOffset;
+
+ // Build the flat 1-D MemRefType. The linearized size may be static or
+ // dynamic (OpFoldResult of either IntegerAttr or a Value).
+ int64_t flatDimSize = ShapedType::kDynamic;
+ if (auto attr = dyn_cast<Attribute>(linearizedInfo.linearizedSize))
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+ flatDimSize = intAttr.getInt();
+
+ auto flatMemrefType =
+ MemRefType::get({flatDimSize}, memrefType.getElementType(),
+ StridedLayoutAttr::get(rewriter.getContext(), 0, {1}),
+ memrefType.getMemorySpace());
+
+ // Collect the flat dynamic-size operand (empty for fully-static case).
+ SmallVector<Value, 1> dynSizes;
+ if (flatDimSize == ShapedType::kDynamic)
+ dynSizes.push_back(getValueFromOpFoldResult(
+ rewriter, loc, linearizedInfo.linearizedSize));
+
+ auto newOp = T::create(rewriter, loc, flatMemrefType, dynSizes,
----------------
hockyy wrote:
```suggestion
auto newOp = AllocLikeOp::create(rewriter, loc, flatMemrefType, dynSizes,
```
https://github.com/llvm/llvm-project/pull/188980
More information about the Mlir-commits
mailing list