[Mlir-commits] [mlir] [mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops (PR #95710)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Jun 21 01:04:48 PDT 2024


================
@@ -833,4 +833,245 @@ struct LinalgFoldUnitExtentDimsPass
     (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
   }
 };
+
+} // namespace
+
+namespace {
+
+static SmallVector<ReassociationIndices>
+getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
+  SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
+  auto lastDim = pos == rank - 1;
+  if (rank > 2) {
+    for (int64_t i = 0; i < rank - 1; i++) {
+      if (i == pos || (lastDim && i == pos - 1))
+        reassociation[i] = ReassociationIndices{i, i + 1};
+      else if (i < pos)
+        reassociation[i] = ReassociationIndices{i};
+      else
+        reassociation[i] = ReassociationIndices{i + 1};
+    }
+  }
+  return reassociation;
+}
+
+static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
+                                    int64_t pos) {
+  if (pos < 0)
+    return val;
+  auto valType = cast<ShapedType>(val.getType());
+  SmallVector<int64_t> collapsedShape(valType.getShape());
+  collapsedShape.erase(collapsedShape.begin() + pos);
+  return collapseValue(
+      rewriter, val.getLoc(), val, collapsedShape,
+      getReassociationForReshapeAtDim(valType.getRank(), pos),
+      ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
+}
+
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
+  using OpRewritePattern<FromOpTy>::OpRewritePattern;
+
+  SmallVector<Value, 3>
+  collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
+                   ArrayRef<int64_t> operandCollapseDims) const {
+    assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
+           "expected 3 operands and dims");
+    return llvm::to_vector(llvm::map_range(
----------------
ftynse wrote:

Nit: llvm::map_to_vector

Also note that it will the number of vector stack elements based on `sizeof` of the element type, in this case to 6. This will force a copy to happen for the return value that has a different number of stack elements where RVO could have taken place otherwise. Avoid specifying explicit number of stack elements in vectors unless there is a strong reason.

https://github.com/llvm/llvm-project/pull/95710


More information about the Mlir-commits mailing list