[Mlir-commits] [mlir] f7b8b00 - [mlir][sparse] fix bugs when computing the memory size when lowering pack op.

Peiming Liu llvmlistbot at llvm.org
Thu May 25 12:19:59 PDT 2023


Author: Peiming Liu
Date: 2023-05-25T19:19:52Z
New Revision: f7b8b005ff12b2f4245aa42684d129358396d5df

URL: https://github.com/llvm/llvm-project/commit/f7b8b005ff12b2f4245aa42684d129358396d5df
DIFF: https://github.com/llvm/llvm-project/commit/f7b8b005ff12b2f4245aa42684d129358396d5df.diff

LOG: [mlir][sparse] fix bugs when computing the memory size when lowering pack op.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D151481

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 7d4efa8961eb5..f6405d2a47e4f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1242,10 +1242,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
         });
 
     MutSparseTensorDescriptor desc(stt, fields);
+    Value c0 = constantIndex(rewriter, loc, 0);
     Value c1 = constantIndex(rewriter, loc, 1);
     Value c2 = constantIndex(rewriter, loc, 2);
-    Value posBack = c1; // index to the last value in the postion array
-    Value memSize = c2; // memory size for current array
+    Value posBack = c0; // index to the last value in the postion array
+    Value memSize = c1; // memory size for current array
 
     Level trailCOOStart = getCOOStart(stt.getEncoding());
     Level trailCOORank = stt.getLvlRank() - trailCOOStart;
@@ -1266,7 +1267,7 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
       DimLevelType dlt = stt.getLvlType(lvl);
       // Simply forwards the position index when this is a dense level.
       if (isDenseDLT(dlt)) {
-        memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, posBack);
+        memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
         posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
         continue;
       }
@@ -1276,6 +1277,10 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
         if (isCompressedWithHiDLT(dlt)) {
           memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
           posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
+        } else {
+          assert(isCompressedDLT(dlt));
+          posBack = memSize;
+          memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
         }
         desc.setPosMemSize(rewriter, loc, lvl, memSize);
         // The last value in position array is the memory size for next level.

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 09ba910fc3cfc..5d7305dac54cc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse | FileCheck %s
+// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
 
 #COO = #sparse_tensor.encoding<{
   lvlTypes = ["compressed-nu", "singleton"],
@@ -9,25 +9,25 @@
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<6xf64>,
 // CHECK-SAME:      %[[VAL_1:.*]]: tensor<2xindex>,
 // CHECK-SAME:      %[[VAL_2:.*]]: tensor<6x2xi32>)
-// CHECK-DAG:       %[[VAL_3:.*]] = bufferization.to_memref %[[VAL_1]] : memref<2xindex>
-// CHECK-DAG:       %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x2xi32>
-// CHECK-DAG:       %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
-// CHECK-DAG:       %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<12xi32> to memref<?xi32>
-// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
-// CHECK-DAG:       %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<6xf64> to memref<?xf64>
-// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant 2 : index
-// CHECK-DAG:       %[[VAL_13:.*]] = arith.constant 100 : index
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]]  lvl_sz at 0 with %[[VAL_13]]
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  pos_mem_sz at 0 with %[[VAL_12]]
-// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK:           %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_12]] : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 100 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_1]] : memref<2xindex>
+// CHECK-DAG:       %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x2xi32>
+// CHECK-DAG:       %[[VAL_9:.*]] = memref.collapse_shape %[[VAL_8]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
+// CHECK-DAG:       %[[VAL_10:.*]] = memref.cast %[[VAL_9]] : memref<12xi32> to memref<?xi32>
+// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
+// CHECK-DAG:       %[[VAL_12:.*]] = memref.cast %[[VAL_11]] : memref<6xf64> to memref<?xf64>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]]  lvl_sz at 0 with %[[VAL_4]]
+// CHECK:           %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  pos_mem_sz at 0 with %[[VAL_3]]
+// CHECK:           %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex>
+// CHECK:           %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
 // CHECK:           %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]]  crd_mem_sz at 0 with %[[VAL_17]]
-// CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  lvl_sz at 1 with %[[VAL_13]]
+// CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  lvl_sz at 1 with %[[VAL_4]]
 // CHECK:           %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]]  val_mem_sz with %[[VAL_16]]
-// CHECK:           return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_20]]
+// CHECK:           return %[[VAL_7]], %[[VAL_10]], %[[VAL_12]], %[[VAL_20]]
 // CHECK:         }
 func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinates: tensor<6x2xi32>)
                     -> tensor<100x100xf64, #COO> {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 4c541a6b61a0f..3014407a95a1a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -81,6 +81,10 @@ module {
     %s5= sparse_tensor.pack %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
                                            to tensor<10x10xf64, #SortedCOOI32>
 
+    %csr_data = arith.constant dense<
+       [  1.0,  2.0,  3.0,  4.0]
+    > : tensor<4xf64>
+
     %csr_pos32 = arith.constant dense<
        [0, 1, 3]
     > : tensor<3xi32>
@@ -88,7 +92,7 @@ module {
     %csr_index32 = arith.constant dense<
        [1, 0, 1]
     > : tensor<3xi32>
-    %csr= sparse_tensor.pack %data, %csr_pos32, %csr_index32 : tensor<3xf64>, tensor<3xi32>, tensor<3xi32>
+    %csr= sparse_tensor.pack %csr_data, %csr_pos32, %csr_index32 : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
                                            to tensor<2x2xf64, #CSR>
 
     %bdata = arith.constant dense<
@@ -164,6 +168,16 @@ module {
         vector.print %v: f64
      }
 
+    %d_csr = tensor.empty() : tensor<4xf64>
+    %p_csr = tensor.empty() : tensor<3xi32>
+    %i_csr = tensor.empty() : tensor<3xi32>
+    %rd_csr, %rp_csr, %ri_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
+                 outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
+                 -> tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
+
+    // CHECK-NEXT: ( 1, 2, 3, {{.*}} )
+    %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
+    vector.print %vd_csr : vector<4xf64>
 
     // CHECK-NEXT:1
     // CHECK-NEXT:2


        


More information about the Mlir-commits mailing list