[Mlir-commits] [mlir] [mlir][sparse] refactoring: using util functions to query the index to load from position array for slice-driven loop. (PR #73986)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 30 13:03:34 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---

Patch is 60.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73986.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+97-70) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+210-203) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index a245344755f0404..50ac86f1c6165bb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -148,23 +148,60 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
 // Helper functions that load/store into the position buffer for slice-driven
 // loops.
 // The sliced pointer buffer is orgnized as:
-// [size, curPtr] (two metadata) + [[pLo, pHi, pNext], ...] (list of tuples)
+// [size, curPtr] (two metadata) + [[pLo0, pLo1, pLo2, ...],
+//                                  [pHi0, pHi1, pHi2, ...],
+//                                  [pNx0, pNx1, pNx2, ...]]
+static Value allocSlicePosBuf(OpBuilder &builder, Location loc,
+                              Value tupleCnt) {
+  Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth));
+  // Additional two metadata {memSize, idx} at head.
+  bufSz = ADDI(bufSz, C_IDX(2));
+  return genAlloca(builder, loc, bufSz, builder.getIndexType());
+}
+// TODO: We should use SSA value for it.
+// Gets and sets metadata.
 static Value loadSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf) {
-  // Load curPtr.
-  // TODO: We should use SSA value for it.
   return genIndexLoad(builder, loc, sPosBuf, C_IDX(1));
 }
 static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf,
                               Value pPtr) {
-  // Set curPtr.
-  // TODO: We should use SSA value for it.
   builder.create<memref::StoreOp>(loc, pPtr, sPosBuf, C_IDX(1));
 }
-static Value loadSliceNextPosPtrStart(OpBuilder &builder, Location loc,
-                                      Value sPosBuf, Value tupleIdx) {
-  // load the pNext in the current tuple specified by `tupleIdx`.
-  // 4 = 2 (two metadata) + 2 (pNext == tuple[2])
-  return genIndexLoad(builder, loc, sPosBuf, ADDI(tupleIdx, C_IDX(4)));
+static Value loadSlicePosTupleNum(OpBuilder &builder, Location loc,
+                                  Value sPosBuf) {
+  return genIndexLoad(builder, loc, sPosBuf, C_IDX(0));
+}
+static void updateSlicePosTupleNum(OpBuilder &builder, Location loc, Value num,
+                                   Value sPosBuf) {
+  builder.create<memref::StoreOp>(loc, num, sPosBuf, C_IDX(0));
+}
+
+// Gets and sets position values for slice-driven loops.
+enum class SlicePosKind { kLo, kHi, kNext };
+static Value getSlicePosIdx(OpBuilder &builder, Location loc, Value posBuf,
+                            Value tupleIdx, SlicePosKind posKind) {
+  Value dim = builder.create<memref::DimOp>(loc, posBuf, C_IDX(0));
+  Value tupleCnt = DIVUI(SUBI(dim, C_IDX(2)), C_IDX(kSliceIterWidth));
+  switch (posKind) {
+  case SlicePosKind::kLo:
+    return ADDI(tupleIdx, C_IDX(2));
+  case SlicePosKind::kHi:
+    return ADDI(tupleIdx, ADDI(tupleCnt, C_IDX(2)));
+  case SlicePosKind::kNext:
+    return ADDI(tupleIdx, ADDI(tupleCnt, ADDI(tupleCnt, C_IDX(2))));
+  }
+  llvm_unreachable("unexpected kind");
+}
+static Value loadSlicePos(OpBuilder &builder, Location loc, Value sPosBuf,
+                          Value tupleIdx, SlicePosKind posKind) {
+  return genIndexLoad(builder, loc, sPosBuf,
+                      getSlicePosIdx(builder, loc, sPosBuf, tupleIdx, posKind));
+}
+static void updateSlicePos(OpBuilder &builder, Location loc, Value sPosBuf,
+                           Value pos, Value tupleIdx, SlicePosKind posKind) {
+  builder.create<memref::StoreOp>(
+      loc, pos, sPosBuf,
+      getSlicePosIdx(builder, loc, sPosBuf, tupleIdx, posKind));
 }
 
 std::pair<Value, Value>
@@ -1446,13 +1483,13 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
     // The first compressed level, setting up the position pointer for it.
     Value sPosBuf = slicePosBuffer[tid][curLvl].back();
     // One step forwards in the parent level result in forwarding one `segment`
-    // (kSliceIterWidth) in the child sparse level.
-    Value fPosPtr = MULI(fcnt, C_IDX(kSliceIterWidth));     // forward ptr
+    // in the child sparse level.
     Value pPosPtr = loadSlicePosPtr(builder, loc, sPosBuf); // previous ptr
-    Value cPosPtr = ADDI(fPosPtr, pPosPtr);                 // current ptr
+    Value cPosPtr = ADDI(fcnt, pPosPtr);                    // current ptr
     updateSlicePosPtr(builder, loc, sPosBuf, cPosPtr);
     // Loads the position pointer start for next level.
-    nxPosPtr = loadSliceNextPosPtrStart(builder, loc, sPosBuf, cPosPtr);
+    nxPosPtr =
+        loadSlicePos(builder, loc, sPosBuf, cPosPtr, SlicePosKind::kNext);
     curLvl++;
   }
 
@@ -1464,10 +1501,10 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
   for (; curLvl < leafLvl; curLvl++) {
     assert(nxPosPtr);
     if (!isDenseLT(lvlTypes[tid][curLvl])) {
-      nxPosPtr = MULI(nxPosPtr, C_IDX(kSliceIterWidth));
       Value sPosBuf = slicePosBuffer[tid][curLvl].back();
       updateSlicePosPtr(builder, loc, sPosBuf, nxPosPtr);
-      nxPosPtr = loadSliceNextPosPtrStart(builder, loc, sPosBuf, nxPosPtr);
+      nxPosPtr =
+          loadSlicePos(builder, loc, sPosBuf, nxPosPtr, SlicePosKind::kNext);
     }
   }
 }
@@ -1737,7 +1774,7 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
     std::optional<std::pair<TensorId, Level>> firstResLvl, ValueRange userReduc,
     LoopBodyBuilder bodyBuilder) {
 
-  Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
+  Value c0 = C_IDX(0), c1 = C_IDX(1);
   Value pos = c0;
   OpBuilder::InsertPoint ip;
   SmallVector<Value> innerArgs(userReduc.begin(), userReduc.end());
@@ -1770,20 +1807,22 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
         unsigned depth = frontSlice.depth - 1;
         Value offset = frontSlice.offset;
         Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
-        Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
+        Value mSz = loadSlicePosTupleNum(builder, loc, sPtrBuf);
         outerMost = builder.create<scf::ForOp>(
-            loc, c2, mSz, C_IDX(kSliceIterWidth), innerArgs,
-            [this, c1, c2, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
+            loc, c0, mSz, c1, innerArgs,
+            [this, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
              &innerArgs](OpBuilder &builder, Location loc, Value iv,
                          ValueRange iterArgs) {
               // generate traversal for each level.
-              Value loopLo = genIndexLoad(builder, loc, sPtrBuf, iv);
-              Value loopHi = genIndexLoad(builder, loc, sPtrBuf, ADDI(iv, c1));
+              Value loopLo =
+                  loadSlicePos(builder, loc, sPtrBuf, iv, SlicePosKind::kLo);
+              Value loopHi =
+                  loadSlicePos(builder, loc, sPtrBuf, iv, SlicePosKind::kHi);
               // We need to remember the starting index for next level's
               // position, because slice-driven loop breaks the level into
               // non-consecutive segments.
-              builder.create<memref::StoreOp>(loc, iterArgs.back(), sPtrBuf,
-                                              ADDI(iv, c2).getResult());
+              updateSlicePos(builder, loc, sPtrBuf, iterArgs.back(), iv,
+                             SlicePosKind::kNext);
 
               auto [size, stride] = sliceMeta[tid][firstLvl].back();
               assert(stride == 1 && "Not yet implemented");
@@ -1874,8 +1913,7 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
 
 void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
                                         TensorId tid, Level lvl) {
-  Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2), c3 = C_IDX(3),
-        c4 = C_IDX(4);
+  Value c0 = C_IDX(0), c1 = C_IDX(1);
   if (isDenseLT(lvlTypes[tid][lvl])) {
     // Dense slice begin is trivial.
     sliceStack[tid].emplace_back(/*minCoord=*/c0, /*offset=*/c0,
@@ -1897,10 +1935,10 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
                        ADDI(posits[tid][lvl - 1], c1));
   }
   // Fills out pIdxBuffer[tid][lvl][0] with [/*memSize =*/4, 0, pLo, pHi]
-  builder.create<memref::StoreOp>(loc, c4, sPtrBuf, c0);  // memSize = 4
-  builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1);  // index = 0
-  builder.create<memref::StoreOp>(loc, pLo, sPtrBuf, c2); // pLo
-  builder.create<memref::StoreOp>(loc, pHi, sPtrBuf, c3); // pHi
+  updateSlicePosTupleNum(builder, loc, c1, sPtrBuf);
+  updateSlicePosPtr(builder, loc, sPtrBuf, c0);
+  updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
+  updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);
 
   // This is an non empty tensor if pLo < pHi.
   Value isNonEmpty = CMPI(ult, pLo, pHi);
@@ -1939,7 +1977,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
 // }
 void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
                                           TensorId tid, Level lvl) {
-  Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
+  Value c0 = C_IDX(0), c1 = C_IDX(1);
   unsigned depth = levelReducedDep[tid][lvl];
   // The remaining slice size after reduction.
   Value remSz = sliceMeta[tid][lvl][depth + 1].first;
@@ -1984,7 +2022,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
   SmallVector<Value, 3> reduc = {
       constantI1(builder, loc, false), // isNonEmpty
       lvlSizes[tid][lvl],              // minCoord
-      c2,                              // memSize
+      c0,                              // memSize
   };
 
   ValueRange result = genUnResolvedSliceTreeTraverse(
@@ -1993,7 +2031,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
                                     MutableArrayRef<Value> reduc) {
         Value &nonEmpty = reduc[0];
         Value &minCrd = reduc[1];
-        Value &curMemSz = reduc[2];
+        Value &curTupleCnt = reduc[2];
 
         Value pHi = ADDI(iv, c1);
         Value sPLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], iv);
@@ -2025,19 +2063,19 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
           YIELD(minCrd);
         }
         minCrd = ifNonEmpty.getResult(0);
-        builder.create<memref::StoreOp>(loc, sPLo, sPtrBuf, curMemSz);
-        Value nxtMemSize = ADDI(curMemSz, c1);
-        builder.create<memref::StoreOp>(loc, sPHi, sPtrBuf, nxtMemSize);
-        // curMemSize += kSliceIterWidth
-        curMemSz = ADDI(curMemSz, C_IDX(kSliceIterWidth));
+        updateSlicePos(builder, loc, sPtrBuf, sPLo, curTupleCnt,
+                       SlicePosKind::kLo);
+        updateSlicePos(builder, loc, sPtrBuf, sPHi, curTupleCnt,
+                       SlicePosKind::kHi);
+        curTupleCnt = ADDI(curTupleCnt, C_IDX(1));
       });
 
   Value isNonEmpty = result[0];
   Value minCrd = result[1];
   // Two metadata [memSize, idx].
   // TODO: Can use an SSA value for these two metadata
-  builder.create<memref::StoreOp>(loc, result[2], sPtrBuf, c0);
-  builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1);
+  updateSlicePosTupleNum(builder, loc, result[2], sPtrBuf);
+  updateSlicePosPtr(builder, loc, sPtrBuf, c0);
   // FIXME: we need the relative offset related to the base slice.
   Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
   sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, depth + 1);
@@ -2045,8 +2083,6 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
 
 bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
                                 Level lvl) {
-  Value c1 = C_IDX(1), c2 = C_IDX(2);
-
   if (depFullyReduced(tid, lvl)) {
     // Do not need to prepare for slice driven loop on dense level after it is
     // fully reduced.
@@ -2055,14 +2091,12 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
     // If constraints on the tensor is fully resolved. We do not need to
     // generates slice begin any more, instead we fall back to TACO-based
     // algorithm to (co)iterates over the slice.
-    Value pLoPtr =
-        loadSlicePosPtr(builder, loc, slicePosBuffer[tid][lvl].back());
-    pLoPtr = ADDI(pLoPtr, c2);
-    Value pHiPtr = ADDI(pLoPtr, c1);
+    Value sPosBuf = slicePosBuffer[tid][lvl].back();
+    Value tupleIdx = loadSlicePosPtr(builder, loc, sPosBuf);
     posits[tid][lvl] =
-        genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), pLoPtr);
+        loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kLo);
     highs[tid][lvl] =
-        genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), pHiPtr);
+        loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kHi);
     return true;
   }
 
@@ -2091,8 +2125,7 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
     // The buffer can be reused, and the size is loop invariant: it only
     // depends on the iteration graph's toposort.
     builder.setInsertionPointAfter(localInsertPos);
-    Value bufSize = C_IDX(1);
-    Value c2 = C_IDX(2);
+    Value tupleCnt = C_IDX(1);
     // Accumlates the size required to cache the pLo for the slice.
     // E.g., if we want to cache the pIdx for slice<d0xd1xf64> on the second
     // level. We at most need to a memref<d0xindex>.
@@ -2109,16 +2142,10 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
       assert(!sliceMeta[tid][curLevel - 1].empty());
       auto [sz, stride] = sliceMeta[tid][curLevel - 1].back();
       assert(stride == 1 && "Not yet implemented");
-      bufSize = MULI(bufSize, sz);
+      tupleCnt = MULI(tupleCnt, sz);
     }
-    // For a triple of [pLo, pHi, pPtr]. Note that we can not compress pHi
-    // because slice creates segments in the index buffer so that the pHi for
-    // the current level is no longer the pLo for the next level.
-    bufSize = MULI(bufSize, C_IDX(kSliceIterWidth));
-    // Additional two metadata {memSize, idx} at head.
-    bufSize = ADDI(bufSize, c2);
     for (Value &cache : slicePosBuffer[tid][lvl])
-      cache = genAlloca(builder, loc, bufSize, builder.getIndexType());
+      cache = allocSlicePosBuf(builder, loc, tupleCnt);
   }
 
   if (sliceInfo.isInitialTensor() ||
@@ -2148,7 +2175,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
     llvm_unreachable("TODO");
 
   // else generate code to compute next non empty slice.
-  Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
+  Value c0 = C_IDX(0), c1 = C_IDX(1);
 
   SliceInfo &info = sliceStack[tid].back();
   assert(info.slicedOnLvl == lvl);
@@ -2195,14 +2222,13 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
     //    offset = minCrd - size + 1;
     // }
     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-    reduc[2] = absOffset; // restore value.
-    Value pSt = c2;       // pointer starting index
-    Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
-    reduc[0] = lvlSizes[tid][lvl];                       // next min coord
-    reduc[1] = constantI1(builder, loc, false);          // isNonEmpty
+    reduc[2] = absOffset;                                    // restore value.
+    Value mSz = loadSlicePosTupleNum(builder, loc, sPtrBuf); // memSize
+    reduc[0] = lvlSizes[tid][lvl];                           // next min coord
+    reduc[1] = constantI1(builder, loc, false);              // isNonEmpty
     auto loopArgs = static_cast<ValueRange>(reduc).drop_back();
     auto forOp = scf::buildLoopNest(
-        builder, loc, pSt, mSz, C_IDX(kSliceIterWidth), loopArgs,
+        builder, loc, c0, mSz, c1, loopArgs,
         [this, tid, lvl, c1, sPtrBuf,
          &info](OpBuilder &builder, Location loc, ValueRange ivs,
                 ValueRange iterArgs) -> scf::ValueVector {
@@ -2210,9 +2236,10 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
           Value isNonEmpty = iterArgs[1];
 
           Type idxTp = builder.getIndexType();
-          Value pLo = genIndexLoad(builder, loc, sPtrBuf, ivs.front());
-          Value pHi =
-              genIndexLoad(builder, loc, sPtrBuf, ADDI(ivs.front(), c1));
+          Value pLo = loadSlicePos(builder, loc, sPtrBuf, ivs.front(),
+                                   SlicePosKind::kLo);
+          Value pHi = loadSlicePos(builder, loc, sPtrBuf, ivs.front(),
+                                   SlicePosKind::kHi);
           //
           // if (pLo < pHi) // Only loads when inbound.
           //   coord = load[pLo]
@@ -2236,8 +2263,8 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
                   &ifEqual.getThenRegion().front());
               Value newPlo = ADDI(pLo, c1);
               // Updates the cache.
-              builder.create<memref::StoreOp>(loc, newPlo, sPtrBuf,
-                                              ivs.front());
+              updateSlicePos(builder, loc, sPtrBuf, newPlo, ivs.front(),
+                             SlicePosKind::kLo);
               YIELD(newPlo);
             }
             /* else coord != minCrd */ {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 94b25a358e804a7..6266c63064ffbd8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -150,8 +150,8 @@ class SparsificationAndBufferizationPass
       pm.addPass(
           createSparseReinterpretMapPass(ReinterpretMapScope::kExceptGeneric));
       pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
+      pm.addPass(mlir::createLoopInvariantCodeMotionPass());
       if (vectorLength > 0) {
-        pm.addPass(mlir::createLoopInvariantCodeMotionPass());
         pm.addPass(createSparseVectorizationPass(
             vectorLength, enableVLAVectorization, enableSIMDIndex32));
       }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 0cd57ac64b50012..0f99a0206e4cb85 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -6,258 +6,265 @@
 
 #DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
 
+
 // CHECK-LABEL:   func.func @conv2d_all_sparse_CSR(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x8xi32, #sparse{{[0-9]*}}>,
-// CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse{{[0-9]*}}> {
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse{{[0-9]*}}>
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant -2 : index
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 8 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 3 : index
-// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 3 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 2 : index
-// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 4 : index
-// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0 : i32
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant false
-// CHECK-DAG:       %[[VAL_12:.*]] = tensor.empty() : tensor<6x6xi32, #sparse{{[0-9]*}}>
-// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/73986


More information about the Mlir-commits mailing list