[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Adam Siemieniuk
llvmlistbot at llvm.org
Tue Jan 21 02:04:15 PST 2025
================
@@ -937,7 +947,19 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
for (auto attr : contractionOp->getAttrs()) {
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
continue;
- collapsedOp->setAttr(attr.getName(), attr.getValue());
+
+ // Update the indexing_maps attribute for the collapsed MatmulOp.
+ if (attr.getName() == "indexing_maps" &&
+ std::is_same<FromOpTy, BatchMatmulOp>::value &&
+ std::is_same<ToOpTy, MatmulOp>::value) {
+ SmallVector<Attribute, 3> indexingMapsAttr = llvm::map_to_vector(
+ MatmulOp::getDefaultIndexingMaps(rewriter.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ collapsedOp->setAttr(attr.getName(),
+ rewriter.getArrayAttr(indexingMapsAttr));
----------------
adam-smnk wrote:
Why is this special case needed at all when user defined maps are disallowed?
If the pass breaks without this extra snippet, I think it's just treating symptom instead of the root cause.
https://github.com/llvm/llvm-project/pull/122275
More information about the Mlir-commits
mailing list