[Mlir-commits] [mlir] [mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops (PR #95710)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Jun 21 01:04:48 PDT 2024


================
@@ -833,4 +833,245 @@ struct LinalgFoldUnitExtentDimsPass
     (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
   }
 };
+
+} // namespace
+
+namespace {
+
+static SmallVector<ReassociationIndices>
+getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
+  SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
+  auto lastDim = pos == rank - 1;
+  if (rank > 2) {
+    for (int64_t i = 0; i < rank - 1; i++) {
+      if (i == pos || (lastDim && i == pos - 1))
+        reassociation[i] = ReassociationIndices{i, i + 1};
+      else if (i < pos)
+        reassociation[i] = ReassociationIndices{i};
+      else
+        reassociation[i] = ReassociationIndices{i + 1};
+    }
+  }
+  return reassociation;
+}
+
+static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
+                                    int64_t pos) {
+  if (pos < 0)
+    return val;
+  auto valType = cast<ShapedType>(val.getType());
+  SmallVector<int64_t> collapsedShape(valType.getShape());
+  collapsedShape.erase(collapsedShape.begin() + pos);
+  return collapseValue(
+      rewriter, val.getLoc(), val, collapsedShape,
+      getReassociationForReshapeAtDim(valType.getRank(), pos),
+      ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
+}
+
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
+  using OpRewritePattern<FromOpTy>::OpRewritePattern;
+
+  SmallVector<Value, 3>
+  collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
+                   ArrayRef<int64_t> operandCollapseDims) const {
+    assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
+           "expected 3 operands and dims");
+    return llvm::to_vector(llvm::map_range(
+        llvm::zip(operands, operandCollapseDims), [&](auto pair) {
+          return collapseSingletonDimAt(rewriter, std::get<0>(pair),
+                                        std::get<1>(pair));
+        }));
+  }
+
+  Value expandResult(PatternRewriter &rewriter, Value result,
+                     RankedTensorType expandedType, int64_t dim) const {
+    return rewriter.create<tensor::ExpandShapeOp>(
+        result.getLoc(), expandedType, result,
+        getReassociationForReshapeAtDim(expandedType.getRank(), dim));
+  }
+
+  LogicalResult matchAndRewrite(FromOpTy contractionOp,
+                                PatternRewriter &rewriter) const override {
+
+    auto loc = contractionOp.getLoc();
+    auto inputs = contractionOp.getDpsInputs();
+    auto inits = contractionOp.getDpsInits();
+    if (inputs.size() != 2 || inits.size() != 1)
+      return rewriter.notifyMatchFailure(contractionOp,
+                                         "expected 2 inputs and 1 init");
+    auto lhs = inputs[0];
+    auto rhs = inputs[1];
+    auto init = inits[0];
+    SmallVector<Value> operands{lhs, rhs, init};
+
+    auto maybeContractionDims = inferContractionDims(contractionOp);
+    if (failed(maybeContractionDims))
+      return rewriter.notifyMatchFailure(contractionOp,
+                                         "could not infer contraction dims");
+
+    auto contractionDims = maybeContractionDims.value();
+    SmallVector<int64_t> operandUnitDims;
+    if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
+      return rewriter.notifyMatchFailure(contractionOp,
+                                         "no reducable dims found");
+
+    auto collapsedOperands =
+        collapseOperands(rewriter, operands, operandUnitDims);
+    auto collapsedLhs = collapsedOperands[0];
+    auto collapsedRhs = collapsedOperands[1];
+    auto collapsedInit = collapsedOperands[2];
+    SmallVector<Type, 1> collapsedResultTy;
+    if (isa<RankedTensorType>(collapsedInit.getType()))
+      collapsedResultTy.push_back(collapsedInit.getType());
+    auto collapsedOp = rewriter.create<ToOpTy>(
+        loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
+        ValueRange{collapsedInit});
+    for (auto attr : contractionOp->getAttrs()) {
+      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+        continue;
+      collapsedOp->setAttr(attr.getName(), attr.getValue());
+    }
+
+    auto results = contractionOp.getResults();
+    assert(results.size() < 2 && "expected at most one result");
+    if (results.size() < 1)
+      rewriter.replaceOp(contractionOp, collapsedOp);
+    else
+      rewriter.replaceOp(
+          contractionOp,
+          expandResult(rewriter, collapsedOp.getResultTensors()[0],
+                       cast<RankedTensorType>(results[0].getType()),
+                       operandUnitDims[2]));
+
+    return success();
+  }
+
+  virtual LogicalResult
+  getOperandUnitDims(LinalgOp op,
+                     SmallVectorImpl<int64_t> &operandUnitDindices) const = 0;
+};
+
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
+  using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+
+  LogicalResult getOperandUnitDims(
+      LinalgOp op,
+      SmallVectorImpl<int64_t> &operandUnitDindices) const override {
+    auto inputs = op.getDpsInputs();
+    auto inits = op.getDpsInits();
+    if (inputs.size() != 2 || inits.size() != 1)
+      return failure();
+
+    auto maybeContractionDims = inferContractionDims(op);
+    if (failed(maybeContractionDims))
+      return failure();
+    auto contractionDims = maybeContractionDims.value();
+
+    if (contractionDims.batch.size() != 1)
+      return failure();
+    auto batchDim = contractionDims.batch[0];
+    SmallVector<std::pair<Value, unsigned>, 2> bOperands;
+    op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
+    if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
+          return cast<ShapedType>(std::get<0>(pair).getType())
+                     .getShape()[std::get<1>(pair)] != 1;
+        }))
+      return failure();
+
+    operandUnitDindices = SmallVector<int64_t>{std::get<1>(bOperands[0]),
+                                               std::get<1>(bOperands[1]),
+                                               std::get<1>(bOperands[2])};
+    return success();
+  }
+};
+
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
+  using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
+
+  static bool constexpr reduceLeft =
+      (std::is_same<FromOpTy, BatchMatmulOp>::value &&
----------------
ftynse wrote:

Nit: `std::is_same_v`

https://github.com/llvm/llvm-project/pull/95710


More information about the Mlir-commits mailing list