[Mlir-commits] [mlir] b6d1a31 - [mlir][sparse] refine heuristic for iteration graph topsort

Aart Bik llvmlistbot at llvm.org
Fri Sep 3 08:37:23 PDT 2021


Author: Aart Bik
Date: 2021-09-03T08:37:15-07:00
New Revision: b6d1a31c1b88227b38b3dad7b4430fa42516e82c

URL: https://github.com/llvm/llvm-project/commit/b6d1a31c1b88227b38b3dad7b4430fa42516e82c
DIFF: https://github.com/llvm/llvm-project/commit/b6d1a31c1b88227b38b3dad7b4430fa42516e82c.diff

LOG: [mlir][sparse] refine heuristic for iteration graph topsort

The sparse index order must always be satisfied, but this
may give a choice in topsorts for several cases. We broke
ties in favor of any dense index order, since this gives
good locality. However, breaking ties in favor of pushing
unrelated indices into sparse iteration spaces gives better
asymptotic complexity. This revision improves the heuristic.

Note that in the long run, we are really interested in using
ML for ML to find the best loop ordering as a replacement for
such heuristics.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D109100

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/sparse_2d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index f30318649dc07..6fb09cabda80b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -30,6 +30,9 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
+// Iteration graph sorting.
+enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 };
+
 // Code generation.
 struct CodeGen {
   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
@@ -141,7 +144,7 @@ static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
 /// order yields innermost unit-stride access with better spatial locality.
 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
                                   std::vector<unsigned> &topSort,
-                                  bool sparseOnly) {
+                                  unsigned mask) {
   // Set up an n x n from/to adjacency matrix of the iteration graph
   // for the implicit loop indices i_0 .. i_n-1.
   unsigned n = op.getNumLoops();
@@ -152,8 +155,8 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
     auto map = op.getTiedIndexingMap(t);
     auto enc = getSparseTensorEncoding(t->get().getType());
     assert(map.getNumDims() == n);
-    // Skip dense tensor constraints when sparse only is requested.
-    if (sparseOnly && !enc)
+    // Skip dense tensor constraints when not requested.
+    if (!(mask & SortMask::kIncludeDense) && !enc)
       continue;
     // Each tensor expression and optional dimension ordering (row-major
     // by default) puts an ordering constraint on the loop indices. For
@@ -164,6 +167,16 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
       unsigned t = map.getDimPosition(perm(enc, d));
       adjM[f][t] = true;
     }
+    // Push unrelated loops into sparse iteration space, so these
+    // will be skipped more often.
+    if (mask & SortMask::kIncludeUndef) {
+      unsigned tensor = t->getOperandNumber();
+      for (unsigned i = 0; i < n; i++)
+        if (merger.isDim(tensor, i, Dim::kSparse))
+          for (unsigned j = 0; j < n; j++)
+            if (merger.isDim(tensor, j, Dim::kUndef))
+              adjM[i][j] = true;
+    }
   }
 
   // Topologically sort the iteration graph to determine loop order.
@@ -1134,8 +1147,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     // This assumes that higher-level passes have already put the
     // tensors in each tensor expression in a feasible order.
     std::vector<unsigned> topSort;
-    if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
-        !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
+    if (!computeIterationGraph(merger, op, topSort,
+                               SortMask::kIncludeUndef |
+                                   SortMask::kIncludeDense) &&
+        !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) &&
+        !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) &&
+        !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly))
       return failure();
 
     // Builds the tensor expression for the Linalg operation in SSA form.

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index 9f4b3ac4d69ce..ffc6e8de3a9ad 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -1043,25 +1043,26 @@ func @scale(%arga: tensor<?x?xf64, #Tds>, %argx: tensor<?x?xf64>) -> tensor<?x?x
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_20:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex>
-// CHECK:             scf.for %[[VAL_22:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] {
-// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_22]]] : memref<?x?xf32>
-// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
-// CHECK:               %[[VAL_25:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index
-// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_25]]] : memref<?xindex>
-// CHECK:               scf.for %[[VAL_27:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] {
-// CHECK:                 %[[VAL_28:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK:                 %[[VAL_29:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref<?x?xf32>
-// CHECK:                 %[[VAL_30:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_27]]] : memref<?xf32>
-// CHECK:                 %[[VAL_31:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_28]]] : memref<?x?xf32>
-// CHECK:                 %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32
-// CHECK:                 %[[VAL_33:.*]] = mulf %[[VAL_30]], %[[VAL_32]] : f32
-// CHECK:                 %[[VAL_34:.*]] = addf %[[VAL_29]], %[[VAL_33]] : f32
-// CHECK:                 memref.store %[[VAL_34]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref<?x?xf32>
+// CHECK:             %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK:             %[[VAL_23:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK:             scf.for %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] {
+// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
+// CHECK:               %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (f32) {
+// CHECK:                 %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref<?x?xf32>
+// CHECK:                 %[[VAL_33:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_30]], %[[VAL_26]]] : memref<?x?xf32>
+// CHECK:                 %[[VAL_34:.*]] = mulf %[[VAL_32]], %[[VAL_33]] : f32
+// CHECK:                 %[[VAL_35:.*]] = mulf %[[VAL_27]], %[[VAL_34]] : f32
+// CHECK:                 %[[VAL_36:.*]] = addf %[[VAL_31]], %[[VAL_35]] : f32
+// CHECK:                 scf.yield %[[VAL_36]] : f32
 // CHECK:               }
+// CHECK:               memref.store %[[VAL_37:.*]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_35:.*]] = memref.tensor_load %[[VAL_17]] : memref<?x?xf32>
-// CHECK:           return %[[VAL_35]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_38:.*]] = memref.tensor_load %[[VAL_17]] : memref<?x?xf32>
+// CHECK:           return %[[VAL_38]] : tensor<?x?xf32>
 // CHECK:         }
 func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
                           %arga: tensor<?x?xf32>,


        


More information about the Mlir-commits mailing list