[Mlir-commits] [mlir] 7ea643c - [mlir][sparse] Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor
Peiming Liu
llvmlistbot at llvm.org
Wed Aug 31 15:15:24 PDT 2022
Author: Peiming Liu
Date: 2022-08-31T22:15:15Z
New Revision: 7ea643c06d8977045d0cf79507f36d828773378c
URL: https://github.com/llvm/llvm-project/commit/7ea643c06d8977045d0cf79507f36d828773378c
DIFF: https://github.com/llvm/llvm-project/commit/7ea643c06d8977045d0cf79507f36d828773378c.diff
LOG: [mlir][sparse] Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor
Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor. The sparse tensor storage are represented as a tuple, these operation will later be eliminated and the tuple will be flattened after sparse tensor codegen
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D133049
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 39af4a846f247..25bc16fec96c7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -623,4 +623,57 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Storage Operation. These operations are used internally by
+// sparse tensor codegen to progressively lower sparse tensors.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>,
+ Arguments<(ins AnyTuple:$storage,
+ IndexAttr:$idx)>,
+ Results<(outs AnyType:$result)> {
+ let summary = "Get the data stored in the sparse tensor storage at the given index";
+ let description = [{
+ Get the data stored in the sparse tensor storage (represented as a tuple)
+ at the given index.
+
+ The result type should match the corresponding element type in the tuple.
+
+ Example:
+
+ ```mlir
+ %0 = sparse_tensor.storage_get %arg0[0] : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
+ ```
+ }];
+
+ let assemblyFormat = " $storage attr-dict `[`$idx`]` `:` type($storage) `to` type($result)";
+ let hasVerifier = 1;
+}
+
+def SparseTensor_StorageSetOp : SparseTensor_Op<"storage_set", []>,
+ Arguments<(ins AnyTuple:$storage,
+ AnyType:$value,
+ IndexAttr:$idx)>,
+ Results<(outs AnyTuple:$result)> {
+ let summary = "Set the data stored in the sparse tensor storage at given index";
+ let description = [{
+ Set the data stored in the sparse tensor storage (represented as a tuple)
+ at the given index. Return a new SSA value with the corresponding element
+ updated (others remain unchanged).
+
+ The result type should match the original tuple type with only the updated
+ element type changed accordingly.
+
+ Example:
+
+ ```mlir
+ %0 = sparse_tensor.storage_set %arg0, %arg1 at 0 : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to tuple<memref<?xf64>, memref<?xf64>, f64>
+ ```
+ }];
+
+ let assemblyFormat = " $storage attr-dict `[`$idx`]``,` $value `:` type($storage) `,` type($value) `to` type($result)";
+ let hasVerifier = 1;
+}
+
+
#endif // SPARSETENSOR_OPS
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 8691b94351f9f..1c76f7efe8d9b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -482,6 +482,48 @@ LogicalResult YieldOp::verify() {
"expected parent op to be sparse_tensor unary, binary, or reduce");
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Storage Operation.
+//===----------------------------------------------------------------------===//
+
+LogicalResult StorageGetOp::verify() {
+ uint64_t extractIdx = getIdx().getZExtValue();
+ auto innerTypeArray = getStorage().getType().getTypes();
+ if (extractIdx >= innerTypeArray.size())
+ return emitError(llvm::formatv(
+ "Out-of-bound access with index={0} on tuple with length={1}",
+ extractIdx, innerTypeArray.size()));
+
+ auto expectedTy = getStorage().getType().getType(extractIdx);
+ auto returnTy = getResult().getType();
+ if (expectedTy != returnTy)
+ return emitError(llvm::formatv(
+ "Type mismatch between the returning type (type={0}) and the "
+ "corresponding element type at index {1} (type={2})",
+ expectedTy, extractIdx, returnTy));
+ return success();
+}
+
+LogicalResult StorageSetOp::verify() {
+ uint64_t setIdx = getIdx().getZExtValue();
+ SmallVector<Type, 8> expectedElemTy(getStorage().getType().getTypes());
+ if (setIdx >= expectedElemTy.size())
+ return emitError(llvm::formatv(
+ "Out-of-bound access with index = {0} on tuple with length={1}", setIdx,
+ expectedElemTy.size()));
+
+ // Updates the element type after storage_set.
+ expectedElemTy[setIdx] = getValue().getType();
+ auto expectedTy = TupleType::get(getContext(), expectedElemTy);
+ auto returnTy = getResult().getType();
+ if (expectedTy != returnTy)
+ return emitError(
+ llvm::formatv("Type mismatch between the returning type "
+ "(type={0}) and the expected type (type={1})",
+ returnTy, expectedTy));
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TensorDialect Methods.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index d9b48fe2240fa..805f959692365 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -443,3 +443,42 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
return %0 : tensor<9x4xf64, #DC>
}
+// -----
+
+func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
+ // expected-error at +1{{Out-of-bound access}}
+ %0 = sparse_tensor.storage_get %arg0[3]
+ : tuple<memref<?xf64>, memref<?xf64>, f64> to
+ memref<?xf64>
+ return %0 : memref<?xf64>
+}
+
+// -----
+
+func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
+ // expected-error at +1{{Type mismatch}}
+ %0 = sparse_tensor.storage_get %arg0[2]
+ : tuple<memref<?xf64>, memref<?xf64>, f64> to
+ memref<?xf64>
+ return %0 : memref<?xf64>
+}
+
+// -----
+
+func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+ // expected-error at +1{{Out-of-bound access}}
+ %0 = sparse_tensor.storage_set %arg0[3], %arg1
+ : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
+ tuple<memref<?xf64>, memref<?xf64>, f64>
+ return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
+
+// -----
+
+func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+ // expected-error at +1{{Type mismatch}}
+ %0 = sparse_tensor.storage_set %arg0[2], %arg1
+ : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
+ tuple<memref<?xf64>, memref<?xf64>, f64>
+ return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 5edc977de7c00..4b972778aae13 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -314,3 +314,34 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix>
return %0 : tensor<9x4xf64, #SparseMatrix>
}
+
+// -----
+
+// CHECK-LABEL: func @sparse_storage_get(
+// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>
+// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] :
+// CHECK-SAME: tuple<memref<?xf64>, memref<?xf64>, f64>
+// CHECK-SAME: to memref<?xf64>
+// CHECK: return %[[TMP0]] : memref<?xf64>
+func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
+ %0 = sparse_tensor.storage_get %arg0[0]
+ : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
+ return %0 : memref<?xf64>
+}
+
+// ----
+
+// CHECK-LABEL: func @sparse_storage_set(
+// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>,
+// CHECK-SAME: %[[A1:.*]]: memref<?xf64>
+// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_set %[[A0]][0], %[[A1]] :
+// CHECK-SAME: tuple<memref<?xf64>, memref<?xf64>, f64>,
+// CHECK-SAME: memref<?xf64>
+// CHECK-SAME: to tuple<memref<?xf64>, memref<?xf64>, f64>
+// CHECK: return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+ %0 = sparse_tensor.storage_set %arg0[0], %arg1
+ : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
+ tuple<memref<?xf64>, memref<?xf64>, f64>
+ return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
More information about the Mlir-commits
mailing list