[Mlir-commits] [mlir] [mlir][Linalg] Refine how broadcast dims are treated (PR #99015)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Jul 25 13:12:39 PDT 2024


================
@@ -629,7 +629,21 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
         loc, value, outputOperand->get(), ValueRange{});
   }
 
-  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
+  // The operand map may contain "zero" results, e.g.:
+  //    (d0, d1, d2, d3) -> (d0, d1, d2, 0)
+  // When applied to canonical vector shapes like these:
+  //    (1, 16, 16, 4)
+  // we would get:
+  //    (1, 16, 16, 0)
+  // Instead, we should extract the following map:
+  //    (d0, d1, d2, d3) -> (d0, d1, d2)
+  // This way, the corresponding vector/mask type will be:
+  //    vector<1x16x16xty>
+  // rather than:
+  //    vector<1x16x16x0xty>
+  auto opOperantMapWithoutZeros = opOperandMap.dropZeros();
+  write =
+      state.maskOperation(rewriter, write, linalgOp, opOperantMapWithoutZeros);
----------------
banach-space wrote:

> What happens with the xfer read counterpart? 

It turns out you've already implemented that :) 

https://github.com/llvm/llvm-project/blob/2ba3fe7356f065757a2279f65e4ef5c8f1476293/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L1327-L1335

> Should we move this logic into maskOperation?

Yes! Sending an update shortly. 

https://github.com/llvm/llvm-project/pull/99015


More information about the Mlir-commits mailing list