[Mlir-commits] [mlir] 7ea643c - [mlir][sparse] Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor

Peiming Liu llvmlistbot at llvm.org
Wed Aug 31 15:15:24 PDT 2022


Author: Peiming Liu
Date: 2022-08-31T22:15:15Z
New Revision: 7ea643c06d8977045d0cf79507f36d828773378c

URL: https://github.com/llvm/llvm-project/commit/7ea643c06d8977045d0cf79507f36d828773378c
DIFF: https://github.com/llvm/llvm-project/commit/7ea643c06d8977045d0cf79507f36d828773378c.diff

LOG: [mlir][sparse] Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor

Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor. The sparse tensor storage are represented as a tuple, these operation will later be eliminated and the tuple will be flattened after sparse tensor codegen

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133049

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 39af4a846f247..25bc16fec96c7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -623,4 +623,57 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Storage Operation. These operations are used internally by
+// sparse tensor codegen to progressively lower sparse tensors.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>,
+    Arguments<(ins AnyTuple:$storage,
+                   IndexAttr:$idx)>,
+    Results<(outs AnyType:$result)> {
+  let summary = "Get the data stored in the sparse tensor storage at the given index";
+  let description = [{
+     Get the data stored in the sparse tensor storage (represented as a tuple)
+     at the given index.
+
+     The result type should match the corresponding element type in the tuple.
+
+     Example:
+
+     ```mlir
+     %0 = sparse_tensor.storage_get %arg0[0] : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
+     ```
+   }];
+
+  let assemblyFormat = " $storage attr-dict `[`$idx`]` `:` type($storage) `to` type($result)";
+  let hasVerifier = 1;
+}
+
+def SparseTensor_StorageSetOp : SparseTensor_Op<"storage_set", []>,
+    Arguments<(ins AnyTuple:$storage,
+                   AnyType:$value,
+                   IndexAttr:$idx)>,
+    Results<(outs AnyTuple:$result)> {
+  let summary = "Set the data stored in the sparse tensor storage at given index";
+  let description = [{
+     Set the data stored in the sparse tensor storage (represented as a tuple)
+     at the given index. Return a new SSA value with the corresponding element
+     updated (others remain unchanged).
+
+     The result type should match the original tuple type with only the updated
+     element type changed accordingly.
+
+     Example:
+
+     ```mlir
+     %0 = sparse_tensor.storage_set %arg0, %arg1 at 0 : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to tuple<memref<?xf64>, memref<?xf64>, f64>
+     ```
+   }];
+
+  let assemblyFormat = " $storage attr-dict `[`$idx`]``,` $value `:` type($storage) `,` type($value) `to` type($result)";
+  let hasVerifier = 1;
+}
+
+
 #endif // SPARSETENSOR_OPS

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 8691b94351f9f..1c76f7efe8d9b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -482,6 +482,48 @@ LogicalResult YieldOp::verify() {
       "expected parent op to be sparse_tensor unary, binary, or reduce");
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Storage Operation.
+//===----------------------------------------------------------------------===//
+
+LogicalResult StorageGetOp::verify() {
+  uint64_t extractIdx = getIdx().getZExtValue();
+  auto innerTypeArray = getStorage().getType().getTypes();
+  if (extractIdx >= innerTypeArray.size())
+    return emitError(llvm::formatv(
+        "Out-of-bound access with index={0} on tuple with length={1}",
+        extractIdx, innerTypeArray.size()));
+
+  auto expectedTy = getStorage().getType().getType(extractIdx);
+  auto returnTy = getResult().getType();
+  if (expectedTy != returnTy)
+    return emitError(llvm::formatv(
+        "Type mismatch between the returning type (type={0}) and the "
+        "corresponding element type at index {1} (type={2})",
+        expectedTy, extractIdx, returnTy));
+  return success();
+}
+
+LogicalResult StorageSetOp::verify() {
+  uint64_t setIdx = getIdx().getZExtValue();
+  SmallVector<Type, 8> expectedElemTy(getStorage().getType().getTypes());
+  if (setIdx >= expectedElemTy.size())
+    return emitError(llvm::formatv(
+        "Out-of-bound access with index = {0} on tuple with length={1}", setIdx,
+        expectedElemTy.size()));
+
+  // Updates the element type after storage_set.
+  expectedElemTy[setIdx] = getValue().getType();
+  auto expectedTy = TupleType::get(getContext(), expectedElemTy);
+  auto returnTy = getResult().getType();
+  if (expectedTy != returnTy)
+    return emitError(
+        llvm::formatv("Type mismatch between the returning type "
+                      "(type={0}) and the expected type (type={1})",
+                      returnTy, expectedTy));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TensorDialect Methods.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index d9b48fe2240fa..805f959692365 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -443,3 +443,42 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
   return %0 : tensor<9x4xf64, #DC>
 }
 
+// -----
+
+func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
+  // expected-error at +1{{Out-of-bound access}}
+  %0 = sparse_tensor.storage_get %arg0[3]
+       : tuple<memref<?xf64>, memref<?xf64>, f64> to
+         memref<?xf64>
+  return %0 : memref<?xf64>
+}
+
+// -----
+
+func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
+  // expected-error at +1{{Type mismatch}}
+  %0 = sparse_tensor.storage_get %arg0[2]
+       : tuple<memref<?xf64>, memref<?xf64>, f64> to
+         memref<?xf64>
+  return %0 : memref<?xf64>
+}
+
+// -----
+
+func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+  // expected-error at +1{{Out-of-bound access}}
+  %0 = sparse_tensor.storage_set %arg0[3], %arg1
+       : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
+         tuple<memref<?xf64>, memref<?xf64>, f64>
+  return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
+
+// -----
+
+func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+  // expected-error at +1{{Type mismatch}}
+  %0 = sparse_tensor.storage_set %arg0[2], %arg1
+       : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
+         tuple<memref<?xf64>, memref<?xf64>, f64>
+  return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 5edc977de7c00..4b972778aae13 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -314,3 +314,34 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
          tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix>
   return %0 : tensor<9x4xf64, #SparseMatrix>
 }
+
+// -----
+
+// CHECK-LABEL: func @sparse_storage_get(
+//  CHECK-SAME:   %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>
+//       CHECK:   %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] :
+//  CHECK-SAME:     tuple<memref<?xf64>, memref<?xf64>, f64>
+//  CHECK-SAME:     to memref<?xf64>
+//       CHECK:   return %[[TMP0]] : memref<?xf64>
+func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
+  %0 = sparse_tensor.storage_get %arg0[0]
+       : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
+  return %0 : memref<?xf64>
+}
+
+// ----
+
+// CHECK-LABEL: func @sparse_storage_set(
+//  CHECK-SAME:   %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>,
+//  CHECK-SAME:   %[[A1:.*]]: memref<?xf64>
+//       CHECK:   %[[TMP0:.*]] = sparse_tensor.storage_set %[[A0]][0], %[[A1]] :
+//  CHECK-SAME:     tuple<memref<?xf64>, memref<?xf64>, f64>,
+//  CHECK-SAME:     memref<?xf64>
+//  CHECK-SAME:     to tuple<memref<?xf64>, memref<?xf64>, f64>
+//       CHECK:   return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+  %0 = sparse_tensor.storage_set %arg0[0], %arg1
+       : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
+         tuple<memref<?xf64>, memref<?xf64>, f64>
+  return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}


        


More information about the Mlir-commits mailing list