[Mlir-commits] [mlir] [mlir][vector] Clarify the semantics of BroadcastOp (PR #101928)

Jakub Kuderski llvmlistbot at llvm.org
Tue Aug 6 07:35:44 PDT 2024


================
@@ -2390,13 +2390,29 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
   // Source has an exact match or singleton value for all trailing dimensions
   // (all leading dimensions are simply duplicated).
   int64_t lead = dstRank - srcRank;
-  for (int64_t r = 0; r < srcRank; ++r) {
-    int64_t srcDim = srcVectorType.getDimSize(r);
-    int64_t dstDim = dstVectorType.getDimSize(lead + r);
-    if (srcDim != 1 && srcDim != dstDim) {
-      if (mismatchingDims) {
-        mismatchingDims->first = srcDim;
-        mismatchingDims->second = dstDim;
+  for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
+    bool mismatch = false;
+
+    // Check fixed-width dims.
+    int64_t srcDim = srcVectorType.getDimSize(dimIdx);
+    int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
+    if (srcDim != 1 && srcDim != dstDim)
+      mismatch = true;
+
+    // Check scalable flags.
+    bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
+    bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
+    if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
+        (srcDimScalableFlag != dstDimScalableFlag))
+      mismatch = true;
+
+    if (mismatch) {
+      if (mismatchingDims != nullptr) {
----------------
kuhar wrote:

This helps, thanks!

https://github.com/llvm/llvm-project/pull/101928


More information about the Mlir-commits mailing list