[Mlir-commits] [mlir] 1ce77b5 - [mlir][sparse] refine lexicographic insertion to any tensor
Aart Bik
llvmlistbot at llvm.org
Wed Nov 17 18:08:49 PST 2021
Author: Aart Bik
Date: 2021-11-17T18:08:42-08:00
New Revision: 1ce77b562de4d9a6fc703d39c9242f9786082ef1
URL: https://github.com/llvm/llvm-project/commit/1ce77b562de4d9a6fc703d39c9242f9786082ef1
DIFF: https://github.com/llvm/llvm-project/commit/1ce77b562de4d9a6fc703d39c9242f9786082ef1.diff
LOG: [mlir][sparse] refine lexicographic insertion to any tensor
First version was vectors only. With some clever "path" insertion,
we now support any d-dimensional tensor. Up next: reductions too
Reviewed By: bixia, wrengr
Differential Revision: https://reviews.llvm.org/D114024
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tensor_ops.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
mlir/test/Dialect/SparseTensor/sparse_out.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index cdcbfc2a54adc..31d3ee520fbd7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -325,9 +325,6 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
for (auto attr : op.iterator_types())
if (isReductionIterator(attr))
return false;
- // TODO: generalize support lib beyond vectors
- if (op.iterator_types().size() != 1)
- return false;
*sparseOut = lhs;
return true;
}
diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 69e678ce3657b..75664ecabd076 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -247,12 +247,9 @@ class SparseTensorStorage : public SparseTensorStorageBase {
if (tensor) {
uint64_t nnz = tensor->getElements().size();
values.reserve(nnz);
- fromCOO(tensor, sparsity, 0, nnz, 0);
- } else {
- if (allDense)
- values.resize(sz, 0);
- for (uint64_t r = 0; r < rank; r++)
- idx[r] = -1u;
+ fromCOO(tensor, 0, nnz, 0);
+ } else if (allDense) {
+ values.resize(sz, 0);
}
}
@@ -279,16 +276,26 @@ class SparseTensorStorage : public SparseTensorStorageBase {
void getValues(std::vector<V> **out) override { *out = &values; }
/// Partially specialize lexicographic insertions based on template types.
- // TODO: 1-dim tensors only for now, generalize soon
void lexInsert(uint64_t *cursor, V val) override {
- assert((idx[0] == -1u || idx[0] < cursor[0]) && "not lexicographic");
- indices[0].push_back(cursor[0]);
- values.push_back(val);
- idx[0] = cursor[0];
+ // First, wrap up pending insertion path.
+ uint64_t
diff = 0;
+ uint64_t top = 0;
+ if (!values.empty()) {
+
diff = lexDiff(cursor);
+ endPath(
diff + 1);
+ top = idx[
diff ] + 1;
+ }
+ // Then continue with insertion path.
+ insPath(cursor,
diff , top, val);
}
/// Finalizes lexicographic insertions.
- void endInsert() override { pointers[0].push_back(indices[0].size()); }
+ void endInsert() override {
+ if (values.empty())
+ endDim(0);
+ else
+ endPath(0);
+ }
/// Returns this sparse tensor storage scheme as a new memory-resident
/// sparse tensor in coordinate scheme with the given dimension order.
@@ -342,14 +349,14 @@ class SparseTensorStorage : public SparseTensorStorageBase {
/// Initializes sparse tensor storage scheme from a memory-resident sparse
/// tensor in coordinate scheme. This method prepares the pointers and
/// indices arrays under the given per-dimension dense/sparse annotations.
- void fromCOO(SparseTensorCOO<V> *tensor, const DimLevelType *sparsity,
- uint64_t lo, uint64_t hi, uint64_t d) {
+ void fromCOO(SparseTensorCOO<V> *tensor, uint64_t lo, uint64_t hi,
+ uint64_t d) {
const std::vector<Element<V>> &elements = tensor->getElements();
// Once dimensions are exhausted, insert the numerical values.
assert(d <= getRank());
if (d == getRank()) {
- assert(lo >= hi || lo < elements.size());
- values.push_back(lo < hi ? elements[lo].value : 0);
+ assert(lo < hi && hi <= elements.size());
+ values.push_back(elements[lo].value);
return;
}
// Visit all elements in this interval.
@@ -362,28 +369,28 @@ class SparseTensorStorage : public SparseTensorStorageBase {
while (seg < hi && elements[seg].indices[d] == i)
seg++;
// Handle segment in interval for sparse or dense dimension.
- if (sparsity[d] == DimLevelType::kCompressed) {
+ if (isCompressedDim(d)) {
indices[d].push_back(i);
} else {
// For dense storage we must fill in all the zero values between
// the previous element (when last we ran this for-loop) and the
// current element.
for (; full < i; full++)
- fromCOO(tensor, sparsity, 0, 0, d + 1); // pass empty
+ endDim(d + 1);
full++;
}
- fromCOO(tensor, sparsity, lo, seg, d + 1);
+ fromCOO(tensor, lo, seg, d + 1);
// And move on to next segment in interval.
lo = seg;
}
// Finalize the sparse pointer structure at this dimension.
- if (sparsity[d] == DimLevelType::kCompressed) {
+ if (isCompressedDim(d)) {
pointers[d].push_back(indices[d].size());
} else {
// For dense storage we must fill in all the zero values after
// the last element.
for (uint64_t sz = sizes[d]; full < sz; full++)
- fromCOO(tensor, sparsity, 0, 0, d + 1); // pass empty
+ endDim(d + 1);
}
}
@@ -395,21 +402,83 @@ class SparseTensorStorage : public SparseTensorStorageBase {
if (d == getRank()) {
assert(pos < values.size());
tensor->add(idx, values[pos]);
- } else if (pointers[d].empty()) {
+ } else if (isCompressedDim(d)) {
+ // Sparse dimension.
+ for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
+ idx[reord[d]] = indices[d][ii];
+ toCOO(tensor, reord, ii, d + 1);
+ }
+ } else {
// Dense dimension.
for (uint64_t i = 0, sz = sizes[d], off = pos * sz; i < sz; i++) {
idx[reord[d]] = i;
toCOO(tensor, reord, off + i, d + 1);
}
+ }
+ }
+
+ /// Ends a deeper, never seen before dimension.
+ void endDim(uint64_t d) {
+ assert(d <= getRank());
+ if (d == getRank()) {
+ values.push_back(0);
+ } else if (isCompressedDim(d)) {
+ pointers[d].push_back(indices[d].size());
} else {
- // Sparse dimension.
- for (uint64_t ii = pointers[d][pos]; ii < pointers[d][pos + 1]; ii++) {
- idx[reord[d]] = indices[d][ii];
- toCOO(tensor, reord, ii, d + 1);
+ for (uint64_t full = 0, sz = sizes[d]; full < sz; full++)
+ endDim(d + 1);
+ }
+ }
+
+ /// Wraps up a single insertion path, inner to outer.
+ void endPath(uint64_t
diff ) {
+ uint64_t rank = getRank();
+ assert(
diff <= rank);
+ for (uint64_t i = 0; i < rank -
diff ; i++) {
+ uint64_t d = rank - i - 1;
+ if (isCompressedDim(d)) {
+ pointers[d].push_back(indices[d].size());
+ } else {
+ for (uint64_t full = idx[d] + 1, sz = sizes[d]; full < sz; full++)
+ endDim(d + 1);
}
}
}
+ /// Continues a single insertion path, outer to inner.
+ void insPath(uint64_t *cursor, uint64_t
diff , uint64_t top, V val) {
+ uint64_t rank = getRank();
+ assert(
diff < rank);
+ for (uint64_t d =
diff ; d < rank; d++) {
+ uint64_t i = cursor[d];
+ if (isCompressedDim(d)) {
+ indices[d].push_back(i);
+ } else {
+ for (uint64_t full = top; full < i; full++)
+ endDim(d + 1);
+ }
+ top = 0;
+ idx[d] = i;
+ }
+ values.push_back(val);
+ }
+
+ /// Finds the lexicographic
diff ering dimension.
+ uint64_t lexDiff(uint64_t *cursor) {
+ for (uint64_t r = 0, rank = getRank(); r < rank; r++)
+ if (cursor[r] > idx[r])
+ return r;
+ else
+ assert(cursor[r] == idx[r] && "non-lexicographic insertion");
+ assert(0 && "duplication insertion");
+ return -1u;
+ }
+
+ /// Returns true if dimension is compressed.
+ inline bool isCompressedDim(uint64_t d) const {
+ return (!pointers[d].empty());
+ }
+
private:
std::vector<uint64_t> sizes; // per-dimension sizes
std::vector<uint64_t> rev; // "reverse" permutation
diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index e17e3e89bef10..90ba2ff4d6df0 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -11,7 +11,7 @@
dimOrdering = affine_map<(i,j) -> (i,j)>
}>
-#trait_scale = {
+#trait_scale_inpl = {
indexing_maps = [
affine_map<(i,j) -> (i,j)> // X (out)
],
@@ -44,7 +44,7 @@
// CHECK: }
func @sparse_simply_dynamic1(%argx: tensor<32x16xf32, #DCSR> {linalg.inplaceable = true}) -> tensor<32x16xf32, #DCSR> {
%c = arith.constant 2.0 : f32
- %0 = linalg.generic #trait_scale
+ %0 = linalg.generic #trait_scale_inpl
outs(%argx: tensor<32x16xf32, #DCSR>) {
^bb(%x: f32):
%1 = arith.mulf %x, %c : f32
@@ -129,3 +129,56 @@ func @sparse_simply_dynamic2(%arga: tensor<32x16xf32, #CSR>,
} -> tensor<32x16xf32, #DCSR>
return %0 : tensor<32x16xf32, #DCSR>
}
+
+#trait_scale = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel"],
+ doc = "X(i,j) = A(i,j) * 2.0"
+}
+
+// CHECK-LABEL: func @sparse_truly_dynamic(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 10 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 20 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.init{{\[}}%[[VAL_2]], %[[VAL_3]]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_11:.*]] = memref.alloca(%[[VAL_5]]) : memref<?xindex>
+// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_2]] step %[[VAL_4]] {
+// CHECK: memref.store %[[VAL_12]], %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_4]] : index
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_4]] {
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: memref.store %[[VAL_17]], %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<?xf32>
+// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_18]], %[[VAL_1]] : f32
+// CHECK: sparse_tensor.lex_insert %[[VAL_7]], %[[VAL_11]], %[[VAL_19]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.load %[[VAL_7]] hasInserts : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: return %[[VAL_20]] : tensor<10x20xf32, #sparse_tensor.encoding<{
+// CHECK: }
+func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #DCSR> {
+ %s = arith.constant 2.0 : f32
+ %d10 = arith.constant 10 : index
+ %d20 = arith.constant 20 : index
+ %xm = sparse_tensor.init [%d10, %d20] : tensor<10x20xf32, #DCSR>
+ %0 = linalg.generic #trait_scale
+ ins(%arga: tensor<10x20xf32, #CSR>)
+ outs(%xm: tensor<10x20xf32, #DCSR>) {
+ ^bb(%a: f32, %x: f32):
+ %1 = arith.mulf %a, %s : f32
+ linalg.yield %1 : f32
+ } -> tensor<10x20xf32, #DCSR>
+ return %0 : tensor<10x20xf32, #DCSR>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir
new file mode 100644
index 0000000000000..318f99aa20e5e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matrix_ops.mlir
@@ -0,0 +1,176 @@
+// RUN: mlir-opt %s \
+// RUN: --sparsification --sparse-tensor-conversion \
+// RUN: --linalg-bufferize --convert-linalg-to-loops \
+// RUN: --convert-vector-to-scf --convert-scf-to-std \
+// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN: --std-bufferize --finalizing-bufferize --lower-affine \
+// RUN: --convert-vector-to-llvm --convert-memref-to-llvm --convert-math-to-llvm \
+// RUN: --convert-std-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+//
+// Traits for 2-d tensor (aka matrix) operations.
+//
+#trait_scale = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A (in)
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel"],
+ doc = "X(i,j) = A(i,j) * 2.0"
+}
+#trait_scale_inpl = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel"],
+ doc = "X(i,j) *= 2.0"
+}
+#trait_op = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A (in)
+ affine_map<(i,j) -> (i,j)>, // B (in)
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel"],
+ doc = "X(i,j) = A(i,j) OP B(i,j)"
+}
+
+module {
+ // Scales a sparse matrix into a new sparse matrix.
+ func @matrix_scale(%arga: tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR> {
+ %s = arith.constant 2.0 : f64
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSR>
+ %d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #DCSR>
+ %xm = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DCSR>
+ %0 = linalg.generic #trait_scale
+ ins(%arga: tensor<?x?xf64, #DCSR>)
+ outs(%xm: tensor<?x?xf64, #DCSR>) {
+ ^bb(%a: f64, %x: f64):
+ %1 = arith.mulf %a, %s : f64
+ linalg.yield %1 : f64
+ } -> tensor<?x?xf64, #DCSR>
+ return %0 : tensor<?x?xf64, #DCSR>
+ }
+
+ // Scales a sparse matrix in place.
+ func @matrix_scale_inplace(%argx: tensor<?x?xf64, #DCSR>
+ {linalg.inplaceable = true}) -> tensor<?x?xf64, #DCSR> {
+ %s = arith.constant 2.0 : f64
+ %0 = linalg.generic #trait_scale_inpl
+ outs(%argx: tensor<?x?xf64, #DCSR>) {
+ ^bb(%x: f64):
+ %1 = arith.mulf %x, %s : f64
+ linalg.yield %1 : f64
+ } -> tensor<?x?xf64, #DCSR>
+ return %0 : tensor<?x?xf64, #DCSR>
+ }
+
+ // Adds two sparse matrices element-wise into a new sparse matrix.
+ func @matrix_add(%arga: tensor<?x?xf64, #DCSR>,
+ %argb: tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSR>
+ %d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #DCSR>
+ %xv = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DCSR>
+ %0 = linalg.generic #trait_op
+ ins(%arga, %argb: tensor<?x?xf64, #DCSR>, tensor<?x?xf64, #DCSR>)
+ outs(%xv: tensor<?x?xf64, #DCSR>) {
+ ^bb(%a: f64, %b: f64, %x: f64):
+ %1 = arith.addf %a, %b : f64
+ linalg.yield %1 : f64
+ } -> tensor<?x?xf64, #DCSR>
+ return %0 : tensor<?x?xf64, #DCSR>
+ }
+
+ // Multiplies two sparse matrices element-wise into a new sparse matrix.
+ func @matrix_mul(%arga: tensor<?x?xf64, #DCSR>,
+ %argb: tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSR>
+ %d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #DCSR>
+ %xv = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DCSR>
+ %0 = linalg.generic #trait_op
+ ins(%arga, %argb: tensor<?x?xf64, #DCSR>, tensor<?x?xf64, #DCSR>)
+ outs(%xv: tensor<?x?xf64, #DCSR>) {
+ ^bb(%a: f64, %b: f64, %x: f64):
+ %1 = arith.mulf %a, %b : f64
+ linalg.yield %1 : f64
+ } -> tensor<?x?xf64, #DCSR>
+ return %0 : tensor<?x?xf64, #DCSR>
+ }
+
+ // Dump a sparse matrix.
+ func @dump(%arg0: tensor<?x?xf64, #DCSR>) {
+ %d0 = arith.constant 0.0 : f64
+ %c0 = arith.constant 0 : index
+ %dm = sparse_tensor.convert %arg0 : tensor<?x?xf64, #DCSR> to tensor<?x?xf64>
+ %0 = memref.buffer_cast %dm : memref<?x?xf64>
+ %1 = vector.transfer_read %0[%c0, %c0], %d0: memref<?x?xf64>, vector<4x8xf64>
+ vector.print %1 : vector<4x8xf64>
+ memref.dealloc %0 : memref<?x?xf64>
+ return
+ }
+
+ // Driver method to call and verify matrix kernels.
+ func @entry() {
+ %c0 = arith.constant 0 : index
+ %d1 = arith.constant 1.1 : f64
+
+ // Setup sparse matrices.
+ %m1 = arith.constant sparse<
+ [ [0,0], [0,1], [1,7], [2,2], [2,4], [2,7], [3,0], [3,2], [3,3] ],
+ [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ]
+ > : tensor<4x8xf64>
+ %m2 = arith.constant sparse<
+ [ [0,0], [0,7], [1,0], [1,6], [2,1], [2,7] ],
+ [6.0, 5.0, 4.0, 3.0, 2.0, 1.0 ]
+ > : tensor<4x8xf64>
+ %sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
+ %sm2 = sparse_tensor.convert %m2 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
+
+ // Call sparse vector kernels.
+ %0 = call @matrix_scale(%sm1)
+ : (tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
+ %1 = call @matrix_scale_inplace(%sm1)
+ : (tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
+ %2 = call @matrix_add(%sm1, %sm2)
+ : (tensor<?x?xf64, #DCSR>, tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
+ %3 = call @matrix_mul(%sm1, %sm2)
+ : (tensor<?x?xf64, #DCSR>, tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
+
+ //
+ // Verify the results.
+ //
+ // CHECK: ( ( 2, 4, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 6 ), ( 0, 0, 8, 0, 10, 0, 0, 12 ), ( 14, 0, 16, 18, 0, 0, 0, 0 ) )
+ // CHECK-NEXT: ( ( 6, 0, 0, 0, 0, 0, 0, 5 ), ( 4, 0, 0, 0, 0, 0, 3, 0 ), ( 0, 2, 0, 0, 0, 0, 0, 1 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ) )
+ // CHECK-NEXT: ( ( 2, 4, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 6 ), ( 0, 0, 8, 0, 10, 0, 0, 12 ), ( 14, 0, 16, 18, 0, 0, 0, 0 ) )
+ // CHECK-NEXT: ( ( 2, 4, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 6 ), ( 0, 0, 8, 0, 10, 0, 0, 12 ), ( 14, 0, 16, 18, 0, 0, 0, 0 ) )
+ // CHECK-NEXT: ( ( 8, 4, 0, 0, 0, 0, 0, 5 ), ( 4, 0, 0, 0, 0, 0, 3, 6 ), ( 0, 2, 8, 0, 10, 0, 0, 13 ), ( 14, 0, 16, 18, 0, 0, 0, 0 ) )
+ // CHECK-NEXT: ( ( 12, 0, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 12 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ) )
+ //
+ call @dump(%sm1) : (tensor<?x?xf64, #DCSR>) -> ()
+ call @dump(%sm2) : (tensor<?x?xf64, #DCSR>) -> ()
+ call @dump(%0) : (tensor<?x?xf64, #DCSR>) -> ()
+ call @dump(%1) : (tensor<?x?xf64, #DCSR>) -> ()
+ call @dump(%2) : (tensor<?x?xf64, #DCSR>) -> ()
+ call @dump(%3) : (tensor<?x?xf64, #DCSR>) -> ()
+
+ // Release the resources.
+ sparse_tensor.release %sm1 : tensor<?x?xf64, #DCSR>
+ sparse_tensor.release %sm2 : tensor<?x?xf64, #DCSR>
+ sparse_tensor.release %0 : tensor<?x?xf64, #DCSR>
+ sparse_tensor.release %2 : tensor<?x?xf64, #DCSR>
+ sparse_tensor.release %3 : tensor<?x?xf64, #DCSR>
+ return
+ }
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tensor_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tensor_ops.mlir
new file mode 100644
index 0000000000000..2156ada00ba31
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_tensor_ops.mlir
@@ -0,0 +1,90 @@
+// RUN: mlir-opt %s \
+// RUN: --sparsification --sparse-tensor-conversion \
+// RUN: --linalg-bufferize --convert-linalg-to-loops \
+// RUN: --convert-vector-to-scf --convert-scf-to-std \
+// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN: --std-bufferize --finalizing-bufferize --lower-affine \
+// RUN: --convert-vector-to-llvm --convert-memref-to-llvm --convert-math-to-llvm \
+// RUN: --convert-std-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#ST1 = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed", "compressed"]}>
+#ST2 = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed", "dense"]}>
+
+//
+// Trait for 3-d tensor operation.
+//
+#trait_scale = {
+ indexing_maps = [
+ affine_map<(i,j,k) -> (i,j,k)>, // A (in)
+ affine_map<(i,j,k) -> (i,j,k)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel", "parallel"],
+ doc = "X(i,j,k) = A(i,j,k) * 2.0"
+}
+
+module {
+ // Scales a sparse tensor into a new sparse tensor.
+ func @tensor_scale(%arga: tensor<?x?x?xf64, #ST1>) -> tensor<?x?x?xf64, #ST2> {
+ %s = arith.constant 2.0 : f64
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %d0 = tensor.dim %arga, %c0 : tensor<?x?x?xf64, #ST1>
+ %d1 = tensor.dim %arga, %c1 : tensor<?x?x?xf64, #ST1>
+ %d2 = tensor.dim %arga, %c2 : tensor<?x?x?xf64, #ST1>
+ %xm = sparse_tensor.init [%d0, %d1, %d2] : tensor<?x?x?xf64, #ST2>
+ %0 = linalg.generic #trait_scale
+ ins(%arga: tensor<?x?x?xf64, #ST1>)
+ outs(%xm: tensor<?x?x?xf64, #ST2>) {
+ ^bb(%a: f64, %x: f64):
+ %1 = arith.mulf %a, %s : f64
+ linalg.yield %1 : f64
+ } -> tensor<?x?x?xf64, #ST2>
+ return %0 : tensor<?x?x?xf64, #ST2>
+ }
+
+ // Driver method to call and verify tensor kernel.
+ func @entry() {
+ %c0 = arith.constant 0 : index
+ %d1 = arith.constant -1.0 : f64
+
+ // Setup sparse tensor.
+ %t = arith.constant dense<
+ [ [ [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0 ] ],
+ [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ] ],
+ [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 5.0 ],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ] ] ]> : tensor<3x4x8xf64>
+ %st = sparse_tensor.convert %t : tensor<3x4x8xf64> to tensor<?x?x?xf64, #ST1>
+
+ // Call sparse vector kernels.
+ %0 = call @tensor_scale(%st) : (tensor<?x?x?xf64, #ST1>) -> tensor<?x?x?xf64, #ST2>
+
+ // Sanity check on stored values.
+ //
+ // CHECK: ( 1, 2, 3, 4, 5, -1, -1, -1 )
+ // CHECK-NEXT: ( 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 6, 8, 0, 0, 0, 0, 10, -1, -1, -1, -1, -1, -1, -1, -1 )
+ %m1 = sparse_tensor.values %st : tensor<?x?x?xf64, #ST1> to memref<?xf64>
+ %m2 = sparse_tensor.values %0 : tensor<?x?x?xf64, #ST2> to memref<?xf64>
+ %v1 = vector.transfer_read %m1[%c0], %d1: memref<?xf64>, vector<8xf64>
+ %v2 = vector.transfer_read %m2[%c0], %d1: memref<?xf64>, vector<32xf64>
+ vector.print %v1 : vector<8xf64>
+ vector.print %v2 : vector<32xf64>
+
+ // Release the resources.
+ sparse_tensor.release %st : tensor<?x?x?xf64, #ST1>
+ sparse_tensor.release %0 : tensor<?x?x?xf64, #ST2>
+ return
+ }
+}
More information about the Mlir-commits
mailing list