[Mlir-commits] [mlir] 1ece4d3 - [mlir][sparse] code simplification: always use synthetical tensor for… (#73597)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 17:41:48 PST 2023


Author: Peiming Liu
Date: 2023-11-27T17:41:45-08:00
New Revision: 1ece4d3a0dc951aa349616c3c8740d8b3f9926c1

URL: https://github.com/llvm/llvm-project/commit/1ece4d3a0dc951aa349616c3c8740d8b3f9926c1
DIFF: https://github.com/llvm/llvm-project/commit/1ece4d3a0dc951aa349616c3c8740d8b3f9926c1.diff

LOG: [mlir][sparse] code simplification: always use synthetical tensor for… (#73597)

… loop bound.

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 69072b91b2fa523..a245344755f0404 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -339,9 +339,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
       const SparseTensorType stt(rtp);
       lvlRank = stt.getLvlRank();
 
-      // We always treat sparse output tensor as dense so that we always iterate
-      // it based on lvl size.
-      if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
+      if (stt.hasEncoding()) {
         const auto enc = stt.getEncoding();
         isSparseSlices[tid] = enc.isSlice();
         for (auto lvlTp : enc.getLvlTypes())

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 3fb90ef379a5778..e0d3ce241e454d0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1059,28 +1059,20 @@ static bool translateBitsToTidLvlPairs(
       }
       if (isUndefLT(lt)) {
         // An undefined lt in the lattices, we probably mean to
-        // iterate based on the level of output tensor.  E.g., this
-        // could be a synthetic tensor (for invariants and sparse
-        // output tensor).
-        auto itType = env.op().getIteratorTypesArray()[ldx];
-        if (linalg::isReductionIterator(itType) &&
-            env.merger().getSynTensorID() == tid) {
-          // Coiterating with an invariant, and this is a reduction loop
+        // generate a dense loop according to the synthetic tensor (for
+        // invariants and sparse output tensor).
+        if (env.merger().getSynTensorID() == tid) {
+          // Coiterating with an invariant
           // e.g., out = prod(in[i][j] op invariant);
-          // In this case, we can not infer the loop bound from output
-          // (whose level is reduced). Instead we use the synthetic tensor
-          // to infer the bound.
+          // or a broadcast
+          // e.g., out[i][j] = in[i] (j is undef for input)
+          //
           // The level of the synthetic tensor is the current loop depth;
           // the rank of the synthetic tensor equals to number of loops.
           lvl = env.emitter().getCurrentDepth();
-        } else {
-          // or a broadcast
-          // out[i][j] = in[i] (j is undef for input)
-          tid = outTid;
-          lvl = outLvl;
+        } else if (!lvl) {
           // Skips invalid lvl (e.g., when this is a zero ranked tensor).
-          if (!lvl)
-            return;
+          return;
         }
       }
       hasNonUnique = !isUniqueLT(lt) || hasNonUnique;


        


More information about the Mlir-commits mailing list