[Mlir-commits] [mlir] 1ece4d3 - [mlir][sparse] code simplification: always use synthetical tensor for… (#73597)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 27 17:41:48 PST 2023
Author: Peiming Liu
Date: 2023-11-27T17:41:45-08:00
New Revision: 1ece4d3a0dc951aa349616c3c8740d8b3f9926c1
URL: https://github.com/llvm/llvm-project/commit/1ece4d3a0dc951aa349616c3c8740d8b3f9926c1
DIFF: https://github.com/llvm/llvm-project/commit/1ece4d3a0dc951aa349616c3c8740d8b3f9926c1.diff
LOG: [mlir][sparse] code simplification: always use synthetical tensor for… (#73597)
… loop bound.
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 69072b91b2fa523..a245344755f0404 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -339,9 +339,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
const SparseTensorType stt(rtp);
lvlRank = stt.getLvlRank();
- // We always treat sparse output tensor as dense so that we always iterate
- // it based on lvl size.
- if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
+ if (stt.hasEncoding()) {
const auto enc = stt.getEncoding();
isSparseSlices[tid] = enc.isSlice();
for (auto lvlTp : enc.getLvlTypes())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 3fb90ef379a5778..e0d3ce241e454d0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1059,28 +1059,20 @@ static bool translateBitsToTidLvlPairs(
}
if (isUndefLT(lt)) {
// An undefined lt in the lattices, we probably mean to
- // iterate based on the level of output tensor. E.g., this
- // could be a synthetic tensor (for invariants and sparse
- // output tensor).
- auto itType = env.op().getIteratorTypesArray()[ldx];
- if (linalg::isReductionIterator(itType) &&
- env.merger().getSynTensorID() == tid) {
- // Coiterating with an invariant, and this is a reduction loop
+ // generate a dense loop according to the synthetic tensor (for
+ // invariants and sparse output tensor).
+ if (env.merger().getSynTensorID() == tid) {
+ // Coiterating with an invariant
// e.g., out = prod(in[i][j] op invariant);
- // In this case, we can not infer the loop bound from output
- // (whose level is reduced). Instead we use the synthetic tensor
- // to infer the bound.
+ // or a broadcast
+ // e.g., out[i][j] = in[i] (j is undef for input)
+ //
// The level of the synthetic tensor is the current loop depth;
// the rank of the synthetic tensor equals to number of loops.
lvl = env.emitter().getCurrentDepth();
- } else {
- // or a broadcast
- // out[i][j] = in[i] (j is undef for input)
- tid = outTid;
- lvl = outLvl;
+ } else if (!lvl) {
// Skips invalid lvl (e.g., when this is a zero ranked tensor).
- if (!lvl)
- return;
+ return;
}
}
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
More information about the Mlir-commits
mailing list