[Mlir-commits] [mlir] 275a2b0 - [MLIR][Tensor] Perform shape inference via in-place modification (NFC) (#111593)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 9 00:42:23 PDT 2024


Author: Mehdi Amini
Date: 2024-10-09T09:42:16+02:00
New Revision: 275a2b05813b2f10f403375abd72d1843e4544c3

URL: https://github.com/llvm/llvm-project/commit/275a2b05813b2f10f403375abd72d1843e4544c3
DIFF: https://github.com/llvm/llvm-project/commit/275a2b05813b2f10f403375abd72d1843e4544c3.diff

LOG: [MLIR][Tensor] Perform shape inference via in-place modification (NFC) (#111593)

This is more efficient to avoid a clone that is immediately removed. 
Also guard the insertion of a cast on the result on whether the
destination type changed.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 659eabd2e93880..4d6c5965c4fcc3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,21 +4332,25 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
           rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
-    if (destShape != packOp.getDestType().getShape()) {
+    RankedTensorType originalResultType = packOp.getDestType();
+    bool needUpdateDestType = (destShape != originalResultType.getShape());
+    if (needUpdateDestType) {
       auto newDestType = packOp.getDestType().clone(destShape);
       dest =
           rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
     }
-    auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
-    Value res = clonedPackOp.getResult();
-    rewriter.startOpModification(clonedPackOp);
-    clonedPackOp.getSourceMutable().assign(source);
-    clonedPackOp.getDestMutable().assign(dest);
-    res.setType(dest.getType());
-    rewriter.finalizeOpModification(clonedPackOp);
-
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(
-        packOp, packOp.getResult().getType(), clonedPackOp);
+    rewriter.modifyOpInPlace(packOp, [&] {
+      packOp.getSourceMutable().assign(source);
+      packOp.getDestMutable().assign(dest);
+      packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+    });
+    // Insert a cast if needed
+    if (needUpdateDestType) {
+      rewriter.setInsertionPointAfter(packOp);
+      auto castOp =
+          rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+      rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+    }
     return success();
   }
 


        


More information about the Mlir-commits mailing list