[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)

Andrzej Warzyński llvmlistbot at llvm.org
Sun Aug 10 09:00:31 PDT 2025


================
@@ -2841,9 +2841,47 @@ LogicalResult BroadcastOp::verify() {
   llvm_unreachable("unexpected vector.broadcast op error");
 }
 
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and shape_cast only adds or removes ones in the
+// leading dimensions.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+  auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+  if (!srcShapeCast)
+    return failure();
+
+  VectorType srcType = srcShapeCast.getSourceVectorType();
+  VectorType destType = broadcastOp.getResultVectorType();
+  // Check type compatibility.
+  if (vector::isBroadcastableTo(srcType, destType) !=
+      BroadcastableToResult::Success)
+    return failure();
+
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  ArrayRef<int64_t> shapecastShape =
+      srcShapeCast.getResultVectorType().getShape();
+  // Trailing dimensions should be the same if shape_cast only alters the
+  // leading dimensions.
+  unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
+  if (!llvm::equal(srcShape.take_back(numTrailingDims),
+                   shapecastShape.take_back(numTrailingDims)))
+    return failure();
+
+  assert(all_of(srcShape.drop_back(numTrailingDims),
+                [](int64_t E) { return E == 1; }) &&
+         all_of(shapecastShape.drop_back(numTrailingDims),
+                [](int64_t E) { return E == 1; }) &&
+         "ill-formed shape_cast");
----------------
banach-space wrote:

[nit] Unlike LLVM, we use `camelCasel` in MLIR for variable names. So, `E` -> `e` (rather confusing, I know). If you want to avoid `e` (less readable than `E` IMHO), you could try `E` -> `dim` 🤷🏻 

https://github.com/llvm/llvm-project/pull/150523


More information about the Mlir-commits mailing list