[Mlir-commits] [mlir] [mlir][tensor] Add support for tensor.pack static shapes inference. (PR #80848)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Feb 13 00:28:32 PST 2024
================
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.finalizeOpModification(packOp);
return success();
}
+
+ // Insert tensor.cast ops if static shape inference is available..
+ SmallVector<int64_t> srcShape, destShape;
+ if (inferStaticShape(packOp, srcShape, destShape)) {
+ Location loc = packOp.getLoc();
+ Value source = packOp.getSource();
+ if (srcShape != packOp.getSourceType().getShape()) {
+ auto newSrcType = packOp.getSourceType().clone(srcShape);
+ source =
+ rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+ }
+ Value dest = packOp.getDest();
+ if (destShape != packOp.getDestType().getShape()) {
+ auto newDestType = packOp.getDestType().clone(destShape);
+ dest =
+ rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+ }
+ Value newOp = rewriter.create<tensor::PackOp>(
+ loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
+ packOp.getPaddingValue(), packOp.getOuterDimsPerm());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ packOp, packOp.getResult().getType(), newOp);
----------------
hanhanW wrote:
This is common in [Linalg](https://github.com/llvm/llvm-project/blob/d57515bd107bc76df5a042ffee2b7dc6125ffef1/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp#L2259-L2262) and I think we can add such functionality to tensor ops. Good point on pack/unpack folding. I think they can be folded if the order of applying patterns is correct. To make the result IR deterministic, we might need it for unpack. So yes, I will prepare a patch and send it out for review. Do you think it is better to land both patterns together? If so, I will put the update to the PR. If it does not matter, I will land it as a follow-up.
https://github.com/llvm/llvm-project/pull/80848
More information about the Mlir-commits
mailing list