[Mlir-commits] [mlir] 5b333d3 - [mlir][sparse] do not ignore ordering for "dense" tensor linked with sparse type
Aart Bik
llvmlistbot at llvm.org
Tue Mar 2 15:22:14 PST 2021
Author: Aart Bik
Date: 2021-03-02T15:21:51-08:00
New Revision: 5b333d3449fa93cd5b554e1e8b16892031bd8bdf
URL: https://github.com/llvm/llvm-project/commit/5b333d3449fa93cd5b554e1e8b16892031bd8bdf
DIFF: https://github.com/llvm/llvm-project/commit/5b333d3449fa93cd5b554e1e8b16892031bd8bdf.diff
LOG: [mlir][sparse] do not ignore ordering for "dense" tensor linked with sparse type
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D97795
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 75fb7f716755..7110695576c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -357,6 +357,14 @@ static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
}
}
+/// Returns true if tensor was set up with sparse storage scheme.
+static bool linkedSparse(linalg::GenericOp op, unsigned tensor) {
+ if (tensor < op.getNumInputs())
+ return isa_and_nonnull<linalg::SparseTensorFromPointerOp>(
+ op.getInput(tensor).getDefiningOp());
+ return false;
+}
+
/// A DFS helper to compute a topological sort. Note that recursion is
/// bounded by the number of implicit loops, which is always small.
/// Returns false when a cycle is detected.
@@ -394,7 +402,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
auto map = op.getIndexingMap(t);
assert(map.getNumDims() == n);
// Skip dense tensor constraints when sparse only is requested.
- if (sparseOnly && !merger.isSparseTensor(t))
+ if (sparseOnly && !merger.isSparseTensor(t) && !linkedSparse(op, t))
continue;
// At the moment, we take the index variables in the tensor access
// expression in the order in which they appear (conceptually a
@@ -513,14 +521,6 @@ static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) {
llvm_unreachable("unexpected SparseIntType");
}
-/// Returns true if tensor was set up with sparse storage scheme.
-static bool linkedSparse(linalg::GenericOp op, unsigned tensor) {
- if (tensor < op.getNumInputs())
- return isa_and_nonnull<linalg::SparseTensorFromPointerOp>(
- op.getInput(tensor).getDefiningOp());
- return false;
-}
-
/// Generates buffer for the output tensor.
static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
linalg::GenericOp op, MemRefType denseTp,
@@ -1004,7 +1004,7 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
if (needsUniv) {
types.push_back(indexType);
assert(codegen.loops[idx].getType().isa<IndexType>() &&
- "type_mismatch for universal index");
+ "type mismatch for universal index");
operands.push_back(codegen.loops[idx]);
}
Location loc = op.getLoc();
More information about the Mlir-commits
mailing list