[Mlir-commits] [mlir] [mlir][vector] Canonicalize/fold 'order preserving' transposes (PR #135841)
Diego Caballero
llvmlistbot at llvm.org
Fri May 2 09:37:26 PDT 2025
================
@@ -5583,17 +5611,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
if (getSource().getType() == resultType)
return getSource();
- // Y = shape_cast(shape_cast(X)))
- // -> X, if X and Y have same type
- // -> shape_cast(X) otherwise.
- if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
- VectorType srcType = otherOp.getSource().getType();
- if (resultType == srcType)
- return otherOp.getSource();
- setOperand(otherOp.getSource());
+ // shape_cast(shape_cast(x)) -> shape_cast(x)
+ if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
+ setOperand(precedingShapeCast.getSource());
return getResult();
}
+ // shape_cast(transpose(x)) -> shape_cast(x)
+ if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+ // This folder does
+ // shape_cast(transpose) -> shape_cast
+ // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
+ // shape_cast -> shape_cast(transpose)
+ // i.e. the complete opposite. When paired, these 2 patterns can cause
+ // infinite cycles in pattern rewriting.
+ // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
+ // vectors, so by disabling this folder for scalable vectors the
+ // cycle is avoided.
+ // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
----------------
dcaballe wrote:
Do you know why we generate illegal shape cast ops in first place? It sounds like something that shouldn't happen...
https://github.com/llvm/llvm-project/pull/135841
More information about the Mlir-commits
mailing list