[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Jan 21 02:29:12 PST 2025


================
@@ -937,7 +947,19 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
     for (auto attr : contractionOp->getAttrs()) {
       if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
         continue;
-      collapsedOp->setAttr(attr.getName(), attr.getValue());
+
+      // Update the indexing_maps attribute for the collapsed MatmulOp.
+      if (attr.getName() == "indexing_maps" &&
+          std::is_same<FromOpTy, BatchMatmulOp>::value &&
+          std::is_same<ToOpTy, MatmulOp>::value) {
+        SmallVector<Attribute, 3> indexingMapsAttr = llvm::map_to_vector(
+            MatmulOp::getDefaultIndexingMaps(rewriter.getContext()),
+            [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+        collapsedOp->setAttr(attr.getName(),
+                             rewriter.getArrayAttr(indexingMapsAttr));
----------------
adam-smnk wrote:

Hmm, I see but that's really not scalable solution.

I think we could ensure that only default indexing maps are present for all ops (see the other comment) and then drop them as default maps should be equivalent to `LinalgDialect::kMemoizedIndexingMapsAttrName`.

https://github.com/llvm/llvm-project/pull/122275


More information about the Mlir-commits mailing list