[Mlir-commits] [mlir] 53d5d34 - [mlir][sparse] extend foreach operation to accept reduction arguments.

Peiming Liu llvmlistbot at llvm.org
Fri Nov 4 16:34:22 PDT 2022


Author: Peiming Liu
Date: 2022-11-04T23:34:16Z
New Revision: 53d5d3401120f2aa741a73a5a9ba0ce012ca532c

URL: https://github.com/llvm/llvm-project/commit/53d5d3401120f2aa741a73a5a9ba0ce012ca532c
DIFF: https://github.com/llvm/llvm-project/commit/53d5d3401120f2aa741a73a5a9ba0ce012ca532c.diff

LOG: [mlir][sparse] extend foreach operation to accept reduction arguments.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137463

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 8b8dc46297971..a22dcce4298ef 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -857,21 +857,44 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
 
 def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
     [SingleBlockImplicitTerminator<"YieldOp">]>,
-    Arguments<(ins AnyTensor:$tensor)>{
+     Arguments<(ins AnyTensor:$tensor,
+                    Variadic<AnyType>:$initArgs)>,
+     Results<(outs Variadic<AnyType>:$results)> {
   let summary = "Iterates over elements in a tensor";
   let description = [{
      Iterates over stored elements in a tensor (which are typically, but not always,
      non-zero for sparse tensors) and executes the block.
 
-     For an input tensor with rank n, the block must take n + 1 arguments. The
-     first n arguments must be Index type, together indicating the current coordinates
-     of the element being visited. The last argument must have the same type as the
+     For an input tensor with rank n, the block must take n + 1 (and additional loop
+     carried variables as described below) arguments. The first n arguments must be
+     Index type, together indicating the current coordinates of the element being visited.
+     The last argument must have the same type as the
      tensor's element type, representing the actual value loaded from the input
      tensor at the given coordinates.
 
-     Note that foreach generated loop iterates over the stored elements in the storage
-     order. However, no matter what storage order is used, the indices passed to the block
-     always obey the original dimension order.
+     `sparse_tensor.foreach` can also operate on loop-carried variables and returns
+     the final values after loop termination. The initial values of the variables are
+     passed as additional SSA operands to the "sparse_tensor.foreach" following the n + 1
+     SSA values mentioned above (n coordinate and 1 value).
+
+     The region must terminate with a "sparse_tensor.yield" that passes the current
+     values of all loop-carried variables to the next iteration, or to the
+     result, if at the last iteration. The number and static types of loop-carried
+     variables may not change with iterations.
+
+     For example:
+     ```mlir
+     %c0 = arith.constant 0 : i32
+     %ret = sparse_tensor.foreach in %0 init(%c0): tensor<?x?xi32, #DCSR>, i32 -> i32 do {
+      ^bb0(%arg1: index, %arg2: index, %arg3: i32, %iter: i32):
+        %sum = arith.add %iter, %arg3
+        sparse_tensor.yield %sum
+     }
+     ```
+
+     It is important to note that foreach generated loop iterates over the stored elements
+     in the storage order. However, no matter what storage order is used, the indices passed
+     to the block always obey the original dimension order.
 
      For example:
      ```mlir
@@ -879,10 +902,10 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
        dimLevelType = [ "compressed", "compressed" ],
        dimOrdering = affine_map<(i,j) -> (j,i)>
      }>
-     
+
      // foreach on a column-major sparse tensor
      sparse_tensor.foreach in %0 : tensor<2x3xf64, #COL_MAJOR> do {
-      ^bb0(%row: index, %col: index, %arg3: f64): 
+      ^bb0(%row: index, %col: index, %arg3: f64):
          // [%row, %col] -> [0, 0], [1, 0], [2, 0], [0, 1], [1, 1], [2, 1]
      }
 
@@ -892,30 +915,25 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
 
      // foreach on a row-major sparse tensor
      sparse_tensor.foreach in %0 : tensor<2x3xf64, #ROW_MAJOR> do {
-      ^bb0(%row: index, %col: index, %arg3: f64): 
+      ^bb0(%row: index, %col: index, %arg3: f64):
          // [%row, %col] -> [0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]
      }
 
      ```
-
-     Example:
-
-     ```mlir
-     sparse_tensor.foreach in %0 : tensor<?x?xf64, #DCSR> do {
-      ^bb0(%arg1: index, %arg2: index, %arg3: f64):
-        do something...
-     }
-     ```
   }];
 
   let builders = [
-    OpBuilder<(
-      ins "Value":$tensor,
-      "function_ref<void(OpBuilder &, Location, ValueRange)>")>
+    OpBuilder<(ins "Value":$tensor,
+      "function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>")>,
+    OpBuilder<(ins "Value":$tensor,
+      "ValueRange":$iterArgs,
+      "function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>")>
   ];
 
-  let regions = (region AnyRegion:$region);
-  let assemblyFormat = "`in` $tensor attr-dict `:` type($tensor)  `do` $region";
+  let regions = (region SizedRegion<1>:$region);
+  let assemblyFormat = "`in` $tensor (`init``(`$initArgs^`)`)? attr-dict"
+                       "    `:` type($tensor) (`,` type($initArgs)^)?"
+                       "  (`->` type($results)^)?  `do` $region";
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 133879b12b197..4563a054ec160 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -581,11 +581,20 @@ LogicalResult CompressOp::verify() {
 
 void ForeachOp::build(
     OpBuilder &builder, OperationState &result, Value tensor,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
-  build(builder, result, tensor);
+    function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
+        bodyBuilder) {
+  build(builder, result, tensor, llvm::None, bodyBuilder);
+}
+
+void ForeachOp::build(
+    OpBuilder &builder, OperationState &result, Value tensor,
+    ValueRange initArgs,
+    function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
+        bodyBuilder) {
+  build(builder, result, initArgs.getTypes(), tensor, initArgs);
+  // Builds foreach body.
   if (!bodyBuilder)
     return;
-
   auto rtp = tensor.getType().cast<RankedTensorType>();
   int64_t rank = rtp.getRank();
 
@@ -602,23 +611,38 @@ void ForeachOp::build(
   auto &region = *result.regions.front();
   Block *bodyBlock =
       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
-  bodyBuilder(builder, result.location, bodyBlock->getArguments());
+  bodyBuilder(builder, result.location,
+              bodyBlock->getArguments().slice(0, rank),
+              bodyBlock->getArguments()[rank],
+              bodyBlock->getArguments().drop_front(rank + 1));
 }
 
 LogicalResult ForeachOp::verify() {
   auto t = getTensor().getType().cast<RankedTensorType>();
   auto args = getBody()->getArguments();
 
-  if (static_cast<size_t>(t.getRank()) + 1 != args.size())
+  if (static_cast<size_t>(t.getRank()) + 1 + getInitArgs().size() !=
+      args.size())
     return emitError("Unmatched number of arguments in the block");
 
+  if (getNumResults() != getInitArgs().size())
+    return emitError("Mismatch in number of init arguments and results");
+
+  if (getResultTypes() != getInitArgs().getTypes())
+    return emitError("Mismatch in types of init arguments and results");
+
+  auto yield = cast<YieldOp>(getBody()->getTerminator());
+  if (yield.getNumOperands() != getNumResults() ||
+      yield.getOperands().getTypes() != getResultTypes())
+    return emitError("Mismatch in types of yield values and results");
+
   for (int64_t i = 0, e = t.getRank(); i < e; i++)
     if (args[i].getType() != IndexType::get(getContext()))
       emitError(
           llvm::formatv("Expecting Index type for argument at index {0}", i));
 
   auto elemTp = t.getElementType();
-  auto valueTp = args.back().getType();
+  auto valueTp = args[t.getRank()].getType();
   if (elemTp != valueTp)
     emitError(llvm::formatv("Unmatched element type between input tensor and "
                             "block argument, expected:{0}, got: {1}",

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 9c002f1ae0ec8..7747fd73aa9bb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -357,7 +357,9 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
     auto cooBuffer =
         rewriter.create<AllocTensorOp>(loc, cooTp, dstDynSizes).getResult();
     rewriter.create<ForeachOp>(
-        loc, srcTensor, [&](OpBuilder &builder, Location loc, ValueRange args) {
+        loc, srcTensor, llvm::None,
+        [&](OpBuilder &builder, Location loc, ValueRange args, Value v,
+            ValueRange reduc) {
           SmallVector<Value, 4> srcIndices;
           SmallVector<Value, 4> dstIndices;
           for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
@@ -366,7 +368,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
           }
           translateIndicesArray(builder, loc, op.getReassociationIndices(),
                                 srcIndices, srcSizes, dstSizes, dstIndices);
-          builder.create<InsertOp>(loc, args.back(), cooBuffer, dstIndices);
+          builder.create<InsertOp>(loc, v, cooBuffer, dstIndices);
           builder.create<sparse_tensor::YieldOp>(loc);
         });
 
@@ -446,7 +448,9 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
       // Build a for op for each input tensor to append new values into the
       // output tensor.
       rewriter.create<ForeachOp>(
-          loc, input, [&](OpBuilder &builder, Location loc, ValueRange args) {
+          loc, input, llvm::None,
+          [&](OpBuilder &builder, Location loc, ValueRange args, Value v,
+              ValueRange reduc) {
             SmallVector<Value, 4> indices;
             for (int64_t i = 0; i < rank; i++) {
               uint64_t dim =
@@ -457,7 +461,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
                 idx = builder.create<arith::AddIOp>(loc, idx, offset);
               indices.push_back(idx);
             }
-            builder.create<InsertOp>(loc, args.back(), cooBuffer, indices);
+            builder.create<InsertOp>(loc, v, cooBuffer, indices);
             builder.create<sparse_tensor::YieldOp>(loc);
           });
       // Accumulates the offset. Note that only static-shaped inputs are allowed
@@ -558,12 +562,13 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
     sizesForTensor(rewriter, sizes, loc, srcTp, src);
     Value dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
 
-    rewriter.create<ForeachOp>(
-        loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
-          builder.create<memref::StoreOp>(loc, args.back(), dst,
-                                          args.drop_back());
-          builder.create<sparse_tensor::YieldOp>(loc);
-        });
+    rewriter.create<ForeachOp>(loc, src, llvm::None,
+                               [&](OpBuilder &builder, Location loc,
+                                   ValueRange args, Value v, ValueRange reduc) {
+                                 builder.create<memref::StoreOp>(loc, v, dst,
+                                                                 args);
+                                 builder.create<sparse_tensor::YieldOp>(loc);
+                               });
 
     rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstTp, dst);
     return success();
@@ -598,13 +603,15 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
       tmpCoo =
           rewriter.create<AllocTensorOp>(loc, srcTp, dynSrcSizes).getResult();
       rewriter.create<ForeachOp>(
-          loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+          loc, src, llvm::None,
+          [&](OpBuilder &builder, Location loc, ValueRange args, Value v,
+              ValueRange reduc) {
             SmallVector<Value, 4> indices;
             for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
               uint64_t dim = toStoredDim(encSrc, i);
               indices.push_back(args[dim]);
             }
-            builder.create<InsertOp>(loc, args.back(), tmpCoo, indices);
+            builder.create<InsertOp>(loc, v, tmpCoo, indices);
             builder.create<sparse_tensor::YieldOp>(loc);
           });
       src = tmpCoo;
@@ -646,16 +653,18 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
     getDynamicSizes(dstTp, srcSizes, dynDstSizes);
     Value dst =
         rewriter.create<AllocTensorOp>(loc, dstTp, dynDstSizes).getResult();
-    rewriter.create<ForeachOp>(
-        loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
-          SmallVector<Value, 4> indices;
-          for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
-            uint64_t dim = toStoredDim(encDst, i);
-            indices.push_back(args[dim]);
-          }
-          builder.create<InsertOp>(loc, args.back(), dst, indices);
-          builder.create<sparse_tensor::YieldOp>(loc);
-        });
+    rewriter.create<ForeachOp>(loc, src, llvm::None,
+                               [&](OpBuilder &builder, Location loc,
+                                   ValueRange args, Value v, ValueRange reduc) {
+                                 SmallVector<Value, 4> indices;
+                                 for (int64_t i = 0, e = srcTp.getRank(); i < e;
+                                      i++) {
+                                   uint64_t dim = toStoredDim(encDst, i);
+                                   indices.push_back(args[dim]);
+                                 }
+                                 builder.create<InsertOp>(loc, v, dst, indices);
+                                 builder.create<sparse_tensor::YieldOp>(loc);
+                               });
 
     // Release the temporary COO if it is created.
     if (tmpCoo)
@@ -866,12 +875,14 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
     ModuleOp module = op->getParentOfType<ModuleOp>();
     // For each element in the source tensor, output the element.
     rewriter.create<ForeachOp>(
-        loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+        loc, src, llvm::None,
+        [&](OpBuilder &builder, Location loc, ValueRange args, Value v,
+            ValueRange reduc) {
           for (uint64_t i = 0; i < rank; i++) {
             rewriter.create<memref::StoreOp>(loc, args[i], indices,
                                              constantIndex(builder, loc, i));
           }
-          rewriter.create<memref::StoreOp>(loc, args.back(), value);
+          rewriter.create<memref::StoreOp>(loc, v, value);
           SmallVector<Value, 4> operands{writer, rankValue, indices, value};
           FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
                                          EmitCInterface::On);

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 1ab4a66665287..407f19401b86b 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -551,6 +551,51 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
 
 // -----
 
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
+  // expected-error at +1 {{Unmatched element type between input tensor and block argument}}
+  sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
+    ^bb0(%1: index, %2: index, %v: f32) :
+  }
+  return
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () {
+  // expected-error at +1 {{Mismatch in number of init arguments and results}}
+  sparse_tensor.foreach in %arg0 init(%arg1) : tensor<2x4xf64, #DCSR>, f32 do {
+    ^bb0(%1: index, %2: index, %v: f32, %r1 : i32) :
+  }
+  return
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () {
+  // expected-error at +1 {{Mismatch in types of init arguments and results}}
+  %1 = sparse_tensor.foreach in %arg0 init(%arg1) : tensor<2x4xf64, #DCSR>, f32 -> i32 do {
+    ^bb0(%1: index, %2: index, %v: f32, %r0 : f32) :
+  }
+  return
+}
+
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () {
+  // expected-error at +1 {{Mismatch in types of yield values and results}}
+  %1 = sparse_tensor.foreach in %arg0 init(%arg1) : tensor<2x4xf64, #DCSR>, f32 -> f32 do {
+    ^bb0(%1: index, %2: index, %v: f32, %r0 : f32) :
+      sparse_tensor.yield %1 : index
+  }
+  return
+}
+
+// -----
+
 // TODO: a test case with empty xs doesn't work due to some parser issues.
 
 func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index e19a5ee833f83..628ce3b4535a5 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -411,6 +411,26 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
   return
 }
 
+// -----
+
+#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @sparse_tensor_foreach(
+//  CHECK-SAME:   %[[A0:.*]]: tensor<2x4xf64, #sparse_tensor.encoding<{{{.*}}}>>, 
+//  CHECK-SAME:   %[[A1:.*]]: f32
+//  CHECK-NEXT:   %[[RET:.*]] = sparse_tensor.foreach in %[[A0]] init(%[[A1]])
+//  CHECK-NEXT:    ^bb0(%[[TMP_1:.*]]: index, %[[TMP_2:.*]]: index, %[[TMP_v:.*]]: f64, %[[TMP_r:.*]]: f32)
+//       CHECK:      sparse_tensor.yield %[[TMP_r]] : f32
+//       CHECK:  }
+func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () {
+  %ret = sparse_tensor.foreach in %arg0 init(%arg1): tensor<2x4xf64, #DCSR>, f32 -> f32
+  do {
+    ^bb0(%1: index, %2: index, %v: f64, %r: f32) : 
+      sparse_tensor.yield %r : f32
+  }
+  return
+}
+
 // ----
 
 // CHECK-LABEL: func @sparse_sort_1d0v(


        


More information about the Mlir-commits mailing list