[Mlir-commits] [mlir] [mlir][vector] Avoid use of vector.splat in transforms (PR #150279)

Diego Caballero llvmlistbot at llvm.org
Tue Jul 29 09:20:45 PDT 2025


================
@@ -1007,26 +1020,23 @@ struct ReorderElementwiseOpsOnBroadcast final
     }
 
     // Get the type of the lhs operand
-    auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
-    if (!lhsBcastOrSplat ||
-        !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
-      return failure();
-    auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+    Value lhsSource = getBroadcastLikeSource(op->getOperand(0));
+    if (!lhsSource)
+      return rewriter.notifyMatchFailure(
+          op, "operand #0 not the result of a broadcast");
+    Type lhsBcastOrSplatType = lhsSource.getType();
 
     // Make sure that all operands are broadcast from identical types:
     //  * scalar (`vector.broadcast` + `vector.splat`), or
     //  * vector (`vector.broadcast`).
     // Otherwise the re-ordering wouldn't be safe.
-    if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
-          auto bcast = val.getDefiningOp<vector::BroadcastOp>();
-          if (bcast)
-            return (bcast.getOperand().getType() == lhsBcastOrSplatType);
-          auto splat = val.getDefiningOp<vector::SplatOp>();
-          if (splat)
-            return (splat.getOperand().getType() == lhsBcastOrSplatType);
+    if (!llvm::all_of(op->getOperands(), [lhsBcastOrSplatType](Value val) {
+          if (auto source = getBroadcastLikeSource(val))
+            return source.getType() == lhsBcastOrSplatType;
           return false;
         })) {
-      return failure();
+      return rewriter.notifyMatchFailure(
+          op, "not all operands are broadcasts from the sametype");
----------------
dcaballe wrote:

typo

https://github.com/llvm/llvm-project/pull/150279


More information about the Mlir-commits mailing list