[Mlir-commits] [mlir] [MLIR][Tensor] Fix Chained tensor.cast canonicalization pattern (PR #113551)

Vivek Khandelwal llvmlistbot at llvm.org
Thu Oct 24 05:49:29 PDT 2024


================
@@ -433,18 +433,23 @@ struct ChainedTensorCast : public OpRewritePattern<CastOp> {
 
     // We can remove the intermediate cast if joining all three produces the
     // same result as just joining the source and result shapes.
-    auto firstJoin =
-        joinShapes(joinShapes(sourceType, intermediateType), resultType);
+    auto firstJoin = joinShapes(sourceType, intermediateType);
 
     // The join might not exist if the cast sequence would fail at runtime.
     if (!firstJoin)
       return failure();
 
+    auto secondJoin = joinShapes(firstJoin, resultType);
+
+    // The join might not exist if the cast sequence would fail at runtime.
+    if (!secondJoin)
----------------
vivekkhandelwal1 wrote:

The `!firstJoin` condition is already checked in the code above, and we can't merge both the conditions since the secondJoin will happen only if the firstJoin is a non-Null value.

https://github.com/llvm/llvm-project/pull/113551


More information about the Mlir-commits mailing list