[Mlir-commits] [mlir] [MLIR][Tensor] Fix Chained tensor.cast canonicalization pattern (PR #113551)
Vivek Khandelwal
llvmlistbot at llvm.org
Thu Oct 24 04:53:38 PDT 2024
https://github.com/vivekkhandelwal1 updated https://github.com/llvm/llvm-project/pull/113551
>From cb008443422319bd55ad6e20e1e025bea687759b Mon Sep 17 00:00:00 2001
From: Vivek Khandelwal <vivekkhandelwal1424 at gmail.com>
Date: Thu, 24 Oct 2024 11:42:19 +0000
Subject: [PATCH 1/2] [MLIR][Tensor] Fix Chained tensor.cast canonicalization
pattern
This commit fixes the bug with the chained tensor.cast canonicalization
pattern. When the sourceType and itermediateType both contains a dim
which is static and not equal then the joinShapes utility returns a
null value. And, this null value during the next call to the joinShapes
utility results in a crash.
Although, this instance of tensor.cast is invalid since the operand shape
and result shape are incompatible but in any case the code should not crash,
and this commit particularly fixes this kind of case.
Signed-Off-By: Vivek Khandelwal <vivekkhandelwal1424 at gmail.com>
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 603e86ca3d7668..13af1497d3790e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -434,17 +434,23 @@ struct ChainedTensorCast : public OpRewritePattern<CastOp> {
// We can remove the intermediate cast if joining all three produces the
// same result as just joining the source and result shapes.
auto firstJoin =
- joinShapes(joinShapes(sourceType, intermediateType), resultType);
+ joinShapes(sourceType, intermediateType);
// The join might not exist if the cast sequence would fail at runtime.
if (!firstJoin)
return failure();
+ auto secondJoin = joinShapes(firstJoin, resultType);
+
+ // The join might not exist if the cast sequence would fail at runtime.
+ if (!secondJoin)
+ return failure();
+
// The newJoin always exists if the above join exists, it might just contain
// less information. If so, we cannot drop the intermediate cast, as doing
// so would remove runtime checks.
auto newJoin = joinShapes(sourceType, resultType);
- if (firstJoin != newJoin)
+ if (secondJoin != newJoin)
return failure();
rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
>From e882bcbb5bdd2deae771e839ce9faa1906f1ff3a Mon Sep 17 00:00:00 2001
From: Vivek Khandelwal <vivekkhandelwal1424 at gmail.com>
Date: Thu, 24 Oct 2024 11:53:25 +0000
Subject: [PATCH 2/2] Fix code formatting
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 13af1497d3790e..d1b73ff2dbd0c7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -433,8 +433,7 @@ struct ChainedTensorCast : public OpRewritePattern<CastOp> {
// We can remove the intermediate cast if joining all three produces the
// same result as just joining the source and result shapes.
- auto firstJoin =
- joinShapes(sourceType, intermediateType);
+ auto firstJoin = joinShapes(sourceType, intermediateType);
// The join might not exist if the cast sequence would fail at runtime.
if (!firstJoin)
More information about the Mlir-commits
mailing list