[Mlir-commits] [mlir] 1c835b5 - [mlir][sparse] Allow the push_back operator to skip capacity check and reallocation.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 29 16:38:15 PDT 2022
Author: bixia1
Date: 2022-09-29T16:38:06-07:00
New Revision: 1c835b5a8e5622713c11205154d7bbd59bb5647b
URL: https://github.com/llvm/llvm-project/commit/1c835b5a8e5622713c11205154d7bbd59bb5647b
DIFF: https://github.com/llvm/llvm-project/commit/1c835b5a8e5622713c11205154d7bbd59bb5647b.diff
LOG: [mlir][sparse] Allow the push_back operator to skip capacity check and reallocation.
Add UnitAttr `inbounds` for this purpose.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D134913
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 0a670ec1445a3..a9f6c18d9d112 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -240,7 +240,7 @@ def SparseTensor_InsertOp : SparseTensor_Op<"insert",
def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes,
StridedMemRefRankOf<[AnyType], [1]>:$inBuffer,
- AnyType:$value, IndexAttr:$idx)>,
+ AnyType:$value, IndexAttr:$idx, UnitAttr:$inbounds)>,
Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer)> {
string summary = "Pushes a value to the back of a given buffer";
string description = [{
@@ -250,6 +250,11 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
current buffer is full, then `inBuffer.realloc` is called before pushing the
data to the buffer. This is similar to std::vector push_back.
+ The `inbounds` attribute tells the compiler that the insertion won't go
+ beyond the current storage buffer. This allows the compiler to not generate
+ the code for capacity check and reallocation. The typical usage will be for
+ "dynamic" sparse tensors for which a capacity can be set beforehand.
+
The operation returns an SSA value for the memref. Referencing the memref
through the old SSA value after this operation is undefined behavior.
@@ -259,9 +264,14 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
%r = sparse_tensor.push_back %bufferSizes, %buffer, %val {idx = 0 : index}
: memref<?xindex>, memref<?xf64>, f64 -> memref<?xf64>
```
+
+ ```mlir
+ %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val
+ {idx = 0 : index} : memref<?xindex>, memref<?xf64>, f64 -> memref<?xf64>
+ ```
}];
- let assemblyFormat = "$bufferSizes `,` $inBuffer `,` $value"
- " attr-dict `:` type($bufferSizes) `,`"
+ let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer"
+ " `,` $value attr-dict `:` type($bufferSizes) `,`"
" type($inBuffer) `,` type($value) `to`"
" type($outBuffer)";
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 7a1c6fa138063..8741f5f7d89b3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -350,6 +350,8 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
// buffer = new_buffer
// store(buffer, value)
// size(buffer)++
+ //
+ // The capacity check is skipped when the attribute inbounds is presented.
Location loc = op->getLoc();
Value c0 = constantIndex(rewriter, loc, 0);
Value buffer = op.getInBuffer();
@@ -357,28 +359,34 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
Value bufferSizes = op.getBufferSizes();
Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
- Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
- size, capacity);
Value value = op.getValue();
- auto bufferType =
- MemRefType::get({ShapedType::kDynamicSize}, value.getType());
- scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
- /*else=*/true);
- // True branch.
- rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value c2 = constantIndex(rewriter, loc, 2);
- capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
- Value newBuffer =
- rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
- rewriter.create<scf::YieldOp>(loc, newBuffer);
-
- // False branch.
- rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
- rewriter.create<scf::YieldOp>(loc, buffer);
+
+ if (!op.getInbounds()) {
+ Value cond = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::uge, size, capacity);
+
+ auto bufferType =
+ MemRefType::get({ShapedType::kDynamicSize}, value.getType());
+ scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
+ /*else=*/true);
+ // True branch.
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value c2 = constantIndex(rewriter, loc, 2);
+ capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
+ Value newBuffer =
+ rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
+ rewriter.create<scf::YieldOp>(loc, newBuffer);
+
+ // False branch.
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ rewriter.create<scf::YieldOp>(loc, buffer);
+
+ // Prepare for adding the value to the end of the buffer.
+ rewriter.setInsertionPointAfter(ifOp);
+ buffer = ifOp.getResult(0);
+ }
// Add the value to the end of the buffer.
- rewriter.setInsertionPointAfter(ifOp);
- buffer = ifOp.getResult(0);
rewriter.create<memref::StoreOp>(loc, value, buffer, size);
// Increment the size of the buffer by 1.
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 5aef2be365667..ccf9d40f59de2 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -26,6 +26,22 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
return %0 : memref<?xf64>
}
+// CHECK-LABEL: func @sparse_push_back_inbound(
+// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+// CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[P]]]
+// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
+// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
+// CHECK: return %[[B]] : memref<?xf64>
+func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
+ %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+ return %0 : memref<?xf64>
+}
+
// CHECK-LABEL: func.func private @_sparse_less_than_1_i8(
// CHECK-SAME: %[[I:arg0]]: index,
// CHECK-SAME: %[[J:.*]]: index,
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index e5ffe85284c70..ee69af3d256a2 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -145,6 +145,19 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
// -----
+// CHECK-LABEL: func @sparse_push_back_inbound(
+// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
+// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+// CHECK: return %[[D]]
+func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
+ %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+ return %0 : memref<?xf64>
+}
+
+// -----
+
#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
// CHECK-LABEL: func @sparse_expansion(
More information about the Mlir-commits
mailing list