[Mlir-commits] [mlir] [mlir][Linalg] Refine how broadcast dims are treated (PR #99015)

Diego Caballero llvmlistbot at llvm.org
Fri Jul 19 15:33:43 PDT 2024


================
@@ -629,7 +629,21 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
         loc, value, outputOperand->get(), ValueRange{});
   }
 
-  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
+  // The operand map may contain "zero" results, e.g.:
+  //    (d0, d1, d2, d3) -> (d0, d1, d2, 0)
+  // When applied to canonical vector shapes like these:
+  //    (1, 16, 16, 4)
+  // we would get:
+  //    (1, 16, 16, 0)
+  // Instead, we should extract the following map:
+  //    (d0, d1, d2, d3) -> (d0, d1, d2)
+  // This way, the corresponding vector/mask type will be:
+  //    vector<1x16x16xty>
+  // rather than:
+  //    vector<1x16x16x0xty>
+  auto opOperantMapWithoutZeros = opOperandMap.dropZeros();
+  write =
+      state.maskOperation(rewriter, write, linalgOp, opOperantMapWithoutZeros);
----------------
dcaballe wrote:

What happens with the xfer read counterpart? Should we move this logic into `maskOperation`?

https://github.com/llvm/llvm-project/pull/99015


More information about the Mlir-commits mailing list