[Mlir-commits] [mlir] b1d1964 - [mlir][sparse] Only try to compute a better iteraton graph when needed

Peiming Liu llvmlistbot at llvm.org
Fri Sep 16 15:53:40 PDT 2022


Author: Peiming Liu
Date: 2022-09-16T22:53:32Z
New Revision: b1d1964771d95e2409a1a94a83091919033b39b7

URL: https://github.com/llvm/llvm-project/commit/b1d1964771d95e2409a1a94a83091919033b39b7
DIFF: https://github.com/llvm/llvm-project/commit/b1d1964771d95e2409a1a94a83091919033b39b7.diff

LOG: [mlir][sparse] Only try to compute a better iteraton graph when needed

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134059

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index e283748cb79c8..bce09130e58bb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1832,26 +1832,30 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     // code generation can proceed. As a last resort, an attempt is made
     // to resolve cycles by inserting a conversion.
     std::vector<unsigned> topSort;
-    // Whether the current GenericOp is admissible
+    // Whether the current GenericOp is admissible.
     bool isAdmissible = false;
+    bool hasCycle = true;
     // An const list of all masks that we used for interation graph
     // computation. Must be ordered from strict -> loose.
     const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef,
                           SortMask::kIncludeDense, SortMask::kSparseOnly};
-    for (auto mask : allMask) {
-      if (computeIterationGraph(merger, op, topSort, mask) &&
-          isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut,
-                                outerParNest)) {
-        // This is an admissible GenericOp.
-        isAdmissible = true;
-        break;
+    for (auto mask : allMask)
+      if (computeIterationGraph(merger, op, topSort, mask)) {
+        hasCycle = false;
+        if (isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut,
+                                  outerParNest)) {
+          isAdmissible = true;
+          break;
+        }
+        // else try a set of less strict constraints.
       }
-      // else try a less strict constraints.
-    }
 
-    if (!isAdmissible)
+    if (hasCycle)
       // Give it one last shot to resolve the cycle.
       return resolveCycle(merger, rewriter, op);
+    if (!isAdmissible)
+      // Inadmissible expression, reject.
+      return failure();
 
     // Recursively generates code if admissible.
     merger.setHasSparseOut(sparseOut != nullptr);


        


More information about the Mlir-commits mailing list