[Mlir-commits] [mlir] 5b333d3 - [mlir][sparse] do not ignore ordering for "dense" tensor linked with sparse type

Aart Bik llvmlistbot at llvm.org
Tue Mar 2 15:22:14 PST 2021


Author: Aart Bik
Date: 2021-03-02T15:21:51-08:00
New Revision: 5b333d3449fa93cd5b554e1e8b16892031bd8bdf

URL: https://github.com/llvm/llvm-project/commit/5b333d3449fa93cd5b554e1e8b16892031bd8bdf
DIFF: https://github.com/llvm/llvm-project/commit/5b333d3449fa93cd5b554e1e8b16892031bd8bdf.diff

LOG: [mlir][sparse] do not ignore ordering for "dense" tensor linked with sparse type

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D97795

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 75fb7f716755..7110695576c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -357,6 +357,14 @@ static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
   }
 }
 
+/// Returns true if tensor was set up with sparse storage scheme.
+static bool linkedSparse(linalg::GenericOp op, unsigned tensor) {
+  if (tensor < op.getNumInputs())
+    return isa_and_nonnull<linalg::SparseTensorFromPointerOp>(
+        op.getInput(tensor).getDefiningOp());
+  return false;
+}
+
 /// A DFS helper to compute a topological sort. Note that recursion is
 /// bounded by the number of implicit loops, which is always small.
 /// Returns false when a cycle is detected.
@@ -394,7 +402,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
     auto map = op.getIndexingMap(t);
     assert(map.getNumDims() == n);
     // Skip dense tensor constraints when sparse only is requested.
-    if (sparseOnly && !merger.isSparseTensor(t))
+    if (sparseOnly && !merger.isSparseTensor(t) && !linkedSparse(op, t))
       continue;
     // At the moment, we take the index variables in the tensor access
     // expression in the order in which they appear (conceptually a
@@ -513,14 +521,6 @@ static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) {
   llvm_unreachable("unexpected SparseIntType");
 }
 
-/// Returns true if tensor was set up with sparse storage scheme.
-static bool linkedSparse(linalg::GenericOp op, unsigned tensor) {
-  if (tensor < op.getNumInputs())
-    return isa_and_nonnull<linalg::SparseTensorFromPointerOp>(
-        op.getInput(tensor).getDefiningOp());
-  return false;
-}
-
 /// Generates buffer for the output tensor.
 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
                              linalg::GenericOp op, MemRefType denseTp,
@@ -1004,7 +1004,7 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
   if (needsUniv) {
     types.push_back(indexType);
     assert(codegen.loops[idx].getType().isa<IndexType>() &&
-           "type_mismatch for universal index");
+           "type mismatch for universal index");
     operands.push_back(codegen.loops[idx]);
   }
   Location loc = op.getLoc();


        


More information about the Mlir-commits mailing list