[Mlir-commits] [mlir] b60cf8c - [mlir][sparse] support coiteration with fused reshape tensor

Peiming Liu llvmlistbot at llvm.org
Wed Mar 1 12:55:51 PST 2023


Author: Peiming Liu
Date: 2023-03-01T20:55:46Z
New Revision: b60cf8c972c6986ff80547e49f2943266e6b2615

URL: https://github.com/llvm/llvm-project/commit/b60cf8c972c6986ff80547e49f2943266e6b2615
DIFF: https://github.com/llvm/llvm-project/commit/b60cf8c972c6986ff80547e49f2943266e6b2615.diff

LOG: [mlir][sparse] support coiteration with fused reshape tensor

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D145091

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index bb2468580a57a..28f68622e3e95 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -127,6 +127,25 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
   return add;
 }
 
+Value LoopEmitter::genSparseCoord(OpBuilder &builder, Location loc, size_t tid,
+                                  size_t l) {
+  Value c = constantIndex(builder, loc, 0);
+  auto reass = getCollapseReassociation(tid, l);
+  for (unsigned i = 0; i < reass.size(); i++) {
+    auto lvl = reass[i];
+    // A load on the indices array yields the coordinate.
+    Value ptr = idxBuffer[tid][lvl];
+    Value off = genIndexLoad(builder, loc, ptr, pidxs[tid][l]);
+    // Linearized the coordinates within the same collapse reassociation.
+    c = builder.create<arith::AddIOp>(loc, c, off);
+    if (i != reass.size() - 1) {
+      c = builder.create<arith::MulIOp>(loc, c,
+                                        this->lvlSizes[tid][reass[i + 1]]);
+    }
+  }
+  return c;
+}
+
 LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
                          bool isSparseOut, ArrayRef<unsigned> topSort) {
   initialize(tensors, loopTag, hasOutput, isSparseOut, topSort);
@@ -383,20 +402,9 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
   Value c;
   if (isSparseInput) {
     assert(reass.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
-    c = constantIndex(builder, loc, 0);
-    for (unsigned i = 0; i < reass.size(); i++) {
-      auto lvl = reass[i];
-      // For COO, the pidxs are always the same across consecutive levels.
-      pidxs[tid][lvl] = iv;
-      // Generating a load on the indices array yields the coordinate.
-      Value ptr = idxBuffer[tid][lvl];
-      Value off = genIndexLoad(builder, loc, ptr, iv);
-      c = builder.create<arith::AddIOp>(loc, c, off);
-      if (i != reass.size() - 1) {
-        c = builder.create<arith::MulIOp>(loc, c,
-                                          this->lvlSizes[tid][reass[i + 1]]);
-      }
-    }
+    // For COO, the position is the same across consecutive levels.
+    llvm::for_each(reass, [this, tid, iv](int lvl) { pidxs[tid][lvl] = iv; });
+    c = genSparseCoord(builder, loc, tid, dim);
   } else {
     // Dense tensor, the coordinates is the inducation variable.
     c = iv;
@@ -555,16 +563,22 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
   builder.setInsertionPointToStart(&whileOp.getBefore().front());
   Value cond;
   unsigned o = 0;
-  for (auto [tid, dim] : llvm::zip(tids, dims)) {
-    if (isCompressedDLT(dimTypes[tid][dim]) ||
-        isSingletonDLT(dimTypes[tid][dim])) {
+  for (auto [t, lvl] : llvm::zip(tids, dims)) {
+    unsigned tid = t; // Why `t` can not be captured by lambda?
+    if (isCompressedDLT(dimTypes[tid][lvl]) ||
+        isSingletonDLT(dimTypes[tid][lvl])) {
       Value op1 = before->getArgument(o);
-      Value op2 = highs[tid][dim];
+      Value op2 = highs[tid][lvl];
       Value opc = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
                                                 op1, op2);
       cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc;
-      // Update
-      pidxs[tid][dim] = after->getArgument(o++);
+      // Update positions
+      Value pos = after->getArgument(o++);
+      auto reass = getCollapseReassociation(tid, lvl);
+      assert(reass.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+      // For COO, the position is the same across consecutive levels.
+      llvm::for_each(reass,
+                     [this, tid, pos](int lvl) { pidxs[tid][lvl] = pos; });
     }
   }
   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
@@ -578,11 +592,10 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
     // Prepares for next level.
     if (isCompressedDLT(dimTypes[tid][dim]) ||
         isSingletonDLT(dimTypes[tid][dim])) {
-      Value ptr = idxBuffer[tid][dim];
-      Value s = pidxs[tid][dim];
-      Value load = genIndexLoad(builder, loc, ptr, s);
-      coord[tid][dim] = load;
+      coord[tid][dim] = genSparseCoord(builder, loc, tid, dim);
       if (isSparseSlices[tid]) {
+        Value load =
+            genIndexLoad(builder, loc, idxBuffer[tid][dim], pidxs[tid][dim]);
         auto enc = getSparseTensorEncoding(tensors[tid].getType());
         auto [trans, pred] =
             genSliceLegitPredicate(builder, loc, load, enc, dim);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index b8c287140b4c6..09d4afbf5dc96 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -190,6 +190,10 @@ class LoopEmitter {
   Value genAddress(OpBuilder &builder, Location loc, size_t tid, size_t dim,
                    Value iv);
 
+  /// Generates instructions to compute the coordinate of tesnors[tid] on `l`
+  /// under the current loop context.
+  Value genSparseCoord(OpBuilder &builder, Location loc, size_t tid, size_t l);
+
   bool isOutputTensor(size_t tid) {
     return hasOutput && tid == tensors.size() - 1;
   }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
index 33606b01c89b7..d163e3e980ac8 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
@@ -23,7 +23,7 @@ module {
   func.func private @printMemref3dF32(%ptr : tensor<?x?x?xf32>) attributes { llvm.emit_c_interface }
   func.func private @printMemref2dF32(%ptr : tensor<?x?xf32>) attributes { llvm.emit_c_interface }
 
-  func.func @test_sparse(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32> {
+  func.func @test_sparse_rhs(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32> {
     %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32, #COO_3D> into tensor<6x6xf32, #COO_2D>
     %0 = tensor.empty() : tensor<5x6xf32>
     %cst = arith.constant 0.000000e+00 : f32
@@ -34,6 +34,17 @@ module {
     return %ret1 : tensor<?x?x?xf32>
   }
 
+  func.func @test_sparse_all(%arg0: tensor<5x6xf32, #COO_2D>, %arg1: tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32> {
+    %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32, #COO_3D> into tensor<6x6xf32, #COO_2D>
+    %0 = tensor.empty() : tensor<5x6xf32>
+    %cst = arith.constant 0.000000e+00 : f32
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32>
+    %2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32, #COO_2D>, tensor<6x6xf32, #COO_2D>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32>
+    %expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32>
+    %ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor<?x?x?xf32>
+    return %ret1 : tensor<?x?x?xf32>
+  }
+
   func.func @test_dense(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32>) -> tensor<?x?x?xf32> {
     %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32> into tensor<6x6xf32>
     %0 = tensor.empty() : tensor<5x6xf32>
@@ -58,6 +69,9 @@ module {
         [     6.0,       7.0,      8.0]
     > : tensor<6x2x3xf32>
 
+    %s1 = sparse_tensor.convert %d1 : tensor<5x6xf32> to tensor<5x6xf32, #COO_2D>
+    %s2 = sparse_tensor.convert %d2 : tensor<6x2x3xf32> to tensor<6x2x3xf32, #COO_3D>
+
     //      CHECK: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data =
     // CHECK-NEXT:[
     // CHECK-SAME: [
@@ -96,13 +110,35 @@ module {
     // CHECK-NEXT: [
     // CHECK-SAME:  [0,    0,    0],
     // CHECK-NEXT:  [0,    0,    0]]]
-    %s2 = sparse_tensor.convert %d2 : tensor<6x2x3xf32> to tensor<6x2x3xf32, #COO_3D>
-    %so1 = call @test_sparse(%d1, %s2): (tensor<5x6xf32>, tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32>
+    %so1 = call @test_sparse_rhs(%d1, %s2): (tensor<5x6xf32>, tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32>
     call @printMemref3dF32(%so1) : (tensor<?x?x?xf32>) -> ()
 
+    // Same results.
+    // CHECK-NEXT: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data =
+    // CHECK-NEXT:[
+    // CHECK-SAME: [
+    // CHECK-SAME:  [6,    0,    0],
+    // CHECK-NEXT:  [0,    0,    0]],
+    // CHECK-NEXT: [
+    // CHECK-SAME:  [0,    0,    0],
+    // CHECK-NEXT:  [0,    14,    0]],
+    // CHECK-NEXT: [
+    // CHECK-SAME:  [0,    0,    0],
+    // CHECK-NEXT:  [0,    24,    0]],
+    // CHECK-NEXT: [
+    // CHECK-SAME:  [0,    0,    0],
+    // CHECK-NEXT:  [0,    0,    0]],
+    // CHECK-NEXT: [
+    // CHECK-SAME:  [0,    0,    0],
+    // CHECK-NEXT:  [0,    0,    0]]]
+    %so2 = call @test_sparse_all(%s1, %s2): (tensor<5x6xf32, #COO_2D>, tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32>
+    call @printMemref3dF32(%so2) : (tensor<?x?x?xf32>) -> ()
+
+    bufferization.dealloc_tensor %s1 : tensor<5x6xf32, #COO_2D>
     bufferization.dealloc_tensor %s2 : tensor<6x2x3xf32, #COO_3D>
     bufferization.dealloc_tensor %do1 : tensor<?x?x?xf32>
     bufferization.dealloc_tensor %so1 : tensor<?x?x?xf32>
+    bufferization.dealloc_tensor %so2 : tensor<?x?x?xf32>
     return
   }
 }


        


More information about the Mlir-commits mailing list