[Mlir-commits] [mlir] [mlir][vector] Fold broadcast(poison) -> poison (PR #135677)

James Newling llvmlistbot at llvm.org
Mon Apr 14 13:45:11 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/135677

I've also a test for broadcast(splat) -> splat which I think was missing 

>From 401d54dc215aee4ff7bc4b7aad2ce49d7b6d2584 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 14 Apr 2025 13:38:40 -0700
Subject: [PATCH] fold broadcast(splat) -> splat

Signed-off-by: James Newling <james.newling at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   |  2 ++
 mlir/test/Dialect/Vector/canonicalize.mlir | 22 ++++++++++++++++++++++
 2 files changed, 24 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..4b7ff757c150a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2590,6 +2590,8 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
   }
   if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
     return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
+  if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
+    return ub::PoisonAttr::get(getContext());
   return {};
 }
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..420244c5e734a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1151,6 +1151,28 @@ func.func @bitcast_i8_to_i32() -> (vector<4xi32>, vector<4xi32>) {
 
 // -----
 
+// CHECK-LABEL: broadcast_poison
+//       CHECK:  %[[POISON:.*]] = ub.poison : vector<4x6xi8>
+//       CHECK:  return %[[POISON]] : vector<4x6xi8>
+func.func @broadcast_poison() -> vector<4x6xi8> {
+  %poison = ub.poison : vector<6xi8>
+  %broadcast = vector.broadcast %poison : vector<6xi8> to vector<4x6xi8>
+  return %broadcast : vector<4x6xi8>
+}
+
+// ----- 
+
+// CHECK-LABEL:  broadcast_splat_constant
+//       CHECK:  %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
+//       CHECK:  return %[[CONST]] : vector<4x6xi8>
+func.func @broadcast_splat_constant() -> vector<4x6xi8> {
+  %cst = arith.constant dense<1> : vector<6xi8>
+  %broadcast = vector.broadcast %cst : vector<6xi8> to vector<4x6xi8>
+  return %broadcast : vector<4x6xi8>
+}
+
+// -----
+
 // CHECK-LABEL: broadcast_folding1
 //       CHECK: %[[CST:.*]] = arith.constant dense<42> : vector<4xi32>
 //   CHECK-NOT: vector.broadcast



More information about the Mlir-commits mailing list