[Mlir-commits] [mlir] [MLIR][SparseTensor] Dense Outer Loop Ordering Strategy (PR #160168)

Govind Malasani llvmlistbot at llvm.org
Thu Oct 2 16:28:01 PDT 2025


https://github.com/gmalasan updated https://github.com/llvm/llvm-project/pull/160168

>From ed2c5d11e53df3fe79611c741ca45c1d4609f7c9 Mon Sep 17 00:00:00 2001
From: gmalasan <145235389+gmalasan at users.noreply.github.com>
Date: Wed, 20 Aug 2025 18:07:07 -0700
Subject: [PATCH 1/7] [MLIR][SparseTensor] Loop ordering strategy
 infrastructure (flag)

---
 .../mlir/Dialect/SparseTensor/Transforms/Passes.h  | 14 +++++++++++++-
 .../mlir/Dialect/SparseTensor/Transforms/Passes.td |  5 +++++
 .../Transforms/SparseReinterpretMap.cpp            |  5 ++++-
 .../SparseTensor/Transforms/SparseTensorPasses.cpp | 12 +++++++++++-
 4 files changed, 33 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 212f7b6f13c26..24f30353e58b5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -55,6 +55,15 @@ enum class SparseEmitStrategy {
   kDebugInterface, // generate only place-holder for sparse iteration
 };
 
+namespace sparse_tensor {
+
+/// Selects between different loop ordering strategies for sparse tensor 
+enum class LoopOrderingStrategy : unsigned {
+  kDefault, ///< Default strategy (current behavior)
+};
+
+} // namespace sparse_tensor
+
 #define GEN_PASS_DECL
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 
@@ -72,10 +81,13 @@ std::unique_ptr<Pass> createSparseAssembler(bool directOut);
 //===----------------------------------------------------------------------===//
 
 void populateSparseReinterpretMap(RewritePatternSet &patterns,
-                                  ReinterpretMapScope scope);
+                                  ReinterpretMapScope scope,
+                                  sparse_tensor::LoopOrderingStrategy strategy = sparse_tensor::LoopOrderingStrategy::kDefault);
 
 std::unique_ptr<Pass> createSparseReinterpretMapPass();
 std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope);
+std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope, 
+                                                     sparse_tensor::LoopOrderingStrategy strategy);
 
 //===----------------------------------------------------------------------===//
 // The PreSparsificationRewriting pass.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 2513e106f5b06..1b0c7f7583827 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -81,6 +81,11 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
          clEnumValN(mlir::ReinterpretMapScope::kExceptGeneric,
                     "except-generic",
                     "Run on operations expect linalg.generic (e.g., foreach)"))}]>,
+    Option<"loopOrderingStrategy", "loop-ordering-strategy", "mlir::sparse_tensor::LoopOrderingStrategy",
+       "mlir::sparse_tensor::LoopOrderingStrategy::kDefault",
+       "Set the loop ordering strategy for sparse tensor dialect", [{llvm::cl::values(
+         clEnumValN(mlir::sparse_tensor::LoopOrderingStrategy::kDefault, "default",
+                    "Default strategy (current behavior)"))}]>,
   ];
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index a1e35b87399ca..5dbd3e976e52c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -787,7 +787,10 @@ struct ForeachOpDemapper
 } // namespace
 
 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
-                                        ReinterpretMapScope scope) {
+                                        ReinterpretMapScope scope,
+                                        sparse_tensor::LoopOrderingStrategy strategy) {
+  (void)strategy; // Suppress unused parameter warning
+  
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kGenericOnly) {
     patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 153b9b170e5d3..aa31927e0602c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -67,12 +67,13 @@ struct SparseReinterpretMap
   SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
   SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
     scope = options.scope;
+    loopOrderingStrategy = options.loopOrderingStrategy;
   }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateSparseReinterpretMap(patterns, scope);
+    populateSparseReinterpretMap(patterns, scope, loopOrderingStrategy);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };
@@ -438,6 +439,15 @@ mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
   return std::make_unique<SparseReinterpretMap>(options);
 }
 
+std::unique_ptr<Pass>
+mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope, 
+                                     sparse_tensor::LoopOrderingStrategy strategy) {
+  SparseReinterpretMapOptions options;
+  options.scope = scope;
+  options.loopOrderingStrategy = strategy;
+  return std::make_unique<SparseReinterpretMap>(options);
+}
+
 std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
   return std::make_unique<PreSparsificationRewritePass>();
 }

>From d6c542174ea491f54c08743f7cc16d3b67d22ff4 Mon Sep 17 00:00:00 2001
From: gmalasan <145235389+gmalasan at users.noreply.github.com>
Date: Wed, 20 Aug 2025 18:36:51 -0700
Subject: [PATCH 2/7] [MLIR][SparseTensor] Fixed up the rest of the boilerplate
 code, strategy fully available in the topoSort method in IterationGraphSorter

---
 .../Transforms/SparseReinterpretMap.cpp       | 16 +++++++++------
 .../Transforms/Utils/IterationGraphSorter.cpp | 20 ++++++++++++++-----
 .../Transforms/Utils/IterationGraphSorter.h   | 12 ++++++++---
 3 files changed, 34 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 5dbd3e976e52c..748d26a0de3b6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -407,7 +407,9 @@ struct GenericOpReinterpretMap
 };
 
 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
-  using OpRewritePattern::OpRewritePattern;
+  GenericOpScheduler(MLIRContext *context, sparse_tensor::LoopOrderingStrategy strategy)
+      : OpRewritePattern<linalg::GenericOp>(context), strategy(strategy) {}
+  
   LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
                                 PatternRewriter &rewriter) const override {
     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
@@ -420,7 +422,8 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
     if (linalgOp->hasAttr(sorted))
       return failure();
 
-    auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
+    // Pass strategy to IterationGraphSorter
+    auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp, strategy);
     bool isAdmissible = false;
     AffineMap order;
     // A const list of all masks that we used for iteration graph
@@ -582,6 +585,9 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
     // TODO: convert more than one?
     return failure();
   }
+
+private:
+  sparse_tensor::LoopOrderingStrategy strategy;
 };
 
 //===----------------------------------------------------------------------===//
@@ -789,12 +795,10 @@ struct ForeachOpDemapper
 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
                                         ReinterpretMapScope scope,
                                         sparse_tensor::LoopOrderingStrategy strategy) {
-  (void)strategy; // Suppress unused parameter warning
-  
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kGenericOnly) {
-    patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
-        patterns.getContext());
+    patterns.add<GenericOpReinterpretMap>(patterns.getContext());
+    patterns.add<GenericOpScheduler>(patterns.getContext(), strategy);
   }
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kExceptGeneric) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index c7e463a5a5b49..6cebb4bcdc3c3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -100,7 +100,15 @@ AffineMap IterationGraphSorter::topoSort() {
     // We always prefer a parallel loop over a reduction loop because putting
     // a reduction loop early might make the loop sequence inadmissible.
     auto &it = !parIt.empty() ? parIt : redIt;
-    auto src = it.back();
+    
+    // Select loop based on strategy
+    unsigned src;
+    switch (strategy) {
+    case sparse_tensor::LoopOrderingStrategy::kDefault:
+      src = it.back();
+      break;
+    }
+    
     loopOrder.push_back(src);
     it.pop_back();
     // Update in-degree, and push 0-degree node into worklist.
@@ -123,7 +131,8 @@ AffineMap IterationGraphSorter::topoSort() {
 }
 
 IterationGraphSorter
-IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
+IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp, 
+                                   sparse_tensor::LoopOrderingStrategy strategy) {
   // Must be a demapped sparse kernel.
   assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
          hasAnySparseOperandOrResult(genericOp) &&
@@ -140,14 +149,15 @@ IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
       genericOp.getIteratorTypesArray();
 
   return IterationGraphSorter(std::move(ins), std::move(loopMap), out, outMap,
-                              std::move(iterTypes));
+                              std::move(iterTypes), strategy);
 }
 
 IterationGraphSorter::IterationGraphSorter(
     SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out,
-    AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes)
+    AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes,
+    sparse_tensor::LoopOrderingStrategy strategy)
     : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
-      loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
+      loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)), strategy(strategy) {
   // One map per tensor.
   assert(loop2InsLvl.size() == ins.size());
   // All the affine maps have the same number of dimensions (loops).
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
index a6abe9eb76c47..a4902d31e0077 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
 
 #include "mlir/IR/AffineMap.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 
 namespace mlir {
 
@@ -42,8 +43,9 @@ enum class SortMask : unsigned {
 class IterationGraphSorter {
 public:
   /// Factory method that construct an iteration graph sorter
-  /// for the given linalg.generic operation.
-  static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp);
+  /// for the given linalg.generic operation with a specific strategy.
+  static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp,
+                                           sparse_tensor::LoopOrderingStrategy strategy);
 
   /// Returns a permutation that represents the scheduled loop order.
   /// Note that the returned AffineMap could be null if the kernel
@@ -58,7 +60,8 @@ class IterationGraphSorter {
   IterationGraphSorter(SmallVector<Value> &&ins,
                        SmallVector<AffineMap> &&loop2InsLvl, Value out,
                        AffineMap loop2OutLvl,
-                       SmallVector<utils::IteratorType> &&iterTypes);
+                       SmallVector<utils::IteratorType> &&iterTypes,
+                       sparse_tensor::LoopOrderingStrategy strategy = sparse_tensor::LoopOrderingStrategy::kDefault);
 
   // Adds all the constraints in the given loop to level map.
   void addConstraints(Value t, AffineMap loop2LvlMap);
@@ -84,6 +87,9 @@ class IterationGraphSorter {
 
   // InDegree used for topo sort.
   std::vector<unsigned> inDegree;
+
+  // Loop ordering strategy.
+  sparse_tensor::LoopOrderingStrategy strategy;
 };
 
 } // namespace sparse_tensor

>From bb15061aade4c7b3cfa057cfc4111650c95ac476 Mon Sep 17 00:00:00 2001
From: gmalasan <145235389+gmalasan at users.noreply.github.com>
Date: Tue, 26 Aug 2025 16:22:54 -0400
Subject: [PATCH 3/7] [MLIR][SparseTensor] Fixed PR feedback about style

---
 mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h    | 4 ++--
 mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td   | 4 ++--
 .../Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp  | 2 +-
 .../SparseTensor/Transforms/Utils/IterationGraphSorter.cpp    | 2 +-
 .../SparseTensor/Transforms/Utils/IterationGraphSorter.h      | 2 +-
 5 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 24f30353e58b5..93ff5ce606423 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -57,9 +57,9 @@ enum class SparseEmitStrategy {
 
 namespace sparse_tensor {
 
-/// Selects between different loop ordering strategies for sparse tensor 
+/// Defines a strategy for loop ordering during sparse code generation.
 enum class LoopOrderingStrategy : unsigned {
-  kDefault, ///< Default strategy (current behavior)
+  kDefault, ///< Default strategy (selects loops from back of available candidates).
 };
 
 } // namespace sparse_tensor
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 1b0c7f7583827..91bc4ae8a6bd2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -83,9 +83,9 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
                     "Run on operations expect linalg.generic (e.g., foreach)"))}]>,
     Option<"loopOrderingStrategy", "loop-ordering-strategy", "mlir::sparse_tensor::LoopOrderingStrategy",
        "mlir::sparse_tensor::LoopOrderingStrategy::kDefault",
-       "Set the loop ordering strategy for sparse tensor dialect", [{llvm::cl::values(
+       "Set the loop ordering strategy for sparse code generation", [{llvm::cl::values(
          clEnumValN(mlir::sparse_tensor::LoopOrderingStrategy::kDefault, "default",
-                    "Default strategy (current behavior)"))}]>,
+                    "Default strategy (selects loops from back of available candidates)"))}]>,
   ];
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 748d26a0de3b6..e068eee0587a0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -422,7 +422,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
     if (linalgOp->hasAttr(sorted))
       return failure();
 
-    // Pass strategy to IterationGraphSorter
+    // Pass strategy to IterationGraphSorter.
     auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp, strategy);
     bool isAdmissible = false;
     AffineMap order;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index 6cebb4bcdc3c3..29596e5febe1a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -101,7 +101,7 @@ AffineMap IterationGraphSorter::topoSort() {
     // a reduction loop early might make the loop sequence inadmissible.
     auto &it = !parIt.empty() ? parIt : redIt;
     
-    // Select loop based on strategy
+    // Select loop based on strategy.
     unsigned src;
     switch (strategy) {
     case sparse_tensor::LoopOrderingStrategy::kDefault:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
index a4902d31e0077..22c186d563c4a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
@@ -43,7 +43,7 @@ enum class SortMask : unsigned {
 class IterationGraphSorter {
 public:
   /// Factory method that construct an iteration graph sorter
-  /// for the given linalg.generic operation with a specific strategy.
+  /// for the given linalg.generic operation with a specific loop ordering strategy.
   static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp,
                                            sparse_tensor::LoopOrderingStrategy strategy);
 

>From 89242945f96ac63661a877d28ceacba64a62f698 Mon Sep 17 00:00:00 2001
From: gmalasan <145235389+gmalasan at users.noreply.github.com>
Date: Mon, 8 Sep 2025 18:31:08 -0400
Subject: [PATCH 4/7] [MLIR][SparseTensor] Comment style fixes

---
 mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td     | 2 +-
 .../SparseTensor/Transforms/Utils/IterationGraphSorter.h        | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 91bc4ae8a6bd2..75e77d67db1b3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -85,7 +85,7 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
        "mlir::sparse_tensor::LoopOrderingStrategy::kDefault",
        "Set the loop ordering strategy for sparse code generation", [{llvm::cl::values(
          clEnumValN(mlir::sparse_tensor::LoopOrderingStrategy::kDefault, "default",
-                    "Default strategy (selects loops from back of available candidates)"))}]>,
+                    "Default strategy (eagerly selects last loop in topological sort)"))}]>,
   ];
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
index 22c186d563c4a..868a978e9a462 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
@@ -42,7 +42,7 @@ enum class SortMask : unsigned {
 
 class IterationGraphSorter {
 public:
-  /// Factory method that construct an iteration graph sorter
+  /// Factory method that constructs an iteration graph sorter
   /// for the given linalg.generic operation with a specific loop ordering strategy.
   static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp,
                                            sparse_tensor::LoopOrderingStrategy strategy);

>From 42c1907549cda1ee9c93cb8b02d5684fc918beb1 Mon Sep 17 00:00:00 2001
From: gmalasan <145235389+gmalasan at users.noreply.github.com>
Date: Tue, 9 Sep 2025 17:28:31 -0400
Subject: [PATCH 5/7] [MLIR][SparseTensor] Missed comment style fix in Passes.h

---
 mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 93ff5ce606423..9a966a4ddb8b6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -59,7 +59,7 @@ namespace sparse_tensor {
 
 /// Defines a strategy for loop ordering during sparse code generation.
 enum class LoopOrderingStrategy : unsigned {
-  kDefault, ///< Default strategy (selects loops from back of available candidates).
+  kDefault, ///< Default strategy (eagerly selects last loop in topological sort).
 };
 
 } // namespace sparse_tensor

>From 7e9c33fee1e6afe8680ea446f34b6aacb3b61832 Mon Sep 17 00:00:00 2001
From: gmalasan <145235389+gmalasan at users.noreply.github.com>
Date: Sun, 21 Sep 2025 21:45:57 -0400
Subject: [PATCH 6/7] [MLIR][SparseTensor] Implemented denseOuter heuristic;
 prefer dense loops to sparse/singleton loop

---
 .../Dialect/SparseTensor/Transforms/Passes.h  |  1 +
 .../Dialect/SparseTensor/Transforms/Passes.td |  4 +-
 .../Transforms/Utils/IterationGraphSorter.cpp | 72 ++++++++++++++++++-
 3 files changed, 75 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 9a966a4ddb8b6..667c73e8f3a48 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -60,6 +60,7 @@ namespace sparse_tensor {
 /// Defines a strategy for loop ordering during sparse code generation.
 enum class LoopOrderingStrategy : unsigned {
   kDefault, ///< Default strategy (eagerly selects last loop in topological sort).
+  kDenseOuter, ///< Prefer dense, then sparse, then singleton dimensions outermost.
 };
 
 } // namespace sparse_tensor
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 75e77d67db1b3..0b8562e484f51 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -85,7 +85,9 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
        "mlir::sparse_tensor::LoopOrderingStrategy::kDefault",
        "Set the loop ordering strategy for sparse code generation", [{llvm::cl::values(
          clEnumValN(mlir::sparse_tensor::LoopOrderingStrategy::kDefault, "default",
-                    "Default strategy (eagerly selects last loop in topological sort)"))}]>,
+                    "Default strategy (eagerly selects last loop in topological sort)"),
+         clEnumValN(mlir::sparse_tensor::LoopOrderingStrategy::kDenseOuter, "dense-outer",
+                    "Prefer dense, then compressed, then singleton dimensions outermost"))}]>,
   ];
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index 29596e5febe1a..c351cd5792136 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -80,6 +80,53 @@ inline static bool includesDenseOutput(SortMask mask) {
   return includesAny(mask, SortMask::kIncludeDenseOutput);
 }
 
+/// Returns a sparsity rank for loop ordering: lower values indicate
+/// dimensions that should be placed in outer loops.
+/// 0 = Dense, 1 = Compressed, 2 = Singleton, 3 = Other/Unknown
+static unsigned getLoopSparsityRank(unsigned loop,
+                                   ArrayRef<Value> allTensors,
+                                   ArrayRef<AffineMap> allMaps) {
+  unsigned bestRank = 3; // Default: most sparse (unknown/singleton-like)
+  
+  for (auto [tensor, map] : llvm::zip(allTensors, allMaps)) {
+    // Check if this loop accesses this tensor
+    bool loopAccessesTensor = false;
+    unsigned tensorDim = 0;
+    for (AffineExpr expr : map.getResults()) {
+      if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+        if (dimExpr.getPosition() == loop) {
+          loopAccessesTensor = true;
+          break;
+        }
+      }
+      tensorDim++;
+    }
+    
+    if (loopAccessesTensor) {
+      const auto enc = getSparseTensorEncoding(tensor.getType());
+      if (!enc) {
+        // Dense tensor - highest priority
+        return 0;
+      } else {
+        // Sparse tensor - check the level type for this dimension
+        auto lvlTypes = enc.getLvlTypes();
+        if (tensorDim < lvlTypes.size()) {
+          auto lvlType = lvlTypes[tensorDim];
+          if (isDenseLT(lvlType)) {
+            return 0; // Dense level
+          } else if (isCompressedLT(lvlType)) {
+            bestRank = std::min(bestRank, 1u); // Compressed level
+          } else if (isSingletonLT(lvlType)) {
+            bestRank = std::min(bestRank, 2u); // Singleton level
+          }
+        }
+      }
+    }
+  }
+  
+  return bestRank;
+}
+
 AffineMap IterationGraphSorter::topoSort() {
   // The sorted result will put the first Reduction iterator to the
   // latest possible position.
@@ -107,10 +154,33 @@ AffineMap IterationGraphSorter::topoSort() {
     case sparse_tensor::LoopOrderingStrategy::kDefault:
       src = it.back();
       break;
+    case sparse_tensor::LoopOrderingStrategy::kDenseOuter: {
+      // Prefer dense, then compressed, then singleton dimensions outermost.
+      // Create combined tensor and map lists for analysis.
+      SmallVector<Value> allTensors = ins;
+      allTensors.push_back(out);
+      SmallVector<AffineMap> allMaps = loop2InsLvl;
+      allMaps.push_back(loop2OutLvl);
+      
+      // Find loop with best (lowest) sparsity rank.
+      unsigned bestLoop = it[0];
+      unsigned bestRank = getLoopSparsityRank(bestLoop, allTensors, allMaps);
+      
+      for (auto candidateLoop : it) {
+        unsigned rank = getLoopSparsityRank(candidateLoop, allTensors, allMaps);
+        if (rank < bestRank || (rank == bestRank && candidateLoop < bestLoop)) {
+          bestLoop = candidateLoop;
+          bestRank = rank;
+        }
+      }
+      src = bestLoop;
+      break;
+    }
     }
     
     loopOrder.push_back(src);
-    it.pop_back();
+    // Remove the selected loop from the worklist.
+    it.erase(std::find(it.begin(), it.end(), src));
     // Update in-degree, and push 0-degree node into worklist.
     for (unsigned dst = 0; dst < numLoops; dst++) {
       if (itGraph[src][dst] && --inDegree[dst] == 0) {

>From b5a1a11f4c5aaa0dffd5f0114a08e70bc60fcef3 Mon Sep 17 00:00:00 2001
From: gmalasan <145235389+gmalasan at users.noreply.github.com>
Date: Thu, 2 Oct 2025 18:35:35 -0400
Subject: [PATCH 7/7] [MLIR][SparseTensor] Format and rebase branch after
 merging infrastructure branch

---
 .../Dialect/SparseTensor/Transforms/Passes.h  | 18 ++++++++-----
 .../Transforms/SparseReinterpretMap.cpp       | 15 ++++++-----
 .../Transforms/SparseTensorPasses.cpp         |  5 ++--
 .../Transforms/Utils/IterationGraphSorter.cpp | 27 +++++++++----------
 .../Transforms/Utils/IterationGraphSorter.h   | 13 +++++----
 5 files changed, 42 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 667c73e8f3a48..b401458f5ddc5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -59,8 +59,10 @@ namespace sparse_tensor {
 
 /// Defines a strategy for loop ordering during sparse code generation.
 enum class LoopOrderingStrategy : unsigned {
-  kDefault, ///< Default strategy (eagerly selects last loop in topological sort).
-  kDenseOuter, ///< Prefer dense, then sparse, then singleton dimensions outermost.
+  kDefault,    ///< Default strategy (eagerly selects last loop in topological
+               ///< sort).
+  kDenseOuter, ///< Prefer dense, then sparse, then singleton dimensions
+               ///< outermost.
 };
 
 } // namespace sparse_tensor
@@ -81,14 +83,16 @@ std::unique_ptr<Pass> createSparseAssembler(bool directOut);
 // The SparseReinterpretMap pass.
 //===----------------------------------------------------------------------===//
 
-void populateSparseReinterpretMap(RewritePatternSet &patterns,
-                                  ReinterpretMapScope scope,
-                                  sparse_tensor::LoopOrderingStrategy strategy = sparse_tensor::LoopOrderingStrategy::kDefault);
+void populateSparseReinterpretMap(
+    RewritePatternSet &patterns, ReinterpretMapScope scope,
+    sparse_tensor::LoopOrderingStrategy strategy =
+        sparse_tensor::LoopOrderingStrategy::kDefault);
 
 std::unique_ptr<Pass> createSparseReinterpretMapPass();
 std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope);
-std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope, 
-                                                     sparse_tensor::LoopOrderingStrategy strategy);
+std::unique_ptr<Pass>
+createSparseReinterpretMapPass(ReinterpretMapScope scope,
+                               sparse_tensor::LoopOrderingStrategy strategy);
 
 //===----------------------------------------------------------------------===//
 // The PreSparsificationRewriting pass.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index e068eee0587a0..0fc5cc76de39c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -59,7 +59,7 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
 
 // Flattens an affine expression into a list of AffineDimExprs.
 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
-  explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
+  explicit AffineDimCollector(unsigned dimNum) : dims(dimNum) {};
   void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
   BitVector dims;
 };
@@ -67,7 +67,7 @@ struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
 // Flattens an affine expression into a list of AffineDimExprs.
 struct AffineExprAdmissibleVisitor
     : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
-  explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput){};
+  explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput) {};
 
   // We only allow AffineDimExpr on output.
   void visitAddExpr(AffineBinaryOpExpr expr) {
@@ -407,9 +407,10 @@ struct GenericOpReinterpretMap
 };
 
 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
-  GenericOpScheduler(MLIRContext *context, sparse_tensor::LoopOrderingStrategy strategy)
+  GenericOpScheduler(MLIRContext *context,
+                     sparse_tensor::LoopOrderingStrategy strategy)
       : OpRewritePattern<linalg::GenericOp>(context), strategy(strategy) {}
-  
+
   LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
                                 PatternRewriter &rewriter) const override {
     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
@@ -792,9 +793,9 @@ struct ForeachOpDemapper
 
 } // namespace
 
-void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
-                                        ReinterpretMapScope scope,
-                                        sparse_tensor::LoopOrderingStrategy strategy) {
+void mlir::populateSparseReinterpretMap(
+    RewritePatternSet &patterns, ReinterpretMapScope scope,
+    sparse_tensor::LoopOrderingStrategy strategy) {
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kGenericOnly) {
     patterns.add<GenericOpReinterpretMap>(patterns.getContext());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index aa31927e0602c..b660e22154688 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -439,9 +439,8 @@ mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
   return std::make_unique<SparseReinterpretMap>(options);
 }
 
-std::unique_ptr<Pass>
-mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope, 
-                                     sparse_tensor::LoopOrderingStrategy strategy) {
+std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass(
+    ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy) {
   SparseReinterpretMapOptions options;
   options.scope = scope;
   options.loopOrderingStrategy = strategy;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index c351cd5792136..7ebe20e9674ce 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -83,11 +83,10 @@ inline static bool includesDenseOutput(SortMask mask) {
 /// Returns a sparsity rank for loop ordering: lower values indicate
 /// dimensions that should be placed in outer loops.
 /// 0 = Dense, 1 = Compressed, 2 = Singleton, 3 = Other/Unknown
-static unsigned getLoopSparsityRank(unsigned loop,
-                                   ArrayRef<Value> allTensors,
-                                   ArrayRef<AffineMap> allMaps) {
+static unsigned getLoopSparsityRank(unsigned loop, ArrayRef<Value> allTensors,
+                                    ArrayRef<AffineMap> allMaps) {
   unsigned bestRank = 3; // Default: most sparse (unknown/singleton-like)
-  
+
   for (auto [tensor, map] : llvm::zip(allTensors, allMaps)) {
     // Check if this loop accesses this tensor
     bool loopAccessesTensor = false;
@@ -101,7 +100,7 @@ static unsigned getLoopSparsityRank(unsigned loop,
       }
       tensorDim++;
     }
-    
+
     if (loopAccessesTensor) {
       const auto enc = getSparseTensorEncoding(tensor.getType());
       if (!enc) {
@@ -123,7 +122,7 @@ static unsigned getLoopSparsityRank(unsigned loop,
       }
     }
   }
-  
+
   return bestRank;
 }
 
@@ -147,7 +146,7 @@ AffineMap IterationGraphSorter::topoSort() {
     // We always prefer a parallel loop over a reduction loop because putting
     // a reduction loop early might make the loop sequence inadmissible.
     auto &it = !parIt.empty() ? parIt : redIt;
-    
+
     // Select loop based on strategy.
     unsigned src;
     switch (strategy) {
@@ -161,11 +160,11 @@ AffineMap IterationGraphSorter::topoSort() {
       allTensors.push_back(out);
       SmallVector<AffineMap> allMaps = loop2InsLvl;
       allMaps.push_back(loop2OutLvl);
-      
+
       // Find loop with best (lowest) sparsity rank.
       unsigned bestLoop = it[0];
       unsigned bestRank = getLoopSparsityRank(bestLoop, allTensors, allMaps);
-      
+
       for (auto candidateLoop : it) {
         unsigned rank = getLoopSparsityRank(candidateLoop, allTensors, allMaps);
         if (rank < bestRank || (rank == bestRank && candidateLoop < bestLoop)) {
@@ -177,7 +176,7 @@ AffineMap IterationGraphSorter::topoSort() {
       break;
     }
     }
-    
+
     loopOrder.push_back(src);
     // Remove the selected loop from the worklist.
     it.erase(std::find(it.begin(), it.end(), src));
@@ -200,9 +199,8 @@ AffineMap IterationGraphSorter::topoSort() {
   return AffineMap();
 }
 
-IterationGraphSorter
-IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp, 
-                                   sparse_tensor::LoopOrderingStrategy strategy) {
+IterationGraphSorter IterationGraphSorter::fromGenericOp(
+    linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy) {
   // Must be a demapped sparse kernel.
   assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
          hasAnySparseOperandOrResult(genericOp) &&
@@ -227,7 +225,8 @@ IterationGraphSorter::IterationGraphSorter(
     AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes,
     sparse_tensor::LoopOrderingStrategy strategy)
     : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
-      loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)), strategy(strategy) {
+      loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
+      strategy(strategy) {
   // One map per tensor.
   assert(loop2InsLvl.size() == ins.size());
   // All the affine maps have the same number of dimensions (loops).
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
index 868a978e9a462..b2a16e9382758 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
@@ -13,8 +13,8 @@
 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
 
-#include "mlir/IR/AffineMap.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/IR/AffineMap.h"
 
 namespace mlir {
 
@@ -43,9 +43,11 @@ enum class SortMask : unsigned {
 class IterationGraphSorter {
 public:
   /// Factory method that constructs an iteration graph sorter
-  /// for the given linalg.generic operation with a specific loop ordering strategy.
-  static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp,
-                                           sparse_tensor::LoopOrderingStrategy strategy);
+  /// for the given linalg.generic operation with a specific loop ordering
+  /// strategy.
+  static IterationGraphSorter
+  fromGenericOp(linalg::GenericOp genericOp,
+                sparse_tensor::LoopOrderingStrategy strategy);
 
   /// Returns a permutation that represents the scheduled loop order.
   /// Note that the returned AffineMap could be null if the kernel
@@ -61,7 +63,8 @@ class IterationGraphSorter {
                        SmallVector<AffineMap> &&loop2InsLvl, Value out,
                        AffineMap loop2OutLvl,
                        SmallVector<utils::IteratorType> &&iterTypes,
-                       sparse_tensor::LoopOrderingStrategy strategy = sparse_tensor::LoopOrderingStrategy::kDefault);
+                       sparse_tensor::LoopOrderingStrategy strategy =
+                           sparse_tensor::LoopOrderingStrategy::kDefault);
 
   // Adds all the constraints in the given loop to level map.
   void addConstraints(Value t, AffineMap loop2LvlMap);



More information about the Mlir-commits mailing list