[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Adam Siemieniuk
llvmlistbot at llvm.org
Tue Jan 21 02:29:12 PST 2025
================
@@ -937,7 +947,19 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
for (auto attr : contractionOp->getAttrs()) {
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
continue;
- collapsedOp->setAttr(attr.getName(), attr.getValue());
+
+ // Update the indexing_maps attribute for the collapsed MatmulOp.
+ if (attr.getName() == "indexing_maps" &&
+ std::is_same<FromOpTy, BatchMatmulOp>::value &&
+ std::is_same<ToOpTy, MatmulOp>::value) {
+ SmallVector<Attribute, 3> indexingMapsAttr = llvm::map_to_vector(
+ MatmulOp::getDefaultIndexingMaps(rewriter.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ collapsedOp->setAttr(attr.getName(),
+ rewriter.getArrayAttr(indexingMapsAttr));
----------------
adam-smnk wrote:
Hmm, I see but that's really not scalable solution.
I think we could ensure that only default indexing maps are present for all ops (see the other comment) and then drop them as default maps should be equivalent to `LinalgDialect::kMemoizedIndexingMapsAttrName`.
https://github.com/llvm/llvm-project/pull/122275
More information about the Mlir-commits
mailing list