[Mlir-commits] [mlir] [MLIR][Tensor] Fix Chained tensor.cast canonicalization pattern (PR #113551)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 24 04:50:23 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Vivek Khandelwal (vivekkhandelwal1)

<details>
<summary>Changes</summary>

This commit fixes the bug with the chained tensor.cast canonicalization pattern. When the sourceType and itermediateType both contains a dim which is static and not equal then the joinShapes utility returns a null value. And, this null value during the next call to the joinShapes utility results in a crash.

Although, this instance of tensor.cast is invalid since the operand shape and result shape are incompatible but in any case the code should not crash, and this commit particularly fixes this kind of case.

---
Full diff: https://github.com/llvm/llvm-project/pull/113551.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+8-2) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 603e86ca3d7668..13af1497d3790e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -434,17 +434,23 @@ struct ChainedTensorCast : public OpRewritePattern<CastOp> {
     // We can remove the intermediate cast if joining all three produces the
     // same result as just joining the source and result shapes.
     auto firstJoin =
-        joinShapes(joinShapes(sourceType, intermediateType), resultType);
+        joinShapes(sourceType, intermediateType);
 
     // The join might not exist if the cast sequence would fail at runtime.
     if (!firstJoin)
       return failure();
 
+    auto secondJoin = joinShapes(firstJoin, resultType);
+
+    // The join might not exist if the cast sequence would fail at runtime.
+    if (!secondJoin)
+      return failure();
+
     // The newJoin always exists if the above join exists, it might just contain
     // less information. If so, we cannot drop the intermediate cast, as doing
     // so would remove runtime checks.
     auto newJoin = joinShapes(sourceType, resultType);
-    if (firstJoin != newJoin)
+    if (secondJoin != newJoin)
       return failure();
 
     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,

``````````

</details>


https://github.com/llvm/llvm-project/pull/113551


More information about the Mlir-commits mailing list