[Mlir-commits] [mlir] [mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops (PR #95710)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Fri Jun 21 01:04:48 PDT 2024
================
@@ -833,4 +833,245 @@ struct LinalgFoldUnitExtentDimsPass
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};
+
+} // namespace
+
+namespace {
+
+static SmallVector<ReassociationIndices>
+getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
+ SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
+ auto lastDim = pos == rank - 1;
+ if (rank > 2) {
+ for (int64_t i = 0; i < rank - 1; i++) {
+ if (i == pos || (lastDim && i == pos - 1))
+ reassociation[i] = ReassociationIndices{i, i + 1};
+ else if (i < pos)
+ reassociation[i] = ReassociationIndices{i};
+ else
+ reassociation[i] = ReassociationIndices{i + 1};
+ }
+ }
+ return reassociation;
+}
+
+static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
+ int64_t pos) {
+ if (pos < 0)
+ return val;
+ auto valType = cast<ShapedType>(val.getType());
+ SmallVector<int64_t> collapsedShape(valType.getShape());
+ collapsedShape.erase(collapsedShape.begin() + pos);
+ return collapseValue(
+ rewriter, val.getLoc(), val, collapsedShape,
+ getReassociationForReshapeAtDim(valType.getRank(), pos),
+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
+}
+
+template <typename FromOpTy, typename ToOpTy>
+struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
+ using OpRewritePattern<FromOpTy>::OpRewritePattern;
+
+ SmallVector<Value, 3>
+ collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
+ ArrayRef<int64_t> operandCollapseDims) const {
+ assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
+ "expected 3 operands and dims");
+ return llvm::to_vector(llvm::map_range(
+ llvm::zip(operands, operandCollapseDims), [&](auto pair) {
+ return collapseSingletonDimAt(rewriter, std::get<0>(pair),
+ std::get<1>(pair));
+ }));
+ }
+
+ Value expandResult(PatternRewriter &rewriter, Value result,
+ RankedTensorType expandedType, int64_t dim) const {
+ return rewriter.create<tensor::ExpandShapeOp>(
+ result.getLoc(), expandedType, result,
+ getReassociationForReshapeAtDim(expandedType.getRank(), dim));
+ }
+
+ LogicalResult matchAndRewrite(FromOpTy contractionOp,
+ PatternRewriter &rewriter) const override {
+
+ auto loc = contractionOp.getLoc();
+ auto inputs = contractionOp.getDpsInputs();
+ auto inits = contractionOp.getDpsInits();
+ if (inputs.size() != 2 || inits.size() != 1)
+ return rewriter.notifyMatchFailure(contractionOp,
+ "expected 2 inputs and 1 init");
+ auto lhs = inputs[0];
+ auto rhs = inputs[1];
+ auto init = inits[0];
+ SmallVector<Value> operands{lhs, rhs, init};
+
+ auto maybeContractionDims = inferContractionDims(contractionOp);
+ if (failed(maybeContractionDims))
+ return rewriter.notifyMatchFailure(contractionOp,
+ "could not infer contraction dims");
+
+ auto contractionDims = maybeContractionDims.value();
+ SmallVector<int64_t> operandUnitDims;
+ if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
+ return rewriter.notifyMatchFailure(contractionOp,
+ "no reducable dims found");
+
+ auto collapsedOperands =
+ collapseOperands(rewriter, operands, operandUnitDims);
+ auto collapsedLhs = collapsedOperands[0];
+ auto collapsedRhs = collapsedOperands[1];
+ auto collapsedInit = collapsedOperands[2];
+ SmallVector<Type, 1> collapsedResultTy;
+ if (isa<RankedTensorType>(collapsedInit.getType()))
+ collapsedResultTy.push_back(collapsedInit.getType());
+ auto collapsedOp = rewriter.create<ToOpTy>(
+ loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
+ ValueRange{collapsedInit});
+ for (auto attr : contractionOp->getAttrs()) {
+ if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+ continue;
+ collapsedOp->setAttr(attr.getName(), attr.getValue());
+ }
+
+ auto results = contractionOp.getResults();
+ assert(results.size() < 2 && "expected at most one result");
+ if (results.size() < 1)
+ rewriter.replaceOp(contractionOp, collapsedOp);
+ else
----------------
ftynse wrote:
Nit: use braces for multi-line if/else bodies.
https://github.com/llvm/llvm-project/pull/95710
More information about the Mlir-commits
mailing list