[Mlir-commits] [mlir] 1c835b5 - [mlir][sparse] Allow the push_back operator to skip capacity check and reallocation.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 29 16:38:15 PDT 2022


Author: bixia1
Date: 2022-09-29T16:38:06-07:00
New Revision: 1c835b5a8e5622713c11205154d7bbd59bb5647b

URL: https://github.com/llvm/llvm-project/commit/1c835b5a8e5622713c11205154d7bbd59bb5647b
DIFF: https://github.com/llvm/llvm-project/commit/1c835b5a8e5622713c11205154d7bbd59bb5647b.diff

LOG: [mlir][sparse] Allow the push_back operator to skip capacity check and reallocation.

Add UnitAttr `inbounds` for this purpose.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134913

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
    mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 0a670ec1445a3..a9f6c18d9d112 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -240,7 +240,7 @@ def SparseTensor_InsertOp : SparseTensor_Op<"insert",
 def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
     Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes,
                StridedMemRefRankOf<[AnyType], [1]>:$inBuffer,
-               AnyType:$value, IndexAttr:$idx)>,
+               AnyType:$value, IndexAttr:$idx, UnitAttr:$inbounds)>,
     Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer)>  {
   string summary = "Pushes a value to the back of a given buffer";
   string description = [{
@@ -250,6 +250,11 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
     current buffer is full, then `inBuffer.realloc` is called before pushing the
     data to the buffer. This is similar to std::vector push_back.
 
+    The `inbounds` attribute tells the compiler that the insertion won't go
+    beyond the current storage buffer. This allows the compiler to not generate
+    the code for capacity check and reallocation. The typical usage will be for
+    "dynamic" sparse tensors for which a capacity can be set beforehand.
+
     The operation returns an SSA value for the memref. Referencing the memref
     through the old SSA value after this operation is undefined behavior.
 
@@ -259,9 +264,14 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
     %r = sparse_tensor.push_back %bufferSizes, %buffer, %val {idx = 0 : index}
       : memref<?xindex>, memref<?xf64>, f64 -> memref<?xf64>
     ```
+
+    ```mlir
+    %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val
+      {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64 -> memref<?xf64>
+    ```
   }];
-  let assemblyFormat = "$bufferSizes `,` $inBuffer `,` $value"
-                       " attr-dict `:` type($bufferSizes) `,`"
+  let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer"
+                       " `,` $value attr-dict `:` type($bufferSizes) `,`"
                        " type($inBuffer) `,` type($value) `to`"
                        " type($outBuffer)";
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 7a1c6fa138063..8741f5f7d89b3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -350,6 +350,8 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
     // buffer = new_buffer
     // store(buffer, value)
     // size(buffer)++
+    //
+    // The capacity check is skipped when the attribute inbounds is presented.
     Location loc = op->getLoc();
     Value c0 = constantIndex(rewriter, loc, 0);
     Value buffer = op.getInBuffer();
@@ -357,28 +359,34 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
     Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
     Value bufferSizes = op.getBufferSizes();
     Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
-    Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
-                                                size, capacity);
     Value value = op.getValue();
-    auto bufferType =
-        MemRefType::get({ShapedType::kDynamicSize}, value.getType());
-    scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
-                                                /*else=*/true);
-    // True branch.
-    rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
-    Value c2 = constantIndex(rewriter, loc, 2);
-    capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
-    Value newBuffer =
-        rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
-    rewriter.create<scf::YieldOp>(loc, newBuffer);
-
-    // False branch.
-    rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
-    rewriter.create<scf::YieldOp>(loc, buffer);
+
+    if (!op.getInbounds()) {
+      Value cond = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::uge, size, capacity);
+
+      auto bufferType =
+          MemRefType::get({ShapedType::kDynamicSize}, value.getType());
+      scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
+                                                  /*else=*/true);
+      // True branch.
+      rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      Value c2 = constantIndex(rewriter, loc, 2);
+      capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
+      Value newBuffer =
+          rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
+      rewriter.create<scf::YieldOp>(loc, newBuffer);
+
+      // False branch.
+      rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+      rewriter.create<scf::YieldOp>(loc, buffer);
+
+      // Prepare for adding the value to the end of the buffer.
+      rewriter.setInsertionPointAfter(ifOp);
+      buffer = ifOp.getResult(0);
+    }
 
     // Add the value to the end of the buffer.
-    rewriter.setInsertionPointAfter(ifOp);
-    buffer = ifOp.getResult(0);
     rewriter.create<memref::StoreOp>(loc, value, buffer, size);
 
     // Increment the size of the buffer by 1.

diff  --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 5aef2be365667..ccf9d40f59de2 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -26,6 +26,22 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
   return %0 : memref<?xf64>
 }
 
+// CHECK-LABEL: func @sparse_push_back_inbound(
+//  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+//  CHECK-SAME: %[[B:.*]]: memref<?xf64>,
+//  CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//       CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+//       CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[P]]]
+//       CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
+//       CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
+//       CHECK: return %[[B]] : memref<?xf64>
+func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
+  %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+  return %0 : memref<?xf64>
+}
+
 // CHECK-LABEL:   func.func private @_sparse_less_than_1_i8(
 // CHECK-SAME:                                              %[[I:arg0]]: index,
 // CHECK-SAME:                                              %[[J:.*]]: index,

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index e5ffe85284c70..ee69af3d256a2 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -145,6 +145,19 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
 
 // -----
 
+// CHECK-LABEL: func @sparse_push_back_inbound(
+//  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+//  CHECK-SAME: %[[B:.*]]: memref<?xf64>,
+//  CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
+//       CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+//       CHECK: return %[[D]]
+func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
+  %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+  return %0 : memref<?xf64>
+}
+
+// -----
+
 #SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
 
 // CHECK-LABEL: func @sparse_expansion(


        


More information about the Mlir-commits mailing list