[Mlir-commits] [mlir] e7b4c93 - [mlir][sparse] fix crash when using sparse_tensor::UnaryOp and ReduceOp.

Peiming Liu llvmlistbot at llvm.org
Fri Jun 2 18:19:11 PDT 2023


Author: Peiming Liu
Date: 2023-06-03T01:19:05Z
New Revision: e7b4c93f5e609728bbfc4d7d34b29cd6ac92a0b0

URL: https://github.com/llvm/llvm-project/commit/e7b4c93f5e609728bbfc4d7d34b29cd6ac92a0b0
DIFF: https://github.com/llvm/llvm-project/commit/e7b4c93f5e609728bbfc4d7d34b29cd6ac92a0b0.diff

LOG: [mlir][sparse] fix crash when using sparse_tensor::UnaryOp and ReduceOp.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D152048

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 8d04ecff5a75..67f3c30eb4db 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -232,7 +232,10 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   this->hasOutput = hasOutput;
   this->isSparseOut = isSparseOut;
 
-  const unsigned numTensors = ts.size();
+  const unsigned numManifestTensors = ts.size();
+  const unsigned synTensorId = numManifestTensors;
+  const unsigned numTensors = numManifestTensors + 1;
+
   this->tensors.assign(ts.begin(), ts.end());
   this->lvlTypes.assign(numTensors, std::vector<DimLevelType>());
   this->lvlSizes.assign(numTensors, std::vector<Value>());
@@ -265,33 +268,43 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
 
   // Initialize nested types of `TensorId`-indexed fields.
   for (TensorId tid = 0; tid < numTensors; tid++) {
-    const Value t = tensors[tid];
-    // a scalar or 0-dimension tensors
-    if (isZeroRankedTensorOrScalar(t.getType()))
-      continue;
-
-    auto rtp = getRankedTensorType(t);
-    if (auto reshape = t.getDefiningOp<tensor::CollapseShapeOp>();
-        isUniqueCOOType(rtp) && reshape) {
-      // TODO: Supports more kinds of sparse tensors.
-      // FIXME: We should instead lower reshape operations on sparse tensors to
-      // view change.
-      collapseReassoc[tid] = reshape.getReassociation();
-      rtp = reshape.getSrcType();
-      // Overwrites the tensor to the source tensor of reshape operations.
-      tensors[tid] = reshape.getSrc();
-    }
-    const SparseTensorType stt(rtp);
-    const Level lvlRank = stt.getLvlRank();
-    // We always treat sparse output tensor as dense so that we always iterate
-    // it based on lvl size.
-    if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
-      const auto enc = stt.getEncoding();
-      isSparseSlices[tid] = enc.isSlice();
-      for (auto lvlTp : enc.getLvlTypes())
-        lvlTypes[tid].push_back(lvlTp);
-    } else {
+    Level lvlRank;
+    if (tid == synTensorId) {
+      // Synthetic tensor (conceptually) is an all-dense tensor with rank equal
+      // to the total number of loops (each level can potentially be mapped to
+      // one of the loop being generated).
+      lvlRank = numLoops;
       lvlTypes[tid].assign(lvlRank, DimLevelType::Dense);
+    } else {
+      const Value t = tensors[tid];
+      // a scalar or 0-dimension tensors
+      if (isZeroRankedTensorOrScalar(t.getType()))
+        continue;
+
+      auto rtp = getRankedTensorType(t);
+      if (auto reshape = t.getDefiningOp<tensor::CollapseShapeOp>();
+          isUniqueCOOType(rtp) && reshape) {
+        // TODO: Supports more kinds of sparse tensors.
+        // FIXME: We should instead lower reshape operations on sparse tensors
+        // to view change.
+        collapseReassoc[tid] = reshape.getReassociation();
+        rtp = reshape.getSrcType();
+        // Overwrites the tensor to the source tensor of reshape operations.
+        tensors[tid] = reshape.getSrc();
+      }
+      const SparseTensorType stt(rtp);
+      lvlRank = stt.getLvlRank();
+
+      // We always treat sparse output tensor as dense so that we always iterate
+      // it based on lvl size.
+      if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
+        const auto enc = stt.getEncoding();
+        isSparseSlices[tid] = enc.isSlice();
+        for (auto lvlTp : enc.getLvlTypes())
+          lvlTypes[tid].push_back(lvlTp);
+      } else {
+        lvlTypes[tid].assign(lvlRank, DimLevelType::Dense);
+      }
     }
 
     // Initialize using empty value.
@@ -314,7 +327,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
     sliceStack[tid].emplace_back(/*minCrd=*/Value(),
                                  /*offset=*/Value(), /*isNonEmpty*/ Value(),
                                  std::nullopt, 0);
-    if (dimGetter) {
+    if (dimGetter && !isSynTensor(tid)) {
       auto reassoc = collapseReassoc[tid];
       Level dstRank = reassoc ? reassoc.size() : lvlRank;
       for (Level l = 0; l < dstRank; l++) {
@@ -461,15 +474,28 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
   assert(loopSeqStack.size() == loopStack.size());
   // Prepares for all the tensors used in the current loop sequence.
   std::vector<std::tuple<TensorId, Level, bool>> slicedTids;
+
+  bool hasSynTensor = false;
+  std::optional<std::pair<TensorId, Level>> loopBoundDefLevel = std::nullopt;
   for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
     if (!dependentLvlMap[tid][lvl].empty()) {
       bool fullyRed = genSliceBegin(builder, loc, tid, lvl);
       slicedTids.emplace_back(tid, lvl, fullyRed);
     } else {
-      prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
+      if (isSynTensor(tid)) {
+        hasSynTensor = true;
+      } else {
+        loopBoundDefLevel = std::make_pair(tid, lvl);
+        prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
+      }
     }
   }
 
+  if (hasSynTensor && loopBoundDefLevel.has_value()) {
+    // TODO: compute the loopBound for index reduction by d - sum(unres_lvls).
+    highs[getSynTensorId()][getCurrentDepth()] =
+        lvlSizes[loopBoundDefLevel->first][loopBoundDefLevel->second];
+  }
   // Universal Index starts from 0.
   loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids));
 }
@@ -1137,6 +1163,9 @@ void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(
   // output tensor unconditionally, since they may not appear in the lattice,
   // but may be needed for linearized codegen.
   for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
+    if (isSynTensor(tid))
+      continue;
+
     if (isDenseDLT(lvlTypes[tid][lvl])) {
       // Slice-driven dense level should have be handled already.
       if (!dependentLvlMap[tid][lvl].empty())

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 03715785d284..ded58f2d4b01 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -194,14 +194,18 @@ class LoopEmitter {
   /// Gets the total number of tensors that loopEmitter is operating on.
   unsigned getNumTensors() const { return tensors.size(); }
 
+  /// Gets the TensorId for synthetic tensor.
+  TensorId getSynTensorId() const { return tensors.size(); }
+
   /// Compresses a TensorId and Level into a TensorLevel.
   TensorLevel makeTensorLevel(TensorId t, Level l) const {
-    return l * getNumTensors() + t;
+    // TODO: getNumTensor() should include synthetic tensor.
+    return l * (getNumTensors() + 1) + t;
   }
 
   /// De-compresses a TensorLevel back to a pair of TensorId and Level.
   std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tidLvl) const {
-    unsigned nt = getNumTensors();
+    unsigned nt = getNumTensors() + 1;
     return std::make_pair(tidLvl % nt, tidLvl / nt);
   }
 
@@ -319,6 +323,8 @@ class LoopEmitter {
                                                  Location loc, Value crd,
                                                  TensorId tid, Level lvl);
 
+  bool isSynTensor(TensorId tid) const { return tid == getNumTensors(); }
+
   bool isOutputTensor(TensorId tid) const {
     return hasOutput && tid == getNumTensors() - 1;
   }
@@ -408,9 +414,11 @@ class LoopEmitter {
   /// TODO: why not do this computation when we first store the reassoc,
   /// instead of doing it every time we look it up?
   SmallVector<Level, 2> getCollapseReassociation(TensorId tid, Level dstLvl) {
-    assert(tid < getNumTensors() && "Invalid TensorId");
-    assert(collapseReassoc.size() == getNumTensors());
+    assert(tid < getNumTensors() + 1 && "Invalid TensorId");
+    assert(collapseReassoc.size() == getNumTensors() + 1);
     if (const auto reassoc = collapseReassoc[tid]) {
+      assert(!isSynTensor(tid) && !isOutputTensor(tid) &&
+             "Output/Synthetic tensor should not have reassociation");
       // TODO: store the dstLvlRank in the LoopEmitter so that we can
       // check `dstLvl < dstLvlRank` at the top; and only here need to
       // assert that `reassoc.size() == dstLvlRank`.

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d9f363adbd0b..4e7e8f767b6c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1490,8 +1490,15 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
                                            std::optional<Level> lvl,
                                            DimLevelType dlt, bool isIdxReduc) {
     assert(env.merger().loop(b) == idx);
-    if (isDenseDLT(dlt) || isUndefDLT(dlt))
+    if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
+      if (tid == env.merger().getSynTensorID()) {
+        // Needs loop emitter to set up loop bounds for synthetic tensor too if
+        // there is a loop condition imposed on the synthetic tensor.
+        tidLvls.push_back(
+            env.makeTensorLevel(tid, env.emitter().getCurrentDepth()));
+      }
       needsUniv = true;
+    }
     if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
         isCompressedWithHiDLT(dlt) || isIdxReduc) {
       // Only when this is a index reduction loop, can the dlt be undefined.
@@ -1575,13 +1582,24 @@ static bool translateBitsToTidLvlPairs(
             // iterate based on the level of output tensor.  E.g., this
             // could be a synthetic tensor (for invariants and sparse
             // output tensor).
-            // out[i][j] = invariant; or a broadcast
-            // out[i][j] = in[i] (j is undef for input)
-            tid = outTid;
-            lvl = outLvl;
-            // Skips invalid lvl (e.g., when this is a zero ranked tensor).
-            if (!lvl)
-              return;
+            if (env.isReduc() && env.merger().getSynTensorID() == tid) {
+              // Coiterating with an invariant, and this is a reduction loop
+              // e.g., out = prod(in[i][j] op invariant);
+              // In this case, we can not infer the loop bound from output
+              // (whose level is reduced). Instead we use the synthetic tensor
+              // to infer the bound.
+              // The level of the synthetic tensor is the current loop depth;
+              // the rank of the synthetic tensor equals to number of loops.
+              lvl = env.emitter().getCurrentDepth();
+            } else {
+              // or a broadcast
+              // out[i][j] = in[i] (j is undef for input)
+              tid = outTid;
+              lvl = outLvl;
+              // Skips invalid lvl (e.g., when this is a zero ranked tensor).
+              if (!lvl)
+                return;
+            }
           }
           hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique;
           tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
@@ -1671,7 +1689,8 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
   auto allTidLvls =
       llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
   for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
-    if (tid != env.merger().getOutTensorID())
+    if (tid != env.merger().getOutTensorID() &&
+        tid != env.merger().getSynTensorID())
       genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
   }
 
@@ -1798,7 +1817,7 @@ static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
   } else {
     // To rematerialize an non-annotated tensor, simply load it
     // from the bufferized value.
-    Value val = env.emitter().getValBuffer().back(); // value array
+    Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
     rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
   }
 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir
index c90c2c416cd8..06b8a1ad0f3a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir
@@ -140,7 +140,9 @@ module {
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0
     ]> : tensor<32xf32>
 
-    // Convert constants to annotated tensors.
+    // Convert constants to annotated tensors. Note that this
+    // particular conversion only stores nonzero elements,
+    // so we will have no explicit zeros, only implicit zeros.
     %d0_i32 = sparse_tensor.convert %c_0_i32
       : tensor<32xi32> to tensor<32xi32, #DV>
     %d0_f32 = sparse_tensor.convert %c_0_f32
@@ -158,6 +160,10 @@ module {
     %s1_f32 = sparse_tensor.convert %c_1_f32
       : tensor<32xf32> to tensor<32xf32, #SV>
 
+    // Special case, construct a sparse vector with an explicit zero.
+    %v0 = arith.constant sparse< [ [1] ], [ 0 ] > : tensor<32xi32>
+    %s0 = sparse_tensor.convert %v0: tensor<32xi32> to tensor<32xi32, #SV>
+
     // Call the kernels.
     %0 = call @prod_dreduction_i32(%d0_i32, %ri) : (tensor<32xi32, #DV>, tensor<i32>) -> tensor<i32>
     %1 = call @prod_dreduction_f32(%d0_f32, %rf) : (tensor<32xf32, #DV>, tensor<f32>) -> tensor<f32>
@@ -167,19 +173,23 @@ module {
     %5 = call @prod_dreduction_f32(%d1_f32, %rf) : (tensor<32xf32, #DV>, tensor<f32>) -> tensor<f32>
     %6 = call @prod_sreduction_i32(%s1_i32, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
     %7 = call @prod_sreduction_f32(%s1_f32, %rf) : (tensor<32xf32, #SV>, tensor<f32>) -> tensor<f32>
+    %8 = call @prod_sreduction_i32(%s0,     %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
 
     // Verify results. Note that the custom reduction gave permission
     // to treat an explicit vs implicit zero 
diff erently to compute the
-    // full product reduction. A "standard" product reduction would
-    // have to return 0 for any implicit zero occurrence too.
+    // full product reduction over stored elements. A "standard" product
+    // reduction would have to return 0 for any implicit zero occurrence
+    // too. An explicit zero nullifies the product, though, as requested.
     //
     // CHECK: 0
+    // CHECK: 0
     // CHECK: 3087
     // CHECK: 14
     // CHECK: 3087
     // CHECK: 168
     // CHECK: 3087
     // CHECK: 168
+    // CHECK: 0
     //
     call @dump_i32(%0) : (tensor<i32>) -> ()
     call @dump_f32(%1) : (tensor<f32>) -> ()
@@ -189,6 +199,7 @@ module {
     call @dump_f32(%5) : (tensor<f32>) -> ()
     call @dump_i32(%6) : (tensor<i32>) -> ()
     call @dump_f32(%7) : (tensor<f32>) -> ()
+    call @dump_i32(%8) : (tensor<i32>) -> ()
 
     // Release the resources.
     bufferization.dealloc_tensor %d0_i32 : tensor<32xi32, #DV>
@@ -199,6 +210,7 @@ module {
     bufferization.dealloc_tensor %d1_f32 : tensor<32xf32, #DV>
     bufferization.dealloc_tensor %s1_i32 : tensor<32xi32, #SV>
     bufferization.dealloc_tensor %s1_f32 : tensor<32xf32, #SV>
+    bufferization.dealloc_tensor %s0     : tensor<32xi32, #SV>
 
     return
   }


        


More information about the Mlir-commits mailing list