[Mlir-commits] [mlir] [mlir][vector] Canonicalize/fold 'order preserving' transposes (PR #135841)
James Newling
llvmlistbot at llvm.org
Fri May 2 12:56:31 PDT 2025
================
@@ -5583,17 +5611,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
if (getSource().getType() == resultType)
return getSource();
- // Y = shape_cast(shape_cast(X)))
- // -> X, if X and Y have same type
- // -> shape_cast(X) otherwise.
- if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
- VectorType srcType = otherOp.getSource().getType();
- if (resultType == srcType)
- return otherOp.getSource();
- setOperand(otherOp.getSource());
+ // shape_cast(shape_cast(x)) -> shape_cast(x)
+ if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
+ setOperand(precedingShapeCast.getSource());
return getResult();
}
+ // shape_cast(transpose(x)) -> shape_cast(x)
+ if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+ // This folder does
+ // shape_cast(transpose) -> shape_cast
+ // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
+ // shape_cast -> shape_cast(transpose)
+ // i.e. the complete opposite. When paired, these 2 patterns can cause
+ // infinite cycles in pattern rewriting.
+ // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
+ // vectors, so by disabling this folder for scalable vectors the
+ // cycle is avoided.
+ // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
----------------
newling wrote:
They're illegal according to an arm specific lowering target, which I'm not familiar with
https://github.com/llvm/llvm-project/blob/1101b767329dd163d528fa5f667a6c0dbdde0ad5/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp#L747
I think @banach-space suspects that actually they're not illegal and will investigate the removal of this constraint here (and the pattern ConvertIllegalShapeCastOpsToTransposes).
https://github.com/llvm/llvm-project/pull/135841
More information about the Mlir-commits
mailing list