[Mlir-commits] [mlir] [mlir][tensor] Add support for tensor.pack static shapes inference. (PR #80848)

Han-Chung Wang llvmlistbot at llvm.org
Tue Feb 13 00:28:32 PST 2024


================
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     rewriter.finalizeOpModification(packOp);
     return success();
   }
+
+  // Insert tensor.cast ops if static shape inference is available..
+  SmallVector<int64_t> srcShape, destShape;
+  if (inferStaticShape(packOp, srcShape, destShape)) {
+    Location loc = packOp.getLoc();
+    Value source = packOp.getSource();
+    if (srcShape != packOp.getSourceType().getShape()) {
+      auto newSrcType = packOp.getSourceType().clone(srcShape);
+      source =
+          rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+    }
+    Value dest = packOp.getDest();
+    if (destShape != packOp.getDestType().getShape()) {
+      auto newDestType = packOp.getDestType().clone(destShape);
+      dest =
+          rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+    }
+    Value newOp = rewriter.create<tensor::PackOp>(
+        loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
+        packOp.getPaddingValue(), packOp.getOuterDimsPerm());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        packOp, packOp.getResult().getType(), newOp);
----------------
hanhanW wrote:

This is common in [Linalg](https://github.com/llvm/llvm-project/blob/d57515bd107bc76df5a042ffee2b7dc6125ffef1/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp#L2259-L2262) and I think we can add such functionality to tensor ops. Good point on pack/unpack folding. I think they can be folded if the order of applying patterns is correct. To make the result IR deterministic, we might need it for unpack. So yes, I will prepare a patch and send it out for review. Do you think it is better to land both patterns together? If so, I will put the update to the PR. If it does not matter, I will land it as a follow-up.



https://github.com/llvm/llvm-project/pull/80848


More information about the Mlir-commits mailing list