[Mlir-commits] [mlir] d45be88 - [mlir][sparse] Implement the rewrite for sparse_tensor.push_back a value n times.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 31 08:19:18 PDT 2022
Author: bixia1
Date: 2022-10-31T08:19:12-07:00
New Revision: d45be8873628ce39e76dba6f4533bf96aa9f1985
URL: https://github.com/llvm/llvm-project/commit/d45be8873628ce39e76dba6f4533bf96aa9f1985
DIFF: https://github.com/llvm/llvm-project/commit/d45be8873628ce39e76dba6f4533bf96aa9f1985.diff
LOG: [mlir][sparse] Implement the rewrite for sparse_tensor.push_back a value n times.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D136654
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 09593c2dafe52..421706e171cbe 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -193,6 +193,7 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
let constructor = "mlir::createSparseBufferRewritePass()";
let dependentDialects = [
"arith::ArithDialect",
+ "linalg::LinalgDialect",
"memref::MemRefDialect",
"scf::SCFDialect",
"sparse_tensor::SparseTensorDialect",
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index bc09772ca6f43..929d4a4ddf1f3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -498,13 +499,17 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
using OpRewritePattern<PushBackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PushBackOp op,
PatternRewriter &rewriter) const override {
- // Rewrite push_back(buffer, value) to:
- // if (size(buffer) >= capacity(buffer))
- // new_capacity = capacity(buffer)*2
+ // Rewrite push_back(buffer, value, n) to:
+ // new_size = size(buffer) + n
+ // if (new_size > capacity(buffer))
+ // while new_size > new_capacity
+ // new_capacity = new_capacity*2
// new_buffer = realloc(buffer, new_capacity)
// buffer = new_buffer
- // store(buffer, value)
- // size(buffer)++
+ // subBuffer = subviewof(buffer)
+ // linalg.fill subBuffer value
+ //
+ // size(buffer) += n
//
// The capacity check is skipped when the attribute inbounds is presented.
Location loc = op->getLoc();
@@ -516,18 +521,50 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
Value value = op.getValue();
+ Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
+ Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
+ auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
+ bool nIsOne = (nValue && nValue.value() == 1);
+
if (!op.getInbounds()) {
Value cond = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::uge, size, capacity);
+ loc, arith::CmpIPredicate::ugt, newSize, capacity);
+ Value c2 = constantIndex(rewriter, loc, 2);
auto bufferType =
MemRefType::get({ShapedType::kDynamicSize}, value.getType());
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
/*else=*/true);
// True branch.
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value c2 = constantIndex(rewriter, loc, 2);
- capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
+ if (nIsOne) {
+ capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
+ } else {
+ // Use a do-while loop to calculate the new capacity as follows:
+ // do { new_capacity *= 2 } while (size > new_capacity)
+ scf::WhileOp whileOp =
+ rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
+
+ // The before-region of the WhileOp.
+ Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
+ {capacity.getType()}, {loc});
+ rewriter.setInsertionPointToEnd(before);
+
+ capacity =
+ rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
+ cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
+ newSize, capacity);
+ rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
+ // The after-region of the WhileOp.
+ Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
+ {capacity.getType()}, {loc});
+ rewriter.setInsertionPointToEnd(after);
+ rewriter.create<scf::YieldOp>(loc, after->getArguments());
+
+ rewriter.setInsertionPointAfter(whileOp);
+ capacity = whileOp.getResult(0);
+ }
+
Value newBuffer =
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
rewriter.create<scf::YieldOp>(loc, newBuffer);
@@ -542,13 +579,17 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
}
// Add the value to the end of the buffer.
- rewriter.create<memref::StoreOp>(loc, value, buffer, size);
-
- // Increment the size of the buffer by 1.
- Value c1 = constantIndex(rewriter, loc, 1);
- size = rewriter.create<arith::AddIOp>(loc, size, c1);
- rewriter.create<memref::StoreOp>(loc, size, bufferSizes, idx);
+ if (nIsOne) {
+ rewriter.create<memref::StoreOp>(loc, value, buffer, size);
+ } else {
+ Value subBuffer = rewriter.create<memref::SubViewOp>(
+ loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
+ /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
+ rewriter.create<linalg::FillOp>(loc, value, subBuffer);
+ }
+ // Update the buffer size.
+ rewriter.create<memref::StoreOp>(loc, newSize, bufferSizes, idx);
rewriter.replaceOp(op, buffer);
return success();
}
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 31c6ad5a0262e..114bfd874609f 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -7,19 +7,19 @@
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]]
-// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
-// CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]]
+// CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]]
+// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] : index
+// CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]]
// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
-// CHECK: %[[P1:.*]] = arith.muli %[[S]], %[[C2]]
-// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P1]])
+// CHECK: %[[P2:.*]] = arith.muli %[[P1]], %[[C2]]
+// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P2]])
// CHECK: scf.yield %[[M2]] : memref<?xf64>
// CHECK: } else {
// CHECK: scf.yield %[[B]] : memref<?xf64>
// CHECK: }
-// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]]
-// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
-// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
+// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[S1]]]
+// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]]
// CHECK: return %[[M]] : memref<?xf64>
func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
%0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
@@ -28,16 +28,52 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
// -----
+// CHECK-LABEL: func @sparse_push_back_n(
+// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[C:.*]]: f64,
+// CHECK-SAME: %[[D:.*]]: index) -> memref<?xf64> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]]
+// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[D]] : index
+// CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]]
+// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
+// CHECK: %[[P2:.*]] = scf.while (%[[I:.*]] = %[[P1]]) : (index) -> index {
+// CHECK: %[[P3:.*]] = arith.muli %[[I]], %[[C2]] : index
+// CHECK: %[[T2:.*]] = arith.cmpi ugt, %[[S2]], %[[P3]] : index
+// CHECK: scf.condition(%[[T2]]) %[[P3]] : index
+// CHECK: } do {
+// CHECK: ^bb0(%[[I2:.*]]: index):
+// CHECK: scf.yield %[[I2]] : index
+// CHECK: }
+// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P2]])
+// CHECK: scf.yield %[[M2]] : memref<?xf64>
+// CHECK: } else {
+// CHECK: scf.yield %[[B]] : memref<?xf64>
+// CHECK: }
+// CHECK: %[[S:.*]] = memref.subview %[[M]]{{\[}}%[[S1]]] {{\[}}%[[D]]] [1]
+// CHECK: linalg.fill ins(%[[C]] : f64) outs(%[[S]]
+// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]]
+// CHECK: return %[[M]] : memref<?xf64>
+func.func @sparse_push_back_n(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64, %arg3: index) -> memref<?xf64> {
+ %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64, index
+ return %0 : memref<?xf64>
+}
+
+// -----
+
// CHECK-LABEL: func @sparse_push_back_inbound(
// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
-// CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[P]]]
-// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
-// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
+// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]]
+// CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[S1]]]
+// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]]
// CHECK: return %[[B]] : memref<?xf64>
func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
%0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
index 90d1b37606956..6a88e29cafe27 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
@@ -8,6 +8,7 @@ module {
func.func @entry() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
%d0 = arith.constant 0.0 : f32
%d1 = arith.constant 1.0 : f32
%d2 = arith.constant 2.0 : f32
@@ -17,15 +18,19 @@ module {
memref.store %c0, %bufferSizes[%c0] : memref<?xindex>
%buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32
- %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32
+ %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1, %c10 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32, index
- // CHECK: ( 2 )
- %sizeValue = vector.transfer_read %bufferSizes[%c0], %c0: memref<?xindex>, vector<1xindex>
- vector.print %sizeValue : vector<1xindex>
+ // CHECK: 16
+ %capacity = memref.dim %buffer3, %c0 : memref<?xf32>
+ vector.print %capacity : index
- // CHECK ( 2, 1 )
- %bufferValue = vector.transfer_read %buffer3[%c0], %d0: memref<?xf32>, vector<2xf32>
- vector.print %bufferValue : vector<2xf32>
+ // CHECK: ( 11 )
+ %size = vector.transfer_read %bufferSizes[%c0], %c0: memref<?xindex>, vector<1xindex>
+ vector.print %size : vector<1xindex>
+
+ // CHECK ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+ %values = vector.transfer_read %buffer3[%c0], %d0: memref<?xf32>, vector<11xf32>
+ vector.print %values : vector<11xf32>
// Release the buffers.
memref.dealloc %bufferSizes : memref<?xindex>
More information about the Mlir-commits
mailing list