[Mlir-commits] [mlir] 654bbbd - [mlir][sparse] Move the implementation of sparse_tensor.push_back to the buffer rewriter.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 29 15:06:09 PDT 2022
Author: bixia1
Date: 2022-09-29T15:06:00-07:00
New Revision: 654bbbde55151a9617d535215bb366b4979354c7
URL: https://github.com/llvm/llvm-project/commit/654bbbde55151a9617d535215bb366b4979354c7
DIFF: https://github.com/llvm/llvm-project/commit/654bbbde55151a9617d535215bb366b4979354c7.diff
LOG: [mlir][sparse] Move the implementation of sparse_tensor.push_back to the buffer rewriter.
Reviewed By: aartbik, Peiming
Differential Revision: https://reviews.llvm.org/D134777
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Dialect/SparseTensor/codegen.mlir
Removed:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 137f82f0abc08..7a1c6fa138063 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -337,6 +337,60 @@ static void createSortFunc(OpBuilder &builder, ModuleOp module,
namespace {
+/// Sparse rewriting rule for the push_back operator.
+struct PushBackRewriter : OpRewritePattern<PushBackOp> {
+public:
+ using OpRewritePattern<PushBackOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(PushBackOp op,
+ PatternRewriter &rewriter) const override {
+ // Rewrite push_back(buffer, value) to:
+ // if (size(buffer) >= capacity(buffer))
+ // new_capacity = capacity(buffer)*2
+ // new_buffer = realloc(buffer, new_capacity)
+ // buffer = new_buffer
+ // store(buffer, value)
+ // size(buffer)++
+ Location loc = op->getLoc();
+ Value c0 = constantIndex(rewriter, loc, 0);
+ Value buffer = op.getInBuffer();
+ Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
+ Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
+ Value bufferSizes = op.getBufferSizes();
+ Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
+ Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
+ size, capacity);
+ Value value = op.getValue();
+ auto bufferType =
+ MemRefType::get({ShapedType::kDynamicSize}, value.getType());
+ scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
+ /*else=*/true);
+ // True branch.
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value c2 = constantIndex(rewriter, loc, 2);
+ capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
+ Value newBuffer =
+ rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
+ rewriter.create<scf::YieldOp>(loc, newBuffer);
+
+ // False branch.
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ rewriter.create<scf::YieldOp>(loc, buffer);
+
+ // Add the value to the end of the buffer.
+ rewriter.setInsertionPointAfter(ifOp);
+ buffer = ifOp.getResult(0);
+ rewriter.create<memref::StoreOp>(loc, value, buffer, size);
+
+ // Increment the size of the buffer by 1.
+ Value c1 = constantIndex(rewriter, loc, 1);
+ size = rewriter.create<arith::AddIOp>(loc, size, c1);
+ rewriter.create<memref::StoreOp>(loc, size, bufferSizes, idx);
+
+ rewriter.replaceOp(op, buffer);
+ return success();
+ }
+};
+
/// Sparse rewriting rule for the sort operator.
struct SortRewriter : public OpRewritePattern<SortOp> {
public:
@@ -378,5 +432,5 @@ struct SortRewriter : public OpRewritePattern<SortOp> {
//===---------------------------------------------------------------------===//
void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) {
- patterns.add<SortRewriter>(patterns.getContext());
+ patterns.add<PushBackRewriter, SortRewriter>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d8e5d20baaaf3..e40c8060556e7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -564,61 +564,6 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
}
};
-/// Sparse codegen rule for the push_back operator.
-class SparsePushBackConverter : public OpConversionPattern<PushBackOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(PushBackOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Lower push_back(buffer, value) to:
- // if (size(buffer) >= capacity(buffer))
- // new_capacity = capacity(buffer)*2
- // new_buffer = realloc(buffer, new_capacity)
- // buffer = new_buffer
- // store(buffer, value)
- // size(buffer)++
- Location loc = op->getLoc();
- Value c0 = constantIndex(rewriter, loc, 0);
- Value buffer = adaptor.getInBuffer();
- Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
- Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
- Value bufferSizes = adaptor.getBufferSizes();
- Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
- Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
- size, capacity);
- Value value = adaptor.getValue();
- auto bufferType =
- MemRefType::get({ShapedType::kDynamicSize}, value.getType());
- scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
- /*else=*/true);
- // True branch.
- rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value c2 = constantIndex(rewriter, loc, 2);
- capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
- Value newBuffer =
- rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
- rewriter.create<scf::YieldOp>(loc, newBuffer);
-
- // False branch.
- rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
- rewriter.create<scf::YieldOp>(loc, buffer);
-
- // Add the value to the end of the buffer.
- rewriter.setInsertionPointAfter(ifOp);
- buffer = ifOp.getResult(0);
- rewriter.create<memref::StoreOp>(loc, value, buffer, size);
-
- // Increment the size of the buffer by 1.
- Value c1 = constantIndex(rewriter, loc, 1);
- size = rewriter.create<arith::AddIOp>(loc, size, c1);
- rewriter.create<memref::StoreOp>(loc, size, bufferSizes, idx);
-
- rewriter.replaceOp(op, buffer);
- return success();
- }
-};
-
/// Base class for getter-like operations, e.g., to_indices, to_pointers.
template <typename SourceOp, typename Base>
class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
@@ -703,7 +648,6 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
SparseCastConverter, SparseTensorAllocConverter,
SparseTensorDeallocConverter, SparseTensorLoadConverter,
SparseExpandConverter, SparseCompressConverter,
- SparsePushBackConverter, SparseToPointersConverter,
- SparseToIndicesConverter, SparseToValuesConverter>(
- typeConverter, patterns.getContext());
+ SparseToPointersConverter, SparseToIndicesConverter,
+ SparseToValuesConverter>(typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 51bcc6aed73a9..b208dfeb5558b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -160,6 +160,7 @@ struct SparseTensorCodegenPass
// Most ops in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addLegalOp<SortOp>();
+ target.addLegalOp<PushBackOp>();
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index e40064b16ec87..5aef2be365667 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -1,5 +1,31 @@
// RUN: mlir-opt %s --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s
+// CHECK-LABEL: func @sparse_push_back(
+// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]]
+// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+// CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]]
+// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
+// CHECK: %[[P1:.*]] = arith.muli %[[S]], %[[C2]]
+// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P1]])
+// CHECK: scf.yield %[[M2]] : memref<?xf64>
+// CHECK: } else {
+// CHECK: scf.yield %[[B]] : memref<?xf64>
+// CHECK: }
+// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]]
+// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
+// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
+// CHECK: return %[[M]] : memref<?xf64>
+func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
+ %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+ return %0 : memref<?xf64>
+}
+
// CHECK-LABEL: func.func private @_sparse_less_than_1_i8(
// CHECK-SAME: %[[I:arg0]]: index,
// CHECK-SAME: %[[J:.*]]: index,
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 1263d3efc3c95..5bc7535cc13ed 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -425,29 +425,3 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #UCSR>
return
}
-
-// CHECK-LABEL: func @sparse_push_back(
-// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
-// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
-// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]]
-// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
-// CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]]
-// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
-// CHECK: %[[P1:.*]] = arith.muli %[[S]], %[[C2]]
-// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P1]])
-// CHECK: scf.yield %[[M2]] : memref<?xf64>
-// CHECK: } else {
-// CHECK: scf.yield %[[B]] : memref<?xf64>
-// CHECK: }
-// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]]
-// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
-// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
-// CHECK: return %[[M]] : memref<?xf64>
-func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
- %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
- return %0 : memref<?xf64>
-}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
similarity index 100%
rename from mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir
rename to mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir
More information about the Mlir-commits
mailing list