[Mlir-commits] [mlir] [mlir][tensor] Add support for tensor.pack static shapes inference. (PR #80848)

Quinn Dawkins llvmlistbot at llvm.org
Sat Feb 10 16:27:21 PST 2024


================
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     rewriter.finalizeOpModification(packOp);
     return success();
   }
+
+  // Insert tensor.cast ops if static shape inference is available..
+  SmallVector<int64_t> srcShape, destShape;
+  if (inferStaticShape(packOp, srcShape, destShape)) {
+    Location loc = packOp.getLoc();
+    Value source = packOp.getSource();
+    if (srcShape != packOp.getSourceType().getShape()) {
+      auto newSrcType = packOp.getSourceType().clone(srcShape);
+      source =
+          rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+    }
+    Value dest = packOp.getDest();
+    if (destShape != packOp.getDestType().getShape()) {
+      auto newDestType = packOp.getDestType().clone(destShape);
+      dest =
+          rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+    }
+    Value newOp = rewriter.create<tensor::PackOp>(
+        loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
+        packOp.getPaddingValue(), packOp.getOuterDimsPerm());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        packOp, packOp.getResult().getType(), newOp);
----------------
qedawkins wrote:

Is this well suited for a canonicalization? I'm wondering about cases where a `pack` and `unpack` could have folded away but this pattern introduces a `tensor.cast` in the middle. Maybe we need the same pattern for unpack too?

https://github.com/llvm/llvm-project/pull/80848


More information about the Mlir-commits mailing list