[Mlir-commits] [mlir] [MLIR][Tensor] Perform shape inference via in-place modification (NFC) (PR #111593)

Mehdi Amini llvmlistbot at llvm.org
Tue Oct 8 15:31:16 PDT 2024


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/111593

>From d927131f0bed8f6855a345bdcda57c57850e4f76 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Tue, 8 Oct 2024 14:37:54 -0700
Subject: [PATCH 1/2] [MLIR][Tensor] Perform shape inference via in-place
 modification (NFC)

This is more efficient to avoid a clone that is immediately removed.
Also guard the insertion of a cast on the result on whether the destination
type changed.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 23 +++++++++++++----------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 659eabd2e93880..0ac0899def21b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,21 +4332,24 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
           rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
+    Type originalResultType = dest.getType();
     if (destShape != packOp.getDestType().getShape()) {
       auto newDestType = packOp.getDestType().clone(destShape);
       dest =
           rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
     }
-    auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
-    Value res = clonedPackOp.getResult();
-    rewriter.startOpModification(clonedPackOp);
-    clonedPackOp.getSourceMutable().assign(source);
-    clonedPackOp.getDestMutable().assign(dest);
-    res.setType(dest.getType());
-    rewriter.finalizeOpModification(clonedPackOp);
-
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(
-        packOp, packOp.getResult().getType(), clonedPackOp);
+    rewriter.modifyOpInPlace(packOp, [&] {
+      packOp.getSourceMutable().assign(source);
+      packOp.getDestMutable().assign(dest);
+      packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+    });
+    // Insert a cast if needed
+    if (originalResultType != dest.getType()) {
+      rewriter.setInsertionPointAfter(packOp);
+      auto castOp =
+          rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+      rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+    }
     return success();
   }
 

>From 8d8b16a894f93f1826b409d5b798ab6f109199b5 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 9 Oct 2024 00:31:07 +0200
Subject: [PATCH 2/2] Update TensorOps.cpp

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0ac0899def21b5..8e79167ec9c7c9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,8 +4332,8 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
           rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
-    Type originalResultType = dest.getType();
-    if (destShape != packOp.getDestType().getShape()) {
+    bool needUpdateDestType = (destShape != packOp.getDestType().getShape());
+    if (needUpdateDestType) {
       auto newDestType = packOp.getDestType().clone(destShape);
       dest =
           rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
@@ -4344,7 +4344,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
       packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
     });
     // Insert a cast if needed
-    if (originalResultType != dest.getType()) {
+    if (needUpdateDestType) {
       rewriter.setInsertionPointAfter(packOp);
       auto castOp =
           rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);



More information about the Mlir-commits mailing list