[Mlir-commits] [mlir] 14504ca - [mlir][sparse] Extend sparse_tensor.push_back to allow push_back a value n times.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 26 09:21:08 PDT 2022


Author: bixia1
Date: 2022-10-26T09:21:03-07:00
New Revision: 14504cae9a203b56f863031e4a6a6593cdee3bd8

URL: https://github.com/llvm/llvm-project/commit/14504cae9a203b56f863031e4a6a6593cdee3bd8
DIFF: https://github.com/llvm/llvm-project/commit/14504cae9a203b56f863031e4a6a6593cdee3bd8.diff

LOG: [mlir][sparse] Extend sparse_tensor.push_back to allow push_back a value n times.

Reviewed By: aartbik, Peiming

Differential Revision: https://reviews.llvm.org/D136653

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index de1ce485bebb8..8b8dc46297971 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -270,7 +270,8 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back",
      AllTypesMatch<["inBuffer", "outBuffer"]>]>,
     Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes,
                StridedMemRefRankOf<[AnyType], [1]>:$inBuffer,
-               AnyType:$value, IndexAttr:$idx, UnitAttr:$inbounds)>,
+               AnyType:$value, IndexAttr:$idx, Optional<Index>:$n,
+               UnitAttr:$inbounds)>,
     Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer)>  {
   string summary = "Pushes a value to the back of a given buffer";
   string description = [{
@@ -280,6 +281,14 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back",
     current buffer is full, then `inBuffer.realloc` is called before pushing the
     data to the buffer. This is similar to std::vector push_back.
 
+    The optional input `n` specifies the number of times to repeately push
+    the value to the back of the tensor. When `n` is a compile-time constant,
+    its value can't be less than 1. If `n` is a runtime value that is less
+    than 1, the behavior is undefined. Although using input `n` is semantically
+    equivalent to calling push_back n times, it gives compiler more chances to
+    to optimize the memory reallocation and the filling of the memory with the
+    same value.
+
     The `inbounds` attribute tells the compiler that the insertion won't go
     beyond the current storage buffer. This allows the compiler to not generate
     the code for capacity check and reallocation. The typical usage will be for
@@ -300,10 +309,24 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back",
     %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val
       {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64
     ```
+
+    ```mlir
+    %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val, %n
+      {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64
+    ```
   }];
   let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer"
-                       " `,` $value attr-dict `:` type($bufferSizes) `,`"
-                       " type($inBuffer) `,` type($value)";
+                       " `,` $value (`,` $n^ )?  attr-dict `:`"
+                       " type($bufferSizes) `,` type($inBuffer) `,`"
+                       " type($value)  (`,` type($n)^ )?";
+
+  let builders = [
+    //Build an op without input `n`.
+    OpBuilder<(ins "Type":$outBuffer, "Value":$bufferSizes, "Value":$inBuffer,
+                   "Value":$value, "APInt":$idx)>
+  ];
+
+  let hasVerifier = 1;
 }
 
 def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 14313caf87cc8..f17080c247103 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -538,6 +538,22 @@ LogicalResult InsertOp::verify() {
   return success();
 }
 
+void PushBackOp::build(OpBuilder &builder, OperationState &result,
+                       Type outBuffer, Value bufferSizes, Value inBuffer,
+                       Value value, APInt idx) {
+  build(builder, result, outBuffer, bufferSizes, inBuffer, value, idx, Value());
+}
+
+LogicalResult PushBackOp::verify() {
+  Value n = getN();
+  if (n) {
+    auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
+    if (nValue && nValue.value() < 1)
+      return emitOpError("n must be not less than 1");
+  }
+  return success();
+}
+
 LogicalResult CompressOp::verify() {
   RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
   if (ttp.getRank() != 1 + static_cast<int64_t>(getIndices().size()))

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 21aa1603a45d7..1591cb464b14c 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -132,6 +132,15 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
 
 // -----
 
+func.func @sparse_push_back_n(%arg0: memref<?xindex>, %arg1: memref<?xf32>, %arg2: f32) -> memref<?xf32> {
+  %c0 = arith.constant 0: index
+  // expected-error at +1 {{'sparse_tensor.push_back' op n must be not less than 1}}
+  %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %c0 {idx = 2 : index} : memref<?xindex>, memref<?xf32>, f32, index
+  return %0 : memref<?xf32>
+}
+
+// -----
+
 func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
   // expected-error at +1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
   %values, %filled, %added, %count = sparse_tensor.expand %arg0

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 3fed1566da141..e19a5ee833f83 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -171,6 +171,20 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
 
 // -----
 
+// CHECK-LABEL: func @sparse_push_back_n(
+//  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+//  CHECK-SAME: %[[B:.*]]: memref<?xf64>,
+//  CHECK-SAME: %[[C:.*]]: f64,
+//  CHECK-SAME: %[[D:.*]]: index) -> memref<?xf64> {
+//       CHECK: %[[E:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]], %[[D]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64, index
+//       CHECK: return %[[E]]
+func.func @sparse_push_back_n(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64, %arg3: index) -> memref<?xf64> {
+  %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64, index
+  return %0 : memref<?xf64>
+}
+
+// -----
+
 #SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
 
 // CHECK-LABEL: func @sparse_expansion(


        


More information about the Mlir-commits mailing list