[Mlir-commits] [mlir] [MLIR][Tensor] Perform shape inference via in-place modification (NFC) (PR #111593)
Mehdi Amini
llvmlistbot at llvm.org
Wed Oct 9 00:41:30 PDT 2024
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/111593
>From d927131f0bed8f6855a345bdcda57c57850e4f76 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Tue, 8 Oct 2024 14:37:54 -0700
Subject: [PATCH 1/5] [MLIR][Tensor] Perform shape inference via in-place
modification (NFC)
This is more efficient to avoid a clone that is immediately removed.
Also guard the insertion of a cast on the result on whether the destination
type changed.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 23 +++++++++++++----------
1 file changed, 13 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 659eabd2e93880..0ac0899def21b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,21 +4332,24 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
+ Type originalResultType = dest.getType();
if (destShape != packOp.getDestType().getShape()) {
auto newDestType = packOp.getDestType().clone(destShape);
dest =
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
}
- auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
- Value res = clonedPackOp.getResult();
- rewriter.startOpModification(clonedPackOp);
- clonedPackOp.getSourceMutable().assign(source);
- clonedPackOp.getDestMutable().assign(dest);
- res.setType(dest.getType());
- rewriter.finalizeOpModification(clonedPackOp);
-
- rewriter.replaceOpWithNewOp<tensor::CastOp>(
- packOp, packOp.getResult().getType(), clonedPackOp);
+ rewriter.modifyOpInPlace(packOp, [&] {
+ packOp.getSourceMutable().assign(source);
+ packOp.getDestMutable().assign(dest);
+ packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+ });
+ // Insert a cast if needed
+ if (originalResultType != dest.getType()) {
+ rewriter.setInsertionPointAfter(packOp);
+ auto castOp =
+ rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+ rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+ }
return success();
}
>From 8d8b16a894f93f1826b409d5b798ab6f109199b5 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 9 Oct 2024 00:31:07 +0200
Subject: [PATCH 2/5] Update TensorOps.cpp
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0ac0899def21b5..8e79167ec9c7c9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,8 +4332,8 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- Type originalResultType = dest.getType();
- if (destShape != packOp.getDestType().getShape()) {
+ bool needUpdateDestType = (destShape != packOp.getDestType().getShape());
+ if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
dest =
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
@@ -4344,7 +4344,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
});
// Insert a cast if needed
- if (originalResultType != dest.getType()) {
+ if (needUpdateDestType) {
rewriter.setInsertionPointAfter(packOp);
auto castOp =
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
>From 3403ef52a0c824426e7f69758f0c202b1889fe29 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 9 Oct 2024 01:41:28 +0200
Subject: [PATCH 3/5] Update TensorOps.cpp
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8e79167ec9c7c9..c781644477c83c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,7 +4332,8 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- bool needUpdateDestType = (destShape != packOp.getDestType().getShape());
+ Type originalResultType = packOp.getDestType();
+ bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
dest =
>From 362f7e7fc4bdc99758e0aa1cd210942d788c9005 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 9 Oct 2024 01:57:51 +0200
Subject: [PATCH 4/5] Update TensorOps.cpp
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index c781644477c83c..931c565785ad72 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,7 +4332,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- Type originalResultType = packOp.getDestType();
+ auto originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
>From 26f1b3b5101c510ebb05d56a201e3ed2da9bcb63 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Wed, 9 Oct 2024 09:41:15 +0200
Subject: [PATCH 5/5] Update TensorOps.cpp
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 931c565785ad72..4d6c5965c4fcc3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,7 +4332,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- auto originalResultType = packOp.getDestType();
+ RankedTensorType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
More information about the Mlir-commits
mailing list