[Mlir-commits] [mlir] [mlir][sparse] schedule sparse kernels in a separate pass from sparsification. (PR #72423)

Yinying Li llvmlistbot at llvm.org
Wed Nov 15 10:54:31 PST 2023


================
@@ -411,6 +401,178 @@ struct GenericOpReinterpretMap
   }
 };
 
+struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+        hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
+        !hasAnySparseOperandOrResult(linalgOp)) {
+      return failure();
+    }
+
+    const StringRef sorted = "sorted";
+    if (linalgOp->hasAttr(sorted))
+      return failure();
+
+    auto scheduler = LoopScheduler::fromGenericOp(linalgOp);
+    bool isAdmissible = false;
+    AffineMap order;
+    // A const list of all masks that we used for iteration graph
+    // computation. Must be ordered from more strict to less strict.
+    // Ideally (though might not be guaranteed), the earlier a constraint mask
+    // can be satisfied, the faster the generated kernel will be.
+    const auto allMasks = {
+        SortMask::kIncludeAll,        SortMask::kIncludeDense,
+        SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
+        SortMask::kIncludeUndef,      SortMask::kSparseOnly};
+    for (const SortMask mask : allMasks) {
+      order = scheduler.schedule(mask);
+      if (order) {
+        if (isAdmissibleOrder(linalgOp, order)) {
+          isAdmissible = true;
+          break;
+        }
+        // else try a set of less strict constraints.
+      }
+    }
+
+    if (!order) {
+      // Cycles detected.
+      if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
+        return rewriter.notifyMatchFailure(
+            linalgOp, "the sparse kernel can not be scheduled: loop detected.");
+      }
+      return success();
+    }
+
+    if (!isAdmissible) {
+      return rewriter.notifyMatchFailure(
+          linalgOp, "the sparse kernel can not be scheduled.");
+    }
+
+    // Marks the GenericOp to avoid recursive matching.
+    linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
+
+    // Already sorted.
+    if (order.isIdentity())
+      return failure();
+
+    assert(order.isPermutation());
+    // `order` is orignial loop -> sorted loop map
+    ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
+    SmallVector<Attribute> curItTypes;
+    curItTypes.reserve(preItTypes.size());
+    for (AffineExpr expr : order.getResults()) {
+      unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
+      curItTypes.push_back(preItTypes[loopID]);
+    }
+
+    // Inverse `order` to get sorted loop -> original loop map
+    order = inversePermutation(order);
+    SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
+    for (AffineMap &idxMap : idxMaps)
+      idxMap = idxMap.compose(order); // sorted loop -> lvl map
+
+    rewriter.startRootUpdate(linalgOp);
+    linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
+    linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
+    rewriter.finalizeRootUpdate(linalgOp);
+
+    return success();
+  }
+
+private:
+  /// Whether the loop order is admissible by sparsification.
+  static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
+    if (!hasAnySparseResult(linalgOp))
+      return true;
+
+    OpOperand *lhs = linalgOp.getDpsInitOperand(0);
+    unsigned nest = 0;
+    const auto iteratorTypes = linalgOp.getIteratorTypesArray();
+    for (const AffineExpr l : order.getResults()) {
+      unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
+      auto itTp =
+          linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>();
+      if (linalg::isReductionIterator(itTp.getValue()))
+        break; // terminate at first reduction
+      nest++;
+    }
+    // Determine admissible dynamic insertion situations:
+    // (1) fully injective, since there are no reductions,
+    // (2) admissible 1-d expansion in innermost dimension.
+    return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
+  };
+
+  // Last resort cycle resolution.
+  static LogicalResult resolveCycle(LoopScheduler &scheduler,
+                                    linalg::LinalgOp linalgOp,
+                                    PatternRewriter &rewriter) {
+    // Compute topological sort while leaving out every sparse input tensor in
+    // succession until an acylic iteration graph results.
----------------
yinying-lisa-li wrote:

Do you mean resolves?

https://github.com/llvm/llvm-project/pull/72423


More information about the Mlir-commits mailing list