[Mlir-commits] [mlir] 4fa3cc6 - [mlir][sparse] deduplicate non-unique coordinates when coiterating collapsed COO tensors.

Peiming Liu llvmlistbot at llvm.org
Thu Mar 9 10:15:18 PST 2023

Author: Peiming Liu
Date: 2023-03-09T18:15:12Z
New Revision: 4fa3cc6eb402a42bd8c677c285155237276c0f09

URL: https://github.com/llvm/llvm-project/commit/4fa3cc6eb402a42bd8c677c285155237276c0f09
DIFF: https://github.com/llvm/llvm-project/commit/4fa3cc6eb402a42bd8c677c285155237276c0f09.diff

LOG: [mlir][sparse] deduplicate non-unique coordinates when coiterating collapsed COO tensors.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D145532




diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index c4a48957ba7bd..a8474a1a65dc2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -407,12 +407,14 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
   auto enc = getSparseTensorEncoding(tensors[tid].getType());
   const auto reassoc = getCollapseReassociation(tid, dim);
-  dim = reassoc.front();
   // TODO: support dynamic slices.
+  // Uses the first dimension here to build the loop bound (which is also the
+  // biggest range).
+  const auto fdim = reassoc.front();
   Value step = constantIndex(builder, loc, 1);
-  Value lo = isSparseInput ? pidxs[tid][dim]      // current offset
+  Value lo = isSparseInput ? pidxs[tid][fdim]     // current offset
                            : loopSeqStack.back(); // universal index
-  Value hi = highs[tid][dim];
+  Value hi = highs[tid][fdim];
   Operation *loop = nullptr;
   Value iv;
@@ -585,9 +587,17 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
   for (auto [tid, dim] : llvm::zip(tids, dims)) {
     if (isCompressedDLT(dimTypes[tid][dim]) ||
         isSingletonDLT(dimTypes[tid][dim])) {
+      const auto reassoc = getCollapseReassociation(tid, dim);
+      for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
+        if (!isUniqueDLT(dimTypes[tid][reassoc[i]])) {
+          // This is the segment high for each non-unique levels.
+          types.push_back(indexType);
+          operands.push_back(constantIndex(builder, loc, 0));
+        }
+      }
-      operands.push_back(pidxs[tid][dim]);
+      operands.push_back(pidxs[tid][reassoc.front()]);
   // The position where user-supplied reduction variable starts.
@@ -616,15 +626,22 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
     unsigned tid = t; // Why `t` can not be captured by lambda?
     if (isCompressedDLT(dimTypes[tid][lvl]) ||
         isSingletonDLT(dimTypes[tid][lvl])) {
+      const auto reassoc = getCollapseReassociation(tid, lvl);
+      assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+      for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
+        if (!isUniqueDLT(dimTypes[tid][reassoc[i]])) {
+          // Links the SSA chain for segHi.
+          segHi[tid][reassoc[i]] = after->getArgument(o++);
+        }
+      }
       Value op1 = before->getArgument(o);
-      Value op2 = highs[tid][lvl];
+      // We used the first level bound as the bound the collapsed set of levels.
+      Value op2 = highs[tid][reassoc.front()];
       Value opc = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
                                                 op1, op2);
       cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc;
       // Update positions
       Value pos = after->getArgument(o++);
-      const auto reassoc = getCollapseReassociation(tid, lvl);
-      assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
       // For COO, the position is the same across consecutive levels.
                      [this, tid, pos](Level lvl) { pidxs[tid][lvl] = pos; });
@@ -714,9 +731,48 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
   assert(loopStack.size() == loopSeqStack.size());
   for (auto [tid, dim] : llvm::zip(tids, dims)) {
-    if (!isUniqueDLT(dimTypes[tid][dim])) {
-      segHi[tid][dim] = genSegmentHigh(builder, loc, tid, dim, pidxs[tid][dim],
-                                       highs[tid][dim]);
+    const auto reassoc = getCollapseReassociation(tid, dim);
+    assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+    // TODO: Refactors this into smaller functions.
+    // NOTE: For all the collapsed level (except for the last one, that is why
+    // the loop ends with `reassoc.size() - 1`), as each iteration is advanced
+    // by the segment size of the last level, which does not always invalidate
+    // the segment size for the previous levels, thus we need to propagate the
+    // segment sizes across loop iterations and only forward if needed.
+    //
+    // E.g., for a COO tensor with the following coordinates array.
+    // (0, 0, 1),
+    // (0, 0, 2),
+    // (1, 1, 1),
+    // segHi[lvl=0] = segHi[lvl=1] = 2
+    // segHi[lvl=2] = 1,
+    // the first iteration does not invalidate segHi[0] and segHi[1]
+    for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
+      const auto lvl = reassoc[i];
+      if (!isUniqueDLT(dimTypes[tid][lvl])) {
+        Value pos = pidxs[tid][lvl];
+        assert(segHi[tid][lvl]);
+        Value newSegHi = builder.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::uge, pos, segHi[tid][lvl]);
+        auto ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(),
+                                              newSegHi, true);
+        {
+          OpBuilder::InsertionGuard guard(builder);
+          builder.setInsertionPointToStart(ifOp.thenBlock());
+          builder.create<scf::YieldOp>(
+              loc,
+              genSegmentHigh(builder, loc, tid, lvl, pos, highs[tid][lvl]));
+          // Else, resues the same segment high.
+          builder.setInsertionPointToStart(ifOp.elseBlock());
+          builder.create<scf::YieldOp>(loc, segHi[tid][lvl]);
+        }
+        highs[tid][lvl + 1] = segHi[tid][lvl] = ifOp.getResult(0);
+      }
+    };
+    const auto lvl = reassoc.back();
+    if (!isUniqueDLT(dimTypes[tid][lvl])) {
+      segHi[tid][lvl] = genSegmentHigh(builder, loc, tid, lvl, pidxs[tid][lvl],
+                                       highs[tid][lvl]);
@@ -906,6 +962,15 @@ void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
   for (auto [tid, dim] : llvm::zip(tids, dims)) {
     if (isCompressedDLT(dimTypes[tid][dim]) ||
         isSingletonDLT(dimTypes[tid][dim])) {
+      const auto reassoc = getCollapseReassociation(tid, dim);
+      assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+      for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
+        const auto lvl = reassoc[i];
+        if (!isUniqueDLT(dimTypes[tid][lvl])) {
+          operands.push_back(segHi[tid][lvl]);
+          o++;
+        }
+      }
       Value op1 = coord[tid][dim];
       Value op3 = pidxs[tid][dim];
       Value cmp =
@@ -913,13 +978,18 @@ void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
       // If the loop contains a coiteration with non-unique level, we fast
       // forward all the duplicated coords by setting the position to the
       // segment high.
-      Value add = !isUniqueDLT(dimTypes[tid][dim])
-                      ? segHi[tid][dim]
+      // If this is a collapsed dim, we forward pidx based on the last level in
+      // the collapsed level set.
+      Value add = !isUniqueDLT(dimTypes[tid][reassoc.back()])
+                      ? segHi[tid][reassoc.back()]
                       : builder.create<arith::AddIOp>(loc, op3, one);
       operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, op3));
       // Following loops continue iteration from the break point of the
       // current while loop.
-      pidxs[tid][dim] = whileOp->getResult(o++);
+      Value pos = whileOp->getResult(o++);
+      const auto t = tid;
+      llvm::for_each(reassoc, [this, t, pos](Level l) { pidxs[t][l] = pos; });
       // The coordinates are invalid now.
       coord[tid][dim] = nullptr;
       // The segment high are invalid now

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
index e34707018b942..7ae4d59c3a90c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
@@ -56,6 +56,18 @@ module {
     return %ret1 :  tensor<?x?x?xf32>
+  func.func @test_sparse_all_2(%arg0: tensor<5x6xf32, #COO_2D>, %arg1: tensor<2x3x6xf32, #COO_3D>) -> tensor<?x?x?xf32> {
+    // collapse the first two level this time, as this is the level requires coiterations.
+    %collapsed = tensor.collapse_shape %arg1 [[0, 1], [2]] : tensor<2x3x6xf32, #COO_3D> into tensor<6x6xf32, #COO_2D>
+    %0 = tensor.empty() : tensor<5x6xf32>
+    %cst = arith.constant 0.000000e+00 : f32
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32>
+    %2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32, #COO_2D>, tensor<6x6xf32, #COO_2D>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32>
+    %expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32>
+    %ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor<?x?x?xf32>
+    return %ret1 : tensor<?x?x?xf32>
+  }
   func.func @entry() {
     // Setup two sparse vectors.
@@ -68,9 +80,12 @@ module {
       [ [0, 0, 0], [1, 1, 1], [2, 1, 1] ],
         [     6.0,       7.0,      8.0]
     > : tensor<6x2x3xf32>
+    %shape = arith.constant dense<[2, 3, 6]> : tensor<3xi32>
+    %d3 = tensor.reshape %d2(%shape): (tensor<6x2x3xf32>, tensor<3xi32>) -> tensor<2x3x6xf32>
     %s1 = sparse_tensor.convert %d1 : tensor<5x6xf32> to tensor<5x6xf32, #COO_2D>
     %s2 = sparse_tensor.convert %d2 : tensor<6x2x3xf32> to tensor<6x2x3xf32, #COO_3D>
+    %s3 = sparse_tensor.convert %d3 : tensor<2x3x6xf32> to tensor<2x3x6xf32, #COO_3D>
     //      CHECK: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data =
     // CHECK-NEXT:[
@@ -134,11 +149,34 @@ module {
     %so2 = call @test_sparse_all(%s1, %s2): (tensor<5x6xf32, #COO_2D>, tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32>
     call @printMemref3dF32(%so2) : (tensor<?x?x?xf32>) -> ()
+    // Same results.
+    // CHECK-NEXT: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data =
+    // CHECK-NEXT:[
+    // CHECK-SAME: [
+    // CHECK-SAME:  [6,    0,    0],
+    // CHECK-NEXT:  [0,    0,    0]],
+    // CHECK-NEXT: [
+    // CHECK-SAME:  [0,    0,    0],
+    // CHECK-NEXT:  [0,    14,    0]],
+    // CHECK-NEXT: [
+    // CHECK-SAME:  [0,    0,    0],
+    // CHECK-NEXT:  [0,    24,    0]],
+    // CHECK-NEXT: [
+    // CHECK-SAME:  [0,    0,    0],
+    // CHECK-NEXT:  [0,    0,    0]],
+    // CHECK-NEXT: [
+    // CHECK-SAME:  [0,    0,    0],
+    // CHECK-NEXT:  [0,    0,    0]]]
+    %so3 = call @test_sparse_all_2(%s1, %s3): (tensor<5x6xf32, #COO_2D>, tensor<2x3x6xf32, #COO_3D>) -> tensor<?x?x?xf32>
+    call @printMemref3dF32(%so2) : (tensor<?x?x?xf32>) -> ()
     bufferization.dealloc_tensor %s1 : tensor<5x6xf32, #COO_2D>
     bufferization.dealloc_tensor %s2 : tensor<6x2x3xf32, #COO_3D>
+    bufferization.dealloc_tensor %s3 : tensor<2x3x6xf32, #COO_3D>
     bufferization.dealloc_tensor %do1 : tensor<?x?x?xf32>
     bufferization.dealloc_tensor %so1 : tensor<?x?x?xf32>
     bufferization.dealloc_tensor %so2 : tensor<?x?x?xf32>
+    bufferization.dealloc_tensor %so3 : tensor<?x?x?xf32>


More information about the Mlir-commits mailing list