[Mlir-commits] [mlir] [mlir][vector] Avoid use of vector.splat in transforms (PR #150279)
Diego Caballero
llvmlistbot at llvm.org
Tue Jul 29 09:20:45 PDT 2025
================
@@ -1007,26 +1020,23 @@ struct ReorderElementwiseOpsOnBroadcast final
}
// Get the type of the lhs operand
- auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
- if (!lhsBcastOrSplat ||
- !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
- return failure();
- auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+ Value lhsSource = getBroadcastLikeSource(op->getOperand(0));
+ if (!lhsSource)
+ return rewriter.notifyMatchFailure(
+ op, "operand #0 not the result of a broadcast");
+ Type lhsBcastOrSplatType = lhsSource.getType();
// Make sure that all operands are broadcast from identical types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
- auto bcast = val.getDefiningOp<vector::BroadcastOp>();
- if (bcast)
- return (bcast.getOperand().getType() == lhsBcastOrSplatType);
- auto splat = val.getDefiningOp<vector::SplatOp>();
- if (splat)
- return (splat.getOperand().getType() == lhsBcastOrSplatType);
+ if (!llvm::all_of(op->getOperands(), [lhsBcastOrSplatType](Value val) {
+ if (auto source = getBroadcastLikeSource(val))
+ return source.getType() == lhsBcastOrSplatType;
return false;
})) {
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "not all operands are broadcasts from the sametype");
----------------
dcaballe wrote:
typo
https://github.com/llvm/llvm-project/pull/150279
More information about the Mlir-commits
mailing list