[Mlir-commits] [mlir] [mlir][sparse] schedule sparse kernels in a separate pass from sparsification. (PR #72423)
Yinying Li
llvmlistbot at llvm.org
Wed Nov 15 10:54:31 PST 2023
================
@@ -411,6 +401,178 @@ struct GenericOpReinterpretMap
}
};
+struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+ PatternRewriter &rewriter) const override {
+ if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+ hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
+ !hasAnySparseOperandOrResult(linalgOp)) {
+ return failure();
+ }
+
+ const StringRef sorted = "sorted";
+ if (linalgOp->hasAttr(sorted))
+ return failure();
+
+ auto scheduler = LoopScheduler::fromGenericOp(linalgOp);
+ bool isAdmissible = false;
+ AffineMap order;
+ // A const list of all masks that we used for iteration graph
+ // computation. Must be ordered from more strict to less strict.
+ // Ideally (though might not be guaranteed), the earlier a constraint mask
+ // can be satisfied, the faster the generated kernel will be.
+ const auto allMasks = {
+ SortMask::kIncludeAll, SortMask::kIncludeDense,
+ SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
+ SortMask::kIncludeUndef, SortMask::kSparseOnly};
+ for (const SortMask mask : allMasks) {
+ order = scheduler.schedule(mask);
+ if (order) {
+ if (isAdmissibleOrder(linalgOp, order)) {
+ isAdmissible = true;
+ break;
+ }
+ // else try a set of less strict constraints.
+ }
+ }
+
+ if (!order) {
+ // Cycles detected.
+ if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
+ return rewriter.notifyMatchFailure(
+ linalgOp, "the sparse kernel can not be scheduled: loop detected.");
+ }
+ return success();
+ }
+
+ if (!isAdmissible) {
+ return rewriter.notifyMatchFailure(
+ linalgOp, "the sparse kernel can not be scheduled.");
+ }
+
+ // Marks the GenericOp to avoid recursive matching.
+ linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+
+ // Already sorted.
+ if (order.isIdentity())
+ return failure();
+
+ assert(order.isPermutation());
+ // `order` is orignial loop -> sorted loop map
+ ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
+ SmallVector<Attribute> curItTypes;
+ curItTypes.reserve(preItTypes.size());
+ for (AffineExpr expr : order.getResults()) {
+ unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
+ curItTypes.push_back(preItTypes[loopID]);
+ }
+
+ // Inverse `order` to get sorted loop -> original loop map
+ order = inversePermutation(order);
+ SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
+ for (AffineMap &idxMap : idxMaps)
+ idxMap = idxMap.compose(order); // sorted loop -> lvl map
+
+ rewriter.startRootUpdate(linalgOp);
+ linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
+ linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
+ rewriter.finalizeRootUpdate(linalgOp);
+
+ return success();
+ }
+
+private:
+ /// Whether the loop order is admissible by sparsification.
+ static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
+ if (!hasAnySparseResult(linalgOp))
+ return true;
+
+ OpOperand *lhs = linalgOp.getDpsInitOperand(0);
+ unsigned nest = 0;
+ const auto iteratorTypes = linalgOp.getIteratorTypesArray();
+ for (const AffineExpr l : order.getResults()) {
+ unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
+ auto itTp =
+ linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>();
+ if (linalg::isReductionIterator(itTp.getValue()))
+ break; // terminate at first reduction
+ nest++;
+ }
+ // Determine admissible dynamic insertion situations:
+ // (1) fully injective, since there are no reductions,
+ // (2) admissible 1-d expansion in innermost dimension.
+ return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
+ };
+
+ // Last resort cycle resolution.
+ static LogicalResult resolveCycle(LoopScheduler &scheduler,
+ linalg::LinalgOp linalgOp,
+ PatternRewriter &rewriter) {
+ // Compute topological sort while leaving out every sparse input tensor in
+ // succession until an acylic iteration graph results.
----------------
yinying-lisa-li wrote:
Do you mean resolves?
https://github.com/llvm/llvm-project/pull/72423
More information about the Mlir-commits
mailing list