[Mlir-commits] [mlir] [mlir][vector] Canonicalize/fold 'order preserving' transposes (PR #135841)

Diego Caballero llvmlistbot at llvm.org
Fri May 2 09:37:26 PDT 2025


================
@@ -5583,17 +5611,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
   if (getSource().getType() == resultType)
     return getSource();
 
-  // Y = shape_cast(shape_cast(X)))
-  //      -> X, if X and Y have same type
-  //      -> shape_cast(X) otherwise.
-  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-    VectorType srcType = otherOp.getSource().getType();
-    if (resultType == srcType)
-      return otherOp.getSource();
-    setOperand(otherOp.getSource());
+  // shape_cast(shape_cast(x)) -> shape_cast(x)
+  if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
+    setOperand(precedingShapeCast.getSource());
     return getResult();
   }
 
+  // shape_cast(transpose(x)) -> shape_cast(x)
+  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+    // This folder does
+    //    shape_cast(transpose) -> shape_cast
+    // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
+    //    shape_cast -> shape_cast(transpose)
+    // i.e. the complete opposite. When paired, these 2 patterns can cause
+    // infinite cycles in pattern rewriting.
+    // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
+    // vectors, so by disabling this folder for scalable vectors the
+    // cycle is avoided.
+    // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
----------------
dcaballe wrote:

Do you know why we generate illegal shape cast ops in first place? It sounds like something that shouldn't happen...

https://github.com/llvm/llvm-project/pull/135841


More information about the Mlir-commits mailing list