[Mlir-commits] [mlir] [MLIR][Tensor] Fix Chained tensor.cast canonicalization pattern (PR #113551)

Vivek Khandelwal llvmlistbot at llvm.org
Thu Oct 24 04:49:39 PDT 2024


https://github.com/vivekkhandelwal1 created https://github.com/llvm/llvm-project/pull/113551

This commit fixes the bug with the chained tensor.cast canonicalization pattern. When the sourceType and itermediateType both contains a dim which is static and not equal then the joinShapes utility returns a null value. And, this null value during the next call to the joinShapes utility results in a crash.

Although, this instance of tensor.cast is invalid since the operand shape and result shape are incompatible but in any case the code should not crash, and this commit particularly fixes this kind of case.

>From cb008443422319bd55ad6e20e1e025bea687759b Mon Sep 17 00:00:00 2001
From: Vivek Khandelwal <vivekkhandelwal1424 at gmail.com>
Date: Thu, 24 Oct 2024 11:42:19 +0000
Subject: [PATCH] [MLIR][Tensor] Fix Chained tensor.cast canonicalization
 pattern

This commit fixes the bug with the chained tensor.cast canonicalization
pattern. When the sourceType and itermediateType both contains a dim
which is static and not equal then the joinShapes utility returns a
null value. And, this null value during the next call to the joinShapes
utility results in a crash.

Although, this instance of tensor.cast is invalid since the operand shape
and result shape are incompatible but in any case the code should not crash,
and this commit particularly fixes this kind of case.

Signed-Off-By: Vivek Khandelwal <vivekkhandelwal1424 at gmail.com>
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 603e86ca3d7668..13af1497d3790e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -434,17 +434,23 @@ struct ChainedTensorCast : public OpRewritePattern<CastOp> {
     // We can remove the intermediate cast if joining all three produces the
     // same result as just joining the source and result shapes.
     auto firstJoin =
-        joinShapes(joinShapes(sourceType, intermediateType), resultType);
+        joinShapes(sourceType, intermediateType);
 
     // The join might not exist if the cast sequence would fail at runtime.
     if (!firstJoin)
       return failure();
 
+    auto secondJoin = joinShapes(firstJoin, resultType);
+
+    // The join might not exist if the cast sequence would fail at runtime.
+    if (!secondJoin)
+      return failure();
+
     // The newJoin always exists if the above join exists, it might just contain
     // less information. If so, we cannot drop the intermediate cast, as doing
     // so would remove runtime checks.
     auto newJoin = joinShapes(sourceType, resultType);
-    if (firstJoin != newJoin)
+    if (secondJoin != newJoin)
       return failure();
 
     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,



More information about the Mlir-commits mailing list