[Mlir-commits] [mlir] [mlir][sparse] optimize memory loads to SSA values when generating sp… (PR #74787)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 7 15:44:00 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>

…arse conv.

---

Patch is 51.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74787.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+41-99) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h (+3-10) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+179-194) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 75121b5e3ce2e..26d6ea908cf38 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -148,39 +148,29 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
 // Helper functions that load/store into the position buffer for slice-driven
 // loops.
 // The sliced pointer buffer is orgnized as:
-// [size, curPtr] (two metadata) + [[pLo0, pLo1, pLo2, ...],
-//                                  [pHi0, pHi1, pHi2, ...],
-//                                  [pNx0, pNx1, pNx2, ...]]
+// [[pLo0, pLo1, pLo2, ...],
+//  [pHi0, pHi1, pHi2, ...],
+//  [pNx0, pNx1, pNx2, ...]]
 static Value allocSlicePosBuf(OpBuilder &builder, Location loc,
                               Value tupleCnt) {
   Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth));
   // Additional two metadata {memSize, idx} at head.
-  bufSz = ADDI(bufSz, C_IDX(2));
   return genAlloca(builder, loc, bufSz, builder.getIndexType());
 }
-// TODO: We should use SSA value for it.
-// Gets and sets metadata.
-static Value loadSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf) {
-  return genIndexLoad(builder, loc, sPosBuf, C_IDX(1));
-}
-static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf,
-                              Value pPtr) {
-  builder.create<memref::StoreOp>(loc, pPtr, sPosBuf, C_IDX(1));
-}
 
 // Gets and sets position values for slice-driven loops.
 enum class SlicePosKind { kLo, kHi, kNext };
 static Value getSlicePosIdx(OpBuilder &builder, Location loc, Value posBuf,
                             Value tupleIdx, SlicePosKind posKind) {
   Value dim = builder.create<memref::DimOp>(loc, posBuf, C_IDX(0));
-  Value tupleCnt = DIVUI(SUBI(dim, C_IDX(2)), C_IDX(kSliceIterWidth));
+  Value tupleCnt = DIVUI(dim, C_IDX(kSliceIterWidth));
   switch (posKind) {
   case SlicePosKind::kLo:
-    return ADDI(tupleIdx, C_IDX(2));
+    return tupleIdx;
   case SlicePosKind::kHi:
-    return ADDI(tupleIdx, ADDI(tupleCnt, C_IDX(2)));
+    return ADDI(tupleIdx, tupleCnt);
   case SlicePosKind::kNext:
-    return ADDI(tupleIdx, ADDI(tupleCnt, ADDI(tupleCnt, C_IDX(2))));
+    return ADDI(tupleIdx, MULI(tupleCnt, C_IDX(2)));
   }
   llvm_unreachable("unexpected kind");
 }
@@ -344,6 +334,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   this->dependentLvlMap.assign(
       numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
   this->slicePosBuffer.assign(numTensors, std::vector<std::vector<Value>>());
+  this->sliceTupleNxStartIdx.assign(numTensors, std::vector<Value>());
+  this->sliceTupleFwdCnt.assign(numTensors, std::vector<Value>());
+  this->trivialSlice.assign(numTensors, std::vector<bool>());
   this->sliceMeta.assign(
       numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
   this->sliceStack.assign(numTensors, std::vector<SliceInfo>());
@@ -394,6 +387,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
     dependentLvlMap[tid].assign(
         lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
     slicePosBuffer[tid].assign(lvlRank, std::vector<Value>());
+    sliceTupleNxStartIdx[tid].assign(lvlRank, Value());
+    sliceTupleFwdCnt[tid].assign(lvlRank, Value());
+    trivialSlice[tid].assign(lvlRank, false);
     sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
     sliceStack[tid].emplace_back(/*minCrd=*/Value(),
                                  /*offset=*/Value(), /*isNonEmpty*/ Value(),
@@ -806,6 +802,7 @@ std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
     assert(ivs.size() == 1);
     // Coord is the relative offset related to its parents.
     assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement");
+    sliceTupleFwdCnt[tid][lvl] = SUBI(ivs[0], posits[tid][lvl]);
     // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
     Value posit = ivs[0];
     Value crdBuf = coordinatesBuffers[tid][lvl];
@@ -1324,6 +1321,12 @@ void LoopEmitter::enterTensorsAtDenseLvls(
       } else {
         posits[tid][lvl] =
             genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
+        Value fwdCnt = lvl == 0 || trivialSlice[tid][lvl]
+                           ? C_IDX(0)
+                           : sliceTupleFwdCnt[tid][lvl - 1];
+        Value sz = sliceMeta[tid][lvl].back().first;
+        Value mul = MULI(fwdCnt, sz);
+        sliceTupleFwdCnt[tid][lvl] = ADDI(mul, iv);
       }
       levelReducedDep[tid][lvl]++;
     } else {
@@ -1357,13 +1360,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
       assert(isDenseLT(lvlTypes[tid][lvl]));
       assert(*info.slicedOnLvl == lvl);
       (void)reduced;
-      // Resets slices pointers as the resolved slices are invalidated after we
-      // moves forward to the next slice.
-      invalidateSliceIterIdx(rewriter, loc, tid, lvl);
       info.minCrd = info.offset = info.isNonEmpty = Value();
-    } else {
-      forwardsReducedSliceLevelTreeIt(rewriter, loc, tid, lvl,
-                                      constantIndex(rewriter, loc, 1));
     }
     levelReducedDep[tid][lvl]--;
   }
@@ -1443,54 +1440,6 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
   }
 }
 
-void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
-                                                  Location loc, TensorId tid,
-                                                  Level rootLvl, Value fcnt) {
-
-  auto stt = getSparseTensorType(tensors[tid]);
-
-  // Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
-  // sparse levels (but not resolved). Since we forward an iterator at higher
-  // level of the tree, the subtree need to be pruned.
-  Level leafLvl = rootLvl + 1;
-  while (leafLvl < stt.getLvlRank() && depFullyReduced(tid, leafLvl) &&
-         !stt.isDenseLvl(leafLvl)) {
-    leafLvl++;
-  }
-
-  Level curLvl = rootLvl + 1;
-  Value nxPosPtr = nullptr;
-  if (curLvl < leafLvl) {
-    assert(!isDenseLT(lvlTypes[tid][curLvl]));
-    // The first compressed level, setting up the position pointer for it.
-    Value sPosBuf = slicePosBuffer[tid][curLvl].back();
-    // One step forwards in the parent level result in forwarding one `segment`
-    // in the child sparse level.
-    Value pPosPtr = loadSlicePosPtr(builder, loc, sPosBuf); // previous ptr
-    Value cPosPtr = ADDI(fcnt, pPosPtr);                    // current ptr
-    updateSlicePosPtr(builder, loc, sPosBuf, cPosPtr);
-    // Loads the position pointer start for next level.
-    nxPosPtr =
-        loadSlicePos(builder, loc, sPosBuf, cPosPtr, SlicePosKind::kNext);
-    curLvl++;
-  }
-
-  // TODO: This is not always needed, but we did it unconditionally for now for
-  // simplicity.
-  // It is only needed when `curLvl` is forwarded without traversing its child
-  // level (e.g., the level is in a conjunctive lattices and got pruned), such
-  // that the position pointer is not forwarded inside the loop.
-  for (; curLvl < leafLvl; curLvl++) {
-    assert(nxPosPtr);
-    if (!isDenseLT(lvlTypes[tid][curLvl])) {
-      Value sPosBuf = slicePosBuffer[tid][curLvl].back();
-      updateSlicePosPtr(builder, loc, sPosBuf, nxPosPtr);
-      nxPosPtr =
-          loadSlicePos(builder, loc, sPosBuf, nxPosPtr, SlicePosKind::kNext);
-    }
-  }
-}
-
 void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
                                 MutableArrayRef<Value> reduc) {
   const LoopInfo &loopInfo = loopStack.back();
@@ -1540,13 +1489,6 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
       forwarded = CMPI(eq, coords[tid][lvl], iv);
       operands.push_back(SELECT(forwarded, nxPos, pos));
     }
-    {
-      OpBuilder::InsertionGuard guard(builder);
-      auto ifOp = builder.create<scf::IfOp>(loc, TypeRange{}, forwarded,
-                                            /*else=*/false);
-      builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-      forwardsReducedSliceLevelTreeIt(builder, loc, tid, lvl, one);
-    }
     // The coordinate is invalid now.
     coords[tid][lvl] = nullptr;
 
@@ -1916,8 +1858,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
     pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
                        ADDI(posits[tid][lvl - 1], c1));
   }
-  // Fills out pIdxBuffer[tid][lvl][0] with [0, pLo, pHi]
-  updateSlicePosPtr(builder, loc, sPtrBuf, c0);
+  // Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
   updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
   updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);
   // Slice over a resolved parent, we only need one pair of pos hi and lo to
@@ -2056,8 +1997,6 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
   Value isNonEmpty = result[0];
   Value minCrd = result[1];
   // Two metadata [memSize, idx].
-  // TODO: Can use an SSA value for these two metadata
-  updateSlicePosPtr(builder, loc, sPtrBuf, c0);
   // FIXME: we need the relative offset related to the base slice.
   Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
   sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, result[2], lvl,
@@ -2066,16 +2005,30 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
 
 bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
                                 Level lvl) {
+  Value curLvlIdx = C_IDX(0);
   if (depFullyReduced(tid, lvl)) {
-    // Do not need to prepare for slice driven loop on dense level after it is
-    // fully reduced.
+    if (lvl == 0 || trivialSlice[tid][lvl]) {
+      sliceTupleNxStartIdx[tid][lvl] = C_IDX(0);
+    } else {
+      if (isDenseLT(lvlTypes[tid][lvl])) {
+        sliceTupleNxStartIdx[tid][lvl] = sliceTupleNxStartIdx[tid][lvl - 1];
+      } else {
+        assert(isCompressedLT(lvlTypes[tid][lvl]));
+        curLvlIdx = ADDI(sliceTupleNxStartIdx[tid][lvl - 1],
+                         sliceTupleFwdCnt[0][lvl - 1]);
+        sliceTupleNxStartIdx[tid][lvl] =
+            loadSlicePos(builder, loc, slicePosBuffer[tid][lvl].back(),
+                         curLvlIdx, SlicePosKind::kNext);
+      }
+    }
     if (isDenseLT(lvlTypes[tid][lvl]))
       return true;
+
+    Value sPosBuf = slicePosBuffer[tid][lvl].back();
     // If constraints on the tensor is fully resolved. We do not need to
     // generates slice begin any more, instead we fall back to TACO-based
     // algorithm to (co)iterates over the slice.
-    Value sPosBuf = slicePosBuffer[tid][lvl].back();
-    Value tupleIdx = loadSlicePosPtr(builder, loc, sPosBuf);
+    Value tupleIdx = curLvlIdx;
     posits[tid][lvl] =
         loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kLo);
     highs[tid][lvl] =
@@ -2134,23 +2087,16 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
   if (sliceInfo.isInitialTensor() ||
       (lvl >= 1 && lvlFullyResolved(tid, lvl - 1))) {
     // First level or previous level has been full resolved.
+    trivialSlice[tid][lvl] = true;
     genResolvedSliceBegin(builder, loc, tid, lvl);
   } else {
     // The previous level has not been full resolved.
+    trivialSlice[tid][lvl] = false;
     genUnResolvedSliceBegin(builder, loc, tid, lvl);
   }
   return false;
 }
 
-void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
-                                         TensorId tid, Level lvl) {
-  for (unsigned i = 0; i <= lvl; i++) {
-    if (!isDenseLT(lvlTypes[tid][i]) && !dependentLvlMap[tid][i].empty()) {
-      updateSlicePosPtr(builder, loc, slicePosBuffer[tid][i].back(), C_IDX(0));
-    }
-  }
-}
-
 std::tuple<Value, Value, Value>
 LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
                                    TensorId tid, Level lvl) {
@@ -2175,10 +2121,6 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
   //   isNonEmpty = false;
   //
   Value absOffset = info.offset;
-  // Resets slices pointers as the resolved slices are invalidated after we
-  // moves forward to the next slice.
-  invalidateSliceIterIdx(builder, loc, tid, lvl);
-
   SmallVector<Value, 3> reduc = {info.minCrd, info.isNonEmpty, absOffset};
   Value sPtrBuf = slicePosBuffer[tid][lvl][info.depth - 1];
   Value fastPathP = CMPI(ugt, info.minCrd, absOffset);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 5e51cb2110fa1..fa8b0076f733b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -453,11 +453,6 @@ class LoopEmitter {
     return tid < lvlTypes.size() && lvl < lvlTypes[tid].size();
   }
 
-  /// Forwards the (conceptual) "tree iterator" when iterating over a fully
-  /// reduced slice created by index-reduction.
-  void forwardsReducedSliceLevelTreeIt(OpBuilder &builder, Location loc,
-                                       TensorId tid, Level lvl, Value fcnt);
-
   /// Prepares loop for iterating over `tensor[lvl]`, under the assumption
   /// that `tensor[0...lvl-1]` loops have already been set up.
   void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
@@ -610,11 +605,6 @@ class LoopEmitter {
   void genUnResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
                                Level lvl);
 
-  /// Invalidates the index kept in slice postion buffers (by setting it to
-  /// zero).
-  /// TODO: We should instead use an SSA value for the index.
-  void invalidateSliceIterIdx(OpBuilder &builder, Location loc, TensorId tid,
-                              Level lvl);
   /// Generates code to get the first non-empty slice of tid on lvl.
   /// return true if has already been resolved.
   bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);
@@ -683,6 +673,9 @@ class LoopEmitter {
   // But they always starts with the first pidx pointing to coord > slice.offset
   // to avoid iteration from the beginning.
   std::vector<std::vector<std::vector<Value>>> slicePosBuffer;
+  std::vector<std::vector<Value>> sliceTupleNxStartIdx;
+  std::vector<std::vector<Value>> sliceTupleFwdCnt;
+  std::vector<std::vector<bool>> trivialSlice;
 
   // The (size, stride) for each conceptual slice used for index reduction
   // loops.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 02cc5d1e2ef34..a3c1e76a3d09a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -12,241 +12,226 @@
 // CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant -2 : index
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 3 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 5 : index
-// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 2 : index
-// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 0 : i32
-// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant false
-// CHECK-DAG:       %[[VAL_13:.*]] = tensor.empty() : tensor<6x6xi32, #sparse>
-// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref<?xi32>
-// CHECK:           %[[VAL_19:.*]] = memref.alloca() : memref<11xindex>
-// CHECK:           %[[VAL_20:.*]] = memref.alloca() : memref<5xindex>
-// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_7]]] : memref<?xindex>
-// CHECK:           memref.store %[[VAL_10]], %[[VAL_20]]{{\[}}%[[VAL_7]]] : memref<5xindex>
-// CHECK:           memref.store %[[VAL_10]], %[[VAL_20]]{{\[}}%[[VAL_9]]] : memref<5xindex>
-// CHECK:           memref.store %[[VAL_21]], %[[VAL_20]]{{\[}}%[[VAL_6]]] : memref<5xindex>
-// CHECK:           %[[VAL_22:.*]] = arith.cmpi ugt, %[[VAL_21]], %[[VAL_10]] : index
-// CHECK:           %[[VAL_23:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_10]]] : memref<?xindex>
-// CHECK:           %[[VAL_24:.*]] = arith.cmpi uge, %[[VAL_23]], %[[VAL_6]] : index
-// CHECK:           %[[VAL_25:.*]] = arith.andi %[[VAL_22]], %[[VAL_24]] : i1
-// CHECK:           %[[VAL_26:.*]] = arith.addi %[[VAL_23]], %[[VAL_3]] : index
-// CHECK:           %[[VAL_27:.*]] = arith.select %[[VAL_25]], %[[VAL_26]], %[[VAL_10]] : index
-// CHECK:           %[[VAL_28:.*]]:3 = scf.while (%[[VAL_29:.*]] = %[[VAL_22]], %[[VAL_30:.*]] = %[[VAL_23]], %[[VAL_31:.*]] = %[[VAL_27]], %[[VAL_32:.*]] = %[[VAL_13]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
-// CHECK:             scf.condition(%[[VAL_29]]) %[[VAL_30]], %[[VAL_31]], %[[VAL_32]] : index, index, tensor<6x6xi32, #sparse>
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0 : i32
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant false
+// CHECK-DAG:       %[[VAL_11:.*]] = tensor.empty() : tensor<6x6xi32, #sparse>
+// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref<?xi32>
+// CHECK-DAG:       %[[VAL_17:.*]] = memref.alloca() : memref<9xindex>
+// CHECK-DAG:       %[[VAL_18:.*]] = memref.alloca() : memref<3xindex>
+// CHECK-DAG:       %[[VAL_19:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
+// CHECK:           memref.store %[[VAL_19]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
+// CHECK:           %[[VAL_20:.*]] = arith.cmpi ugt, %[[VAL_19]], %[[VAL_8]] : index
+// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK:           %[[VAL_22:.*]] = arith.cmpi uge, %[[VAL_21]], %[[VAL_6]] : index
+// CHECK:           %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : i1
+// CHECK:           %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
+// CHECK:           %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_8]] : index
+// CHECK:           %[[VAL_26:.*]]:3 = scf.while (%[[VAL_27:.*]] = %[[VAL_20]], %[[VAL_28:.*]] = %[[VAL_21]], %[[VAL_29:.*]] = %[[VAL_25]], %[[VAL_30:.*]] = %[[VAL_11]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
+// CHECK:             scf.condition(%[[VAL_27]]) %[[VAL_28]], %[[VAL_29]], %[[VAL_30]] : index, index, tensor<6x6xi32, #sparse>
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index, %[[VAL_35:.*]]: tensor<6x6xi32, #sparse>):
...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/74787


More information about the Mlir-commits mailing list