[Mlir-commits] [mlir] e08865a - [mlir][sparse] Introducing a new sparse_tensor.foreach operator.
Peiming Liu
llvmlistbot at llvm.org
Thu Sep 22 16:49:30 PDT 2022
Author: Peiming Liu
Date: 2022-09-22T23:49:22Z
New Revision: e08865a12c16896439920f3366fdb676885502aa
URL: https://github.com/llvm/llvm-project/commit/e08865a12c16896439920f3366fdb676885502aa
DIFF: https://github.com/llvm/llvm-project/commit/e08865a12c16896439920f3366fdb676885502aa.diff
LOG: [mlir][sparse] Introducing a new sparse_tensor.foreach operator.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D134484
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index a27cd02ae37b0..46f912d42bd51 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -389,7 +389,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
}
//===----------------------------------------------------------------------===//
-// Sparse Tensor Custom Linalg.Generic Operations.
+// Sparse Tensor Syntax Operations.
//===----------------------------------------------------------------------===//
def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
@@ -694,11 +694,11 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [NoSideEffect, SameOperand
}
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
- Arguments<(ins AnyType:$result)> {
+ Arguments<(ins Optional<AnyType>:$result)> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
- Yields a value from within a `binary`, `unary`, `reduce`,
- or `select` block.
+ Yields a value from within a `binary`, `unary`, `reduce`,
+ `select` or `foreach` block.
Example:
@@ -712,10 +712,46 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
```
}];
+ let builders = [
+ OpBuilder<(ins),
+ [{
+ build($_builder, $_state, Value());
+ }]>
+ ];
+
let assemblyFormat = [{
$result attr-dict `:` type($result)
}];
let hasVerifier = 1;
}
+def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
+ [SingleBlockImplicitTerminator<"YieldOp">]>,
+ Arguments<(ins AnySparseTensor:$tensor)>{
+ let summary = "Iterates over non-zero elements in a sparse tensor";
+ let description = [{
+ Iterates over every non-zero element in the given sparse tensor and executes
+ the block.
+
+ For a input sparse tensor with rank n, the block must take n + 1 arguments. The
+ first n arguments must be Index type, together indicating the current coordinates
+ of the element being visited. The last argument must have the same type as the
+ sparse tensor's element type, representing the actual value loaded from the input
+ tensor at the given coordinates.
+
+ Example:
+
+ ```mlir
+ sparse_tensor.foreach in %0 : tensor<?x?xf64, #DCSR> do {
+ ^bb0(%arg1: index, %arg2: index, %arg3: f64):
+ do something...
+ }
+ ```
+ }];
+
+ let regions = (region AnyRegion:$region);
+ let assemblyFormat = "`in` $tensor attr-dict `:` type($tensor) `do` $region";
+ let hasVerifier = 1;
+}
+
#endif // SPARSETENSOR_OPS
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index c647b0bd0db7c..2e98eaa7561c7 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -316,7 +316,7 @@ static LogicalResult verifyNumBlockArgs(T *op, Region ®ion,
if (!yield)
return op->emitError() << regionName
<< " region must end with sparse_tensor.yield";
- if (yield.getOperand().getType() != outputType)
+ if (!yield.getResult() || yield.getResult().getType() != outputType)
return op->emitError() << regionName << " region yield type mismatch";
return success();
@@ -410,7 +410,7 @@ LogicalResult ConcatenateOp::verify() {
"Failed to concatentate tensors with rank={0} on dimension={1}.", rank,
concatDim));
- for (size_t i = 0; i < getInputs().size(); i++) {
+ for (size_t i = 0, e = getInputs().size(); i < e; i++) {
Value input = getInputs()[i];
auto inputRank = input.getType().cast<RankedTensorType>().getRank();
if (inputRank != rank)
@@ -452,6 +452,28 @@ LogicalResult ConcatenateOp::verify() {
return success();
}
+LogicalResult ForeachOp::verify() {
+ auto t = getTensor().getType().cast<RankedTensorType>();
+ auto args = getBody()->getArguments();
+
+ if (static_cast<size_t>(t.getRank()) + 1 != args.size())
+ return emitError("Unmatched number of arguments in the block");
+
+ for (int64_t i = 0, e = t.getRank(); i < e; i++)
+ if (args[i].getType() != IndexType::get(getContext()))
+ emitError(
+ llvm::formatv("Expecting Index type for argument at index {0}", i));
+
+ auto elemTp = t.getElementType();
+ auto valueTp = args.back().getType();
+ if (elemTp != valueTp)
+ emitError(llvm::formatv("Unmatched element type between input tensor and "
+ "block argument, expected:{0}, got: {1}",
+ elemTp, valueTp));
+
+ return success();
+}
+
LogicalResult ReduceOp::verify() {
Type inputType = getX().getType();
LogicalResult regionResult = success();
@@ -487,11 +509,12 @@ LogicalResult YieldOp::verify() {
// Check for compatible parent.
auto *parentOp = (*this)->getParentOp();
if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
- isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp))
+ isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
+ isa<ForeachOp>(parentOp))
return success();
return emitOpError("expected parent op to be sparse_tensor unary, binary, "
- "reduce, or select");
+ "reduce, select or foreach");
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index c607dd2e77fee..af913204fabba 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -468,3 +468,36 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
return %0 : tensor<9x4xf64, #DC>
}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
+ // expected-error at +1 {{Unmatched number of arguments in the block}}
+ sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
+ ^bb0(%1: index, %2: index, %3: index, %v: f64) :
+ }
+ return
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
+ // expected-error at +1 {{Expecting Index type for argument at index 1}}
+ sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
+ ^bb0(%1: index, %2: f64, %v: f64) :
+ }
+ return
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
+ // expected-error at +1 {{Unmatched element type between input tensor and block argument}}
+ sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
+ ^bb0(%1: index, %2: index, %v: f32) :
+ }
+ return
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 7d32300c61837..fd4b508ad4852 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -347,3 +347,18 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix>
return %0 : tensor<9x4xf64, #SparseMatrix>
}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @sparse_tensor_foreach(
+// CHECK-SAME: %[[A0:.*]]: tensor<2x4xf64
+// CHECK: sparse_tensor.foreach in %[[A0]] :
+// CHECK: ^bb0(%arg1: index, %arg2: index, %arg3: f64):
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
+ sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
+ ^bb0(%1: index, %2: index, %v: f64) :
+ }
+ return
+}
More information about the Mlir-commits
mailing list