[Mlir-commits] [mlir] [mlir][tensor] Introduce `FoldTensorCastUnPackOp` (PR #121393)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 2 07:13:17 PST 2025


================
@@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
   }
 };
 
+/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+///   %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+///   %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %2 = tensor.unpack %0  ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
+  using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(UnPackOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!foldTensorCastPrecondition(op))
+      return failure();
+
+    SmallVector<Type> newResultTypes(op->getResultTypes());
+    SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+    Value sourceTensor = newOperands[0];
+
+    // Get the updated mixed-tile-sizes attribute.
+    SmallVector<OpFoldResult> newMixedTileSizes;
+    for (auto it : llvm::zip(cast<ShapedType>(sourceTensor.getType())
+                                 .getShape()
+                                 .take_back(op.getMixedTiles().size()),
+                             op.getMixedTiles())) {
+      int64_t shape = std::get<0>(it);
+      // If the current source shape is dynamic, just preserve this mixed
+      // size.
+      if (shape == ShapedType::kDynamic) {
+        newMixedTileSizes.push_back(std::get<1>(it));
+        continue;
+      }
+
+      // If the current source is static, update the dynamic mixed-size
+      // (provided the original value is dynamic).
+      if (Attribute attr =
+              llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
----------------
Max191 wrote:

nit: Can you add a local variable (similar to what you did above with `shape`) for the second iterator's value (e.g., something like `tile`)? I think it makes it more clear what the iterator is when reading the code.

https://github.com/llvm/llvm-project/pull/121393


More information about the Mlir-commits mailing list