[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Apr 1 15:22:11 PDT 2025
================
@@ -4930,23 +5002,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
return success();
}
- // Insert tensor.cast ops if static shape inference is available..
+ // Insert either tensor.cast or memref.cast ops
+ // if static shape inference is available..
+ bool hasTensorSemantics = packOp.hasPureTensorSemantics();
+
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(packOp, srcShape, destShape)) {
Location loc = packOp.getLoc();
Value source = packOp.getSource();
if (srcShape != packOp.getSourceType().getShape()) {
auto newSrcType = packOp.getSourceType().clone(srcShape);
- source =
- rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+ if (hasTensorSemantics)
+ source = rewriter.create<tensor::CastOp>(loc, newSrcType,
+ packOp.getSource());
+ else
+ source = rewriter.create<memref::CastOp>(loc, newSrcType,
+ packOp.getSource());
}
Value dest = packOp.getDest();
- RankedTensorType originalResultType = packOp.getDestType();
+ ShapedType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
- dest =
- rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+ if (hasTensorSemantics)
+ dest =
+ rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
----------------
hanhanW wrote:
There are three types in the pack ops on tensors. (1) source type (2) dest type (3) result type.
In the shape inference, we need casting for (1) and (2), so here you also need to take memref into account. (A new test will capture the failure). For (3), where is updated in the `modifyOpInPlace{...}`, we update the result type if and only if it is on tensors.
The (3) only happens on tensors because memref variant only has (1) and (2) types.
https://github.com/llvm/llvm-project/pull/129036
More information about the Mlir-commits
mailing list