[Mlir-commits] [mlir] c3aeb3e - [mlir][sparse] Introduce sparse_tensor.storage operator to create a sparse tensor storage tuple

Peiming Liu llvmlistbot at llvm.org
Fri Sep 2 17:08:38 PDT 2022


Author: Peiming Liu
Date: 2022-09-03T00:08:29Z
New Revision: c3aeb3e644e3b6f26265f5609a6e6124ba983d06

URL: https://github.com/llvm/llvm-project/commit/c3aeb3e644e3b6f26265f5609a6e6124ba983d06
DIFF: https://github.com/llvm/llvm-project/commit/c3aeb3e644e3b6f26265f5609a6e6124ba983d06.diff

LOG: [mlir][sparse] Introduce sparse_tensor.storage operator to create a sparse tensor storage tuple

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133231

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3e1564f201cc0..9272e2e90d54a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -629,6 +629,29 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
 // sparse tensor codegen to progressively lower sparse tensors.
 //===----------------------------------------------------------------------===//
 
+def SparseTensor_StorageNewOp : SparseTensor_Op<"storage", []>,
+    Arguments<(ins Variadic<AnyType>:$inputs)>,
+    Results<(outs AnyTuple:$result)> {
+  let summary = "Pack a list of value into one sparse tensor storage value";
+  let description = [{
+     Pack a list of value into one sparse tensor storage value (represented as
+     a tuple) at the given index.
+
+     The result tuple elements' type should match the corresponding type in the
+     input array.
+
+     Example:
+
+     ```mlir
+     %0 = sparse_tensor.storage(%1, %2): memref<?xf64>, memref<?xf64>
+                                to tuple<memref<?xf64>, memref<?xf64>>
+     ```
+   }];
+
+  let assemblyFormat = " attr-dict `(` $inputs `)``:` type($inputs) `to` type($result)";
+  let hasVerifier = 1;
+}
+
 def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>,
     Arguments<(ins AnyTuple:$storage,
                    IndexAttr:$idx)>,

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1c76f7efe8d9b..22cf768316f98 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -486,6 +486,23 @@ LogicalResult YieldOp::verify() {
 // Sparse Tensor Storage Operation.
 //===----------------------------------------------------------------------===//
 
+LogicalResult StorageNewOp::verify() {
+  auto retTypes = getResult().getType().getTypes();
+  if (retTypes.size() != getInputs().size())
+    return emitError("The number of inputs is inconsistent with output tuple");
+
+  for (auto pair : llvm::zip(getInputs(), retTypes)) {
+    auto input = std::get<0>(pair);
+    auto retTy = std::get<1>(pair);
+
+    if (input.getType() != retTy)
+      return emitError(llvm::formatv("Type mismatch between input (type={0}) "
+                                     "and output tuple element (type={1})",
+                                     input.getType(), retTy));
+  }
+  return success();
+}
+
 LogicalResult StorageGetOp::verify() {
   uint64_t extractIdx = getIdx().getZExtValue();
   auto innerTypeArray = getStorage().getType().getTypes();

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 805f959692365..b9555e8861a25 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -445,6 +445,26 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
 
 // -----
 
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+                               tuple<memref<?xf64>, memref<?xf64>> {
+  // expected-error at +1{{The number of inputs is inconsistent with output}}
+  %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+       : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>>
+  return %0 : tuple<memref<?xf64>, memref<?xf64>>
+}
+
+// -----
+
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+                               tuple<memref<?xi64>, memref<?xf64>, f64> {
+  // expected-error at +1{{Type mismatch between}}
+  %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+       : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xi64>, memref<?xf64>, f64>
+  return %0 : tuple<memref<?xi64>, memref<?xf64>, f64>
+}
+
+// -----
+
 func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
   // expected-error at +1{{Out-of-bound access}}
   %0 = sparse_tensor.storage_get %arg0[3]

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 4b972778aae13..c37b4e7b53ac8 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -317,6 +317,22 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
 
 // -----
 
+
+// CHECK: func @sparse_storage_new(
+//  CHECK-SAME: %[[A0:.*0]]: memref<?xf64>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xf64>,
+//  CHECK-SAME: %[[A2:.*]]: f64
+//       CHECK: %[[TMP_0:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]], %[[A2]])
+//       CHECK: return %[[TMP_0]] : tuple<memref<?xf64>, memref<?xf64>, f64>
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+                               tuple<memref<?xf64>, memref<?xf64>, f64> {
+  %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+       : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>, f64>
+  return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
+
+// -----
+
 // CHECK-LABEL: func @sparse_storage_get(
 //  CHECK-SAME:   %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>
 //       CHECK:   %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] :
@@ -329,7 +345,7 @@ func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -
   return %0 : memref<?xf64>
 }
 
-// ----
+// -----
 
 // CHECK-LABEL: func @sparse_storage_set(
 //  CHECK-SAME:   %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>,


        


More information about the Mlir-commits mailing list