[Mlir-commits] [mlir] e08865a - [mlir][sparse] Introducing a new sparse_tensor.foreach operator.

Peiming Liu llvmlistbot at llvm.org
Thu Sep 22 16:49:30 PDT 2022


Author: Peiming Liu
Date: 2022-09-22T23:49:22Z
New Revision: e08865a12c16896439920f3366fdb676885502aa

URL: https://github.com/llvm/llvm-project/commit/e08865a12c16896439920f3366fdb676885502aa
DIFF: https://github.com/llvm/llvm-project/commit/e08865a12c16896439920f3366fdb676885502aa.diff

LOG: [mlir][sparse] Introducing a new sparse_tensor.foreach operator.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134484

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index a27cd02ae37b0..46f912d42bd51 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -389,7 +389,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
 }
 
 //===----------------------------------------------------------------------===//
-// Sparse Tensor Custom Linalg.Generic Operations.
+// Sparse Tensor Syntax Operations.
 //===----------------------------------------------------------------------===//
 
 def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
@@ -694,11 +694,11 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [NoSideEffect, SameOperand
 }
 
 def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
-    Arguments<(ins AnyType:$result)> {
+    Arguments<(ins Optional<AnyType>:$result)> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
-      Yields a value from within a `binary`, `unary`, `reduce`,
-      or `select` block.
+      Yields a value from within a `binary`, `unary`, `reduce`, 
+      `select` or `foreach` block.
 
       Example:
 
@@ -712,10 +712,46 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
       ```
   }];
 
+  let builders = [
+    OpBuilder<(ins),
+    [{
+      build($_builder, $_state, Value());
+    }]>
+  ];
+
   let assemblyFormat = [{
         $result attr-dict `:` type($result)
   }];
   let hasVerifier = 1;
 }
 
+def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
+    [SingleBlockImplicitTerminator<"YieldOp">]>,
+    Arguments<(ins AnySparseTensor:$tensor)>{
+  let summary = "Iterates over non-zero elements in a sparse tensor";
+  let description = [{
+     Iterates over every non-zero element in the given sparse tensor and executes
+     the block.
+
+     For a input sparse tensor with rank n, the block must take n + 1 arguments. The
+     first n arguments must be Index type, together indicating the current coordinates
+     of the element being visited. The last argument must have the same type as the
+     sparse tensor's element type, representing the actual value loaded from the input
+     tensor at the given coordinates.
+
+     Example:
+     
+     ```mlir
+     sparse_tensor.foreach in %0 : tensor<?x?xf64, #DCSR> do {
+      ^bb0(%arg1: index, %arg2: index, %arg3: f64):
+        do something...
+     }
+     ```
+  }];
+  
+  let regions = (region AnyRegion:$region);
+  let assemblyFormat = "`in` $tensor attr-dict `:` type($tensor)  `do` $region";
+  let hasVerifier = 1;
+}
+
 #endif // SPARSETENSOR_OPS

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index c647b0bd0db7c..2e98eaa7561c7 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -316,7 +316,7 @@ static LogicalResult verifyNumBlockArgs(T *op, Region &region,
   if (!yield)
     return op->emitError() << regionName
                            << " region must end with sparse_tensor.yield";
-  if (yield.getOperand().getType() != outputType)
+  if (!yield.getResult() || yield.getResult().getType() != outputType)
     return op->emitError() << regionName << " region yield type mismatch";
 
   return success();
@@ -410,7 +410,7 @@ LogicalResult ConcatenateOp::verify() {
         "Failed to concatentate tensors with rank={0} on dimension={1}.", rank,
         concatDim));
 
-  for (size_t i = 0; i < getInputs().size(); i++) {
+  for (size_t i = 0, e = getInputs().size(); i < e; i++) {
     Value input = getInputs()[i];
     auto inputRank = input.getType().cast<RankedTensorType>().getRank();
     if (inputRank != rank)
@@ -452,6 +452,28 @@ LogicalResult ConcatenateOp::verify() {
   return success();
 }
 
+LogicalResult ForeachOp::verify() {
+  auto t = getTensor().getType().cast<RankedTensorType>();
+  auto args = getBody()->getArguments();
+
+  if (static_cast<size_t>(t.getRank()) + 1 != args.size())
+    return emitError("Unmatched number of arguments in the block");
+
+  for (int64_t i = 0, e = t.getRank(); i < e; i++)
+    if (args[i].getType() != IndexType::get(getContext()))
+      emitError(
+          llvm::formatv("Expecting Index type for argument at index {0}", i));
+
+  auto elemTp = t.getElementType();
+  auto valueTp = args.back().getType();
+  if (elemTp != valueTp)
+    emitError(llvm::formatv("Unmatched element type between input tensor and "
+                            "block argument, expected:{0}, got: {1}",
+                            elemTp, valueTp));
+
+  return success();
+}
+
 LogicalResult ReduceOp::verify() {
   Type inputType = getX().getType();
   LogicalResult regionResult = success();
@@ -487,11 +509,12 @@ LogicalResult YieldOp::verify() {
   // Check for compatible parent.
   auto *parentOp = (*this)->getParentOp();
   if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
-      isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp))
+      isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
+      isa<ForeachOp>(parentOp))
     return success();
 
   return emitOpError("expected parent op to be sparse_tensor unary, binary, "
-                     "reduce, or select");
+                     "reduce, select or foreach");
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index c607dd2e77fee..af913204fabba 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -468,3 +468,36 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
          tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
   return %0 : tensor<9x4xf64, #DC>
 }
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
+  // expected-error at +1 {{Unmatched number of arguments in the block}}
+  sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
+    ^bb0(%1: index, %2: index, %3: index, %v: f64) :
+  }
+  return
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
+  // expected-error at +1 {{Expecting Index type for argument at index 1}}
+  sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
+    ^bb0(%1: index, %2: f64, %v: f64) :
+  }
+  return
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
+  // expected-error at +1 {{Unmatched element type between input tensor and block argument}}
+  sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
+    ^bb0(%1: index, %2: index, %v: f32) :
+  }
+  return
+}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 7d32300c61837..fd4b508ad4852 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -347,3 +347,18 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
          tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix>
   return %0 : tensor<9x4xf64, #SparseMatrix>
 }
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @sparse_tensor_foreach(
+//  CHECK-SAME: %[[A0:.*]]: tensor<2x4xf64
+//       CHECK: sparse_tensor.foreach in %[[A0]] : 
+//       CHECK:  ^bb0(%arg1: index, %arg2: index, %arg3: f64):
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
+  sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
+    ^bb0(%1: index, %2: index, %v: f64) :
+  }
+  return
+}


        


More information about the Mlir-commits mailing list