[Mlir-commits] [mlir] f7b8b00 - [mlir][sparse] fix bugs when computing the memory size when lowering pack op.
Peiming Liu
llvmlistbot at llvm.org
Thu May 25 12:19:59 PDT 2023
Author: Peiming Liu
Date: 2023-05-25T19:19:52Z
New Revision: f7b8b005ff12b2f4245aa42684d129358396d5df
URL: https://github.com/llvm/llvm-project/commit/f7b8b005ff12b2f4245aa42684d129358396d5df
DIFF: https://github.com/llvm/llvm-project/commit/f7b8b005ff12b2f4245aa42684d129358396d5df.diff
LOG: [mlir][sparse] fix bugs when computing the memory size when lowering pack op.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D151481
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/sparse_pack.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 7d4efa8961eb5..f6405d2a47e4f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1242,10 +1242,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
});
MutSparseTensorDescriptor desc(stt, fields);
+ Value c0 = constantIndex(rewriter, loc, 0);
Value c1 = constantIndex(rewriter, loc, 1);
Value c2 = constantIndex(rewriter, loc, 2);
- Value posBack = c1; // index to the last value in the postion array
- Value memSize = c2; // memory size for current array
+ Value posBack = c0; // index to the last value in the postion array
+ Value memSize = c1; // memory size for current array
Level trailCOOStart = getCOOStart(stt.getEncoding());
Level trailCOORank = stt.getLvlRank() - trailCOOStart;
@@ -1266,7 +1267,7 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
DimLevelType dlt = stt.getLvlType(lvl);
// Simply forwards the position index when this is a dense level.
if (isDenseDLT(dlt)) {
- memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, posBack);
+ memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
continue;
}
@@ -1276,6 +1277,10 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
if (isCompressedWithHiDLT(dlt)) {
memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
+ } else {
+ assert(isCompressedDLT(dlt));
+ posBack = memSize;
+ memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
}
desc.setPosMemSize(rewriter, loc, lvl, memSize);
// The last value in position array is the memory size for next level.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 09ba910fc3cfc..5d7305dac54cc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse | FileCheck %s
+// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
#COO = #sparse_tensor.encoding<{
lvlTypes = ["compressed-nu", "singleton"],
@@ -9,25 +9,25 @@
// CHECK-SAME: %[[VAL_0:.*]]: tensor<6xf64>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xindex>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<6x2xi32>)
-// CHECK-DAG: %[[VAL_3:.*]] = bufferization.to_memref %[[VAL_1]] : memref<2xindex>
-// CHECK-DAG: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x2xi32>
-// CHECK-DAG: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
-// CHECK-DAG: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<12xi32> to memref<?xi32>
-// CHECK-DAG: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
-// CHECK-DAG: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<6xf64> to memref<?xf64>
-// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init
-// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 100 : index
-// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] lvl_sz at 0 with %[[VAL_13]]
-// CHECK: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_12]]
-// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_12]] : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 100 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_1]] : memref<2xindex>
+// CHECK-DAG: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x2xi32>
+// CHECK-DAG: %[[VAL_9:.*]] = memref.collapse_shape %[[VAL_8]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
+// CHECK-DAG: %[[VAL_10:.*]] = memref.cast %[[VAL_9]] : memref<12xi32> to memref<?xi32>
+// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
+// CHECK-DAG: %[[VAL_12:.*]] = memref.cast %[[VAL_11]] : memref<6xf64> to memref<?xf64>
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]] lvl_sz at 0 with %[[VAL_4]]
+// CHECK: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_3]]
+// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex>
+// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]] crd_mem_sz at 0 with %[[VAL_17]]
-// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 1 with %[[VAL_13]]
+// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 1 with %[[VAL_4]]
// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] val_mem_sz with %[[VAL_16]]
-// CHECK: return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_20]]
+// CHECK: return %[[VAL_7]], %[[VAL_10]], %[[VAL_12]], %[[VAL_20]]
// CHECK: }
func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinates: tensor<6x2xi32>)
-> tensor<100x100xf64, #COO> {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 4c541a6b61a0f..3014407a95a1a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -81,6 +81,10 @@ module {
%s5= sparse_tensor.pack %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
to tensor<10x10xf64, #SortedCOOI32>
+ %csr_data = arith.constant dense<
+ [ 1.0, 2.0, 3.0, 4.0]
+ > : tensor<4xf64>
+
%csr_pos32 = arith.constant dense<
[0, 1, 3]
> : tensor<3xi32>
@@ -88,7 +92,7 @@ module {
%csr_index32 = arith.constant dense<
[1, 0, 1]
> : tensor<3xi32>
- %csr= sparse_tensor.pack %data, %csr_pos32, %csr_index32 : tensor<3xf64>, tensor<3xi32>, tensor<3xi32>
+ %csr= sparse_tensor.pack %csr_data, %csr_pos32, %csr_index32 : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
to tensor<2x2xf64, #CSR>
%bdata = arith.constant dense<
@@ -164,6 +168,16 @@ module {
vector.print %v: f64
}
+ %d_csr = tensor.empty() : tensor<4xf64>
+ %p_csr = tensor.empty() : tensor<3xi32>
+ %i_csr = tensor.empty() : tensor<3xi32>
+ %rd_csr, %rp_csr, %ri_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
+ outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
+ -> tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
+
+ // CHECK-NEXT: ( 1, 2, 3, {{.*}} )
+ %vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
+ vector.print %vd_csr : vector<4xf64>
// CHECK-NEXT:1
// CHECK-NEXT:2
More information about the Mlir-commits
mailing list