[Mlir-commits] [mlir] c3aeb3e - [mlir][sparse] Introduce sparse_tensor.storage operator to create a sparse tensor storage tuple
Peiming Liu
llvmlistbot at llvm.org
Fri Sep 2 17:08:38 PDT 2022
Author: Peiming Liu
Date: 2022-09-03T00:08:29Z
New Revision: c3aeb3e644e3b6f26265f5609a6e6124ba983d06
URL: https://github.com/llvm/llvm-project/commit/c3aeb3e644e3b6f26265f5609a6e6124ba983d06
DIFF: https://github.com/llvm/llvm-project/commit/c3aeb3e644e3b6f26265f5609a6e6124ba983d06.diff
LOG: [mlir][sparse] Introduce sparse_tensor.storage operator to create a sparse tensor storage tuple
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D133231
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3e1564f201cc0..9272e2e90d54a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -629,6 +629,29 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
// sparse tensor codegen to progressively lower sparse tensors.
//===----------------------------------------------------------------------===//
+def SparseTensor_StorageNewOp : SparseTensor_Op<"storage", []>,
+ Arguments<(ins Variadic<AnyType>:$inputs)>,
+ Results<(outs AnyTuple:$result)> {
+ let summary = "Pack a list of value into one sparse tensor storage value";
+ let description = [{
+ Pack a list of value into one sparse tensor storage value (represented as
+ a tuple) at the given index.
+
+ The result tuple elements' type should match the corresponding type in the
+ input array.
+
+ Example:
+
+ ```mlir
+ %0 = sparse_tensor.storage(%1, %2): memref<?xf64>, memref<?xf64>
+ to tuple<memref<?xf64>, memref<?xf64>>
+ ```
+ }];
+
+ let assemblyFormat = " attr-dict `(` $inputs `)``:` type($inputs) `to` type($result)";
+ let hasVerifier = 1;
+}
+
def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>,
Arguments<(ins AnyTuple:$storage,
IndexAttr:$idx)>,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1c76f7efe8d9b..22cf768316f98 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -486,6 +486,23 @@ LogicalResult YieldOp::verify() {
// Sparse Tensor Storage Operation.
//===----------------------------------------------------------------------===//
+LogicalResult StorageNewOp::verify() {
+ auto retTypes = getResult().getType().getTypes();
+ if (retTypes.size() != getInputs().size())
+ return emitError("The number of inputs is inconsistent with output tuple");
+
+ for (auto pair : llvm::zip(getInputs(), retTypes)) {
+ auto input = std::get<0>(pair);
+ auto retTy = std::get<1>(pair);
+
+ if (input.getType() != retTy)
+ return emitError(llvm::formatv("Type mismatch between input (type={0}) "
+ "and output tuple element (type={1})",
+ input.getType(), retTy));
+ }
+ return success();
+}
+
LogicalResult StorageGetOp::verify() {
uint64_t extractIdx = getIdx().getZExtValue();
auto innerTypeArray = getStorage().getType().getTypes();
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 805f959692365..b9555e8861a25 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -445,6 +445,26 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
// -----
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+ tuple<memref<?xf64>, memref<?xf64>> {
+ // expected-error at +1{{The number of inputs is inconsistent with output}}
+ %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+ : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>>
+ return %0 : tuple<memref<?xf64>, memref<?xf64>>
+}
+
+// -----
+
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+ tuple<memref<?xi64>, memref<?xf64>, f64> {
+ // expected-error at +1{{Type mismatch between}}
+ %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+ : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xi64>, memref<?xf64>, f64>
+ return %0 : tuple<memref<?xi64>, memref<?xf64>, f64>
+}
+
+// -----
+
func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
// expected-error at +1{{Out-of-bound access}}
%0 = sparse_tensor.storage_get %arg0[3]
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 4b972778aae13..c37b4e7b53ac8 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -317,6 +317,22 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
// -----
+
+// CHECK: func @sparse_storage_new(
+// CHECK-SAME: %[[A0:.*0]]: memref<?xf64>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xf64>,
+// CHECK-SAME: %[[A2:.*]]: f64
+// CHECK: %[[TMP_0:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]], %[[A2]])
+// CHECK: return %[[TMP_0]] : tuple<memref<?xf64>, memref<?xf64>, f64>
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+ tuple<memref<?xf64>, memref<?xf64>, f64> {
+ %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+ : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>, f64>
+ return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
+
+// -----
+
// CHECK-LABEL: func @sparse_storage_get(
// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>
// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] :
@@ -329,7 +345,7 @@ func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -
return %0 : memref<?xf64>
}
-// ----
+// -----
// CHECK-LABEL: func @sparse_storage_set(
// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>,
More information about the Mlir-commits
mailing list