[Mlir-commits] [mlir] 95e0ae9 - [MLIR][SparseTensor] Loop ordering strategy infrastructure (flag) (#154656)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 6 10:29:53 PDT 2025


Author: Govind Malasani
Date: 2025-10-06T17:29:48Z
New Revision: 95e0ae9fa7f3bfbfe3dc587428a064cbb8deb3ca

URL: https://github.com/llvm/llvm-project/commit/95e0ae9fa7f3bfbfe3dc587428a064cbb8deb3ca
DIFF: https://github.com/llvm/llvm-project/commit/95e0ae9fa7f3bfbfe3dc587428a064cbb8deb3ca.diff

LOG: [MLIR][SparseTensor] Loop ordering strategy infrastructure (flag) (#154656)

As discussed before, this PR adds the basic infrastructure/boiler plate
for loop ordering strategies to be implemented.

If this looks ok, I wanted to also mention some of the heuristics that I
would implement next, if they sound reasonable to you guys:
- Parallel first : prioritize parallel loops over reduction loops
- Dense outer : prioritize the most dense loops first
- Sparse outer : the opposite, potentially useful in some cases?

There is another that I am considering, stride/memory aware, which would
prioritize loops with better stride patterns (like sequential or
linear). Not sure how well this carries over to Sparse Tensor though.
Are there any ideas/heuristics that I should definitely try to
implement?

As we discussed, I will try to incrementally add heuristics. Sorry for
the delay on my end, and thank you so much for the feedback!

---------

Co-authored-by: Aart Bik <ajcbik at google.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 212f7b6f13c26..af64370a62dd7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -55,6 +55,16 @@ enum class SparseEmitStrategy {
   kDebugInterface, // generate only place-holder for sparse iteration
 };
 
+namespace sparse_tensor {
+
+/// Defines a strategy for loop ordering during sparse code generation.
+enum class LoopOrderingStrategy : unsigned {
+  kDefault, ///< Default strategy (eagerly selects last loop in topological
+            ///< sort).
+};
+
+} // namespace sparse_tensor
+
 #define GEN_PASS_DECL
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 
@@ -71,11 +81,16 @@ std::unique_ptr<Pass> createSparseAssembler(bool directOut);
 // The SparseReinterpretMap pass.
 //===----------------------------------------------------------------------===//
 
-void populateSparseReinterpretMap(RewritePatternSet &patterns,
-                                  ReinterpretMapScope scope);
+void populateSparseReinterpretMap(
+    RewritePatternSet &patterns, ReinterpretMapScope scope,
+    sparse_tensor::LoopOrderingStrategy strategy =
+        sparse_tensor::LoopOrderingStrategy::kDefault);
 
 std::unique_ptr<Pass> createSparseReinterpretMapPass();
 std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope);
+std::unique_ptr<Pass>
+createSparseReinterpretMapPass(ReinterpretMapScope scope,
+                               sparse_tensor::LoopOrderingStrategy strategy);
 
 //===----------------------------------------------------------------------===//
 // The PreSparsificationRewriting pass.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 2513e106f5b06..75e77d67db1b3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -81,6 +81,11 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
          clEnumValN(mlir::ReinterpretMapScope::kExceptGeneric,
                     "except-generic",
                     "Run on operations expect linalg.generic (e.g., foreach)"))}]>,
+    Option<"loopOrderingStrategy", "loop-ordering-strategy", "mlir::sparse_tensor::LoopOrderingStrategy",
+       "mlir::sparse_tensor::LoopOrderingStrategy::kDefault",
+       "Set the loop ordering strategy for sparse code generation", [{llvm::cl::values(
+         clEnumValN(mlir::sparse_tensor::LoopOrderingStrategy::kDefault, "default",
+                    "Default strategy (eagerly selects last loop in topological sort)"))}]>,
   ];
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index a1e35b87399ca..0fc5cc76de39c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -59,7 +59,7 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
 
 // Flattens an affine expression into a list of AffineDimExprs.
 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
-  explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
+  explicit AffineDimCollector(unsigned dimNum) : dims(dimNum) {};
   void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
   BitVector dims;
 };
@@ -67,7 +67,7 @@ struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
 // Flattens an affine expression into a list of AffineDimExprs.
 struct AffineExprAdmissibleVisitor
     : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
-  explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput){};
+  explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput) {};
 
   // We only allow AffineDimExpr on output.
   void visitAddExpr(AffineBinaryOpExpr expr) {
@@ -407,7 +407,10 @@ struct GenericOpReinterpretMap
 };
 
 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
-  using OpRewritePattern::OpRewritePattern;
+  GenericOpScheduler(MLIRContext *context,
+                     sparse_tensor::LoopOrderingStrategy strategy)
+      : OpRewritePattern<linalg::GenericOp>(context), strategy(strategy) {}
+
   LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
                                 PatternRewriter &rewriter) const override {
     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
@@ -420,7 +423,8 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
     if (linalgOp->hasAttr(sorted))
       return failure();
 
-    auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
+    // Pass strategy to IterationGraphSorter.
+    auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp, strategy);
     bool isAdmissible = false;
     AffineMap order;
     // A const list of all masks that we used for iteration graph
@@ -582,6 +586,9 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
     // TODO: convert more than one?
     return failure();
   }
+
+private:
+  sparse_tensor::LoopOrderingStrategy strategy;
 };
 
 //===----------------------------------------------------------------------===//
@@ -786,12 +793,13 @@ struct ForeachOpDemapper
 
 } // namespace
 
-void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
-                                        ReinterpretMapScope scope) {
+void mlir::populateSparseReinterpretMap(
+    RewritePatternSet &patterns, ReinterpretMapScope scope,
+    sparse_tensor::LoopOrderingStrategy strategy) {
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kGenericOnly) {
-    patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
-        patterns.getContext());
+    patterns.add<GenericOpReinterpretMap>(patterns.getContext());
+    patterns.add<GenericOpScheduler>(patterns.getContext(), strategy);
   }
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kExceptGeneric) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 153b9b170e5d3..b660e22154688 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -67,12 +67,13 @@ struct SparseReinterpretMap
   SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
   SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
     scope = options.scope;
+    loopOrderingStrategy = options.loopOrderingStrategy;
   }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateSparseReinterpretMap(patterns, scope);
+    populateSparseReinterpretMap(patterns, scope, loopOrderingStrategy);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };
@@ -438,6 +439,14 @@ mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
   return std::make_unique<SparseReinterpretMap>(options);
 }
 
+std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass(
+    ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy) {
+  SparseReinterpretMapOptions options;
+  options.scope = scope;
+  options.loopOrderingStrategy = strategy;
+  return std::make_unique<SparseReinterpretMap>(options);
+}
+
 std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
   return std::make_unique<PreSparsificationRewritePass>();
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index c7e463a5a5b49..73e0f3d2891d7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -100,7 +100,15 @@ AffineMap IterationGraphSorter::topoSort() {
     // We always prefer a parallel loop over a reduction loop because putting
     // a reduction loop early might make the loop sequence inadmissible.
     auto &it = !parIt.empty() ? parIt : redIt;
-    auto src = it.back();
+
+    // Select loop based on strategy.
+    unsigned src;
+    switch (strategy) {
+    case sparse_tensor::LoopOrderingStrategy::kDefault:
+      src = it.back();
+      break;
+    }
+
     loopOrder.push_back(src);
     it.pop_back();
     // Update in-degree, and push 0-degree node into worklist.
@@ -122,8 +130,8 @@ AffineMap IterationGraphSorter::topoSort() {
   return AffineMap();
 }
 
-IterationGraphSorter
-IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
+IterationGraphSorter IterationGraphSorter::fromGenericOp(
+    linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy) {
   // Must be a demapped sparse kernel.
   assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
          hasAnySparseOperandOrResult(genericOp) &&
@@ -140,14 +148,16 @@ IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
       genericOp.getIteratorTypesArray();
 
   return IterationGraphSorter(std::move(ins), std::move(loopMap), out, outMap,
-                              std::move(iterTypes));
+                              std::move(iterTypes), strategy);
 }
 
 IterationGraphSorter::IterationGraphSorter(
     SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out,
-    AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes)
+    AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes,
+    sparse_tensor::LoopOrderingStrategy strategy)
     : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
-      loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
+      loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
+      strategy(strategy) {
   // One map per tensor.
   assert(loop2InsLvl.size() == ins.size());
   // All the affine maps have the same number of dimensions (loops).

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
index a6abe9eb76c47..b2a16e9382758 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
 
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/IR/AffineMap.h"
 
 namespace mlir {
@@ -41,9 +42,12 @@ enum class SortMask : unsigned {
 
 class IterationGraphSorter {
 public:
-  /// Factory method that construct an iteration graph sorter
-  /// for the given linalg.generic operation.
-  static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp);
+  /// Factory method that constructs an iteration graph sorter
+  /// for the given linalg.generic operation with a specific loop ordering
+  /// strategy.
+  static IterationGraphSorter
+  fromGenericOp(linalg::GenericOp genericOp,
+                sparse_tensor::LoopOrderingStrategy strategy);
 
   /// Returns a permutation that represents the scheduled loop order.
   /// Note that the returned AffineMap could be null if the kernel
@@ -58,7 +62,9 @@ class IterationGraphSorter {
   IterationGraphSorter(SmallVector<Value> &&ins,
                        SmallVector<AffineMap> &&loop2InsLvl, Value out,
                        AffineMap loop2OutLvl,
-                       SmallVector<utils::IteratorType> &&iterTypes);
+                       SmallVector<utils::IteratorType> &&iterTypes,
+                       sparse_tensor::LoopOrderingStrategy strategy =
+                           sparse_tensor::LoopOrderingStrategy::kDefault);
 
   // Adds all the constraints in the given loop to level map.
   void addConstraints(Value t, AffineMap loop2LvlMap);
@@ -84,6 +90,9 @@ class IterationGraphSorter {
 
   // InDegree used for topo sort.
   std::vector<unsigned> inDegree;
+
+  // Loop ordering strategy.
+  sparse_tensor::LoopOrderingStrategy strategy;
 };
 
 } // namespace sparse_tensor


        


More information about the Mlir-commits mailing list