[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Jan 21 02:04:15 PST 2025


================
@@ -937,7 +947,19 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
     for (auto attr : contractionOp->getAttrs()) {
       if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
         continue;
-      collapsedOp->setAttr(attr.getName(), attr.getValue());
+
+      // Update the indexing_maps attribute for the collapsed MatmulOp.
+      if (attr.getName() == "indexing_maps" &&
+          std::is_same<FromOpTy, BatchMatmulOp>::value &&
+          std::is_same<ToOpTy, MatmulOp>::value) {
+        SmallVector<Attribute, 3> indexingMapsAttr = llvm::map_to_vector(
+            MatmulOp::getDefaultIndexingMaps(rewriter.getContext()),
+            [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+        collapsedOp->setAttr(attr.getName(),
+                             rewriter.getArrayAttr(indexingMapsAttr));
----------------
adam-smnk wrote:

Why is this special case needed at all when user defined maps are disallowed?
If the pass breaks without this extra snippet, I think it's just treating symptom instead of the root cause. 

https://github.com/llvm/llvm-project/pull/122275


More information about the Mlir-commits mailing list