[Mlir-commits] [mlir] [mlir][tensor] Add support for tensor.pack static shapes inference. (PR #80848)
Quinn Dawkins
llvmlistbot at llvm.org
Sat Feb 10 16:27:21 PST 2024
================
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.finalizeOpModification(packOp);
return success();
}
+
+ // Insert tensor.cast ops if static shape inference is available..
+ SmallVector<int64_t> srcShape, destShape;
+ if (inferStaticShape(packOp, srcShape, destShape)) {
+ Location loc = packOp.getLoc();
+ Value source = packOp.getSource();
+ if (srcShape != packOp.getSourceType().getShape()) {
+ auto newSrcType = packOp.getSourceType().clone(srcShape);
+ source =
+ rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+ }
+ Value dest = packOp.getDest();
+ if (destShape != packOp.getDestType().getShape()) {
+ auto newDestType = packOp.getDestType().clone(destShape);
+ dest =
+ rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+ }
+ Value newOp = rewriter.create<tensor::PackOp>(
+ loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
+ packOp.getPaddingValue(), packOp.getOuterDimsPerm());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ packOp, packOp.getResult().getType(), newOp);
----------------
qedawkins wrote:
Is this well suited for a canonicalization? I'm wondering about cases where a `pack` and `unpack` could have folded away but this pattern introduces a `tensor.cast` in the middle. Maybe we need the same pattern for unpack too?
https://github.com/llvm/llvm-project/pull/80848
More information about the Mlir-commits
mailing list