[Mlir-commits] [mlir] 4132bce - [mlir][sparse] Add codegen rule for the push_back operator.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 22 09:09:57 PDT 2022
Author: bixia1
Date: 2022-09-22T09:09:49-07:00
New Revision: 4132bce9e56b00cdce8928e4ea67b136c93f46a2
URL: https://github.com/llvm/llvm-project/commit/4132bce9e56b00cdce8928e4ea67b136c93f46a2
DIFF: https://github.com/llvm/llvm-project/commit/4132bce9e56b00cdce8928e4ea67b136c93f46a2.diff
LOG: [mlir][sparse] Add codegen rule for the push_back operator.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D134372
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0ad44b56e687..7dcc81c3c3ee 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -613,6 +613,61 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
}
};
+/// Sparse codegen rule for the push_back operator.
+class SparsePushBackConverter : public OpConversionPattern<PushBackOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(PushBackOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Lower push_back(buffer, value) to:
+ // if (size(buffer) >= capacity(buffer))
+ // new_capacity = capacity(buffer)*2
+ // new_buffer = realloc(buffer, new_capacity)
+ // buffer = new_buffer
+ // store(buffer, value)
+ // size(buffer)++
+ Location loc = op->getLoc();
+ Value c0 = constantIndex(rewriter, loc, 0);
+ Value buffer = adaptor.getInBuffer();
+ Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
+ Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
+ Value bufferSizes = adaptor.getBufferSizes();
+ Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
+ Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
+ size, capacity);
+ Value value = adaptor.getValue();
+ auto bufferType =
+ MemRefType::get({ShapedType::kDynamicSize}, value.getType());
+ scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
+ /*else=*/true);
+ // True branch.
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value c2 = constantIndex(rewriter, loc, 2);
+ capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
+ Value newBuffer =
+ rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
+ rewriter.create<scf::YieldOp>(loc, newBuffer);
+
+ // False branch.
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ rewriter.create<scf::YieldOp>(loc, buffer);
+
+ // Add the value to the end of the buffer.
+ rewriter.setInsertionPointAfter(ifOp);
+ buffer = ifOp.getResult(0);
+ rewriter.create<memref::StoreOp>(loc, value, buffer, size);
+
+ // Increment the size of the buffer by 1.
+ Value c1 = constantIndex(rewriter, loc, 1);
+ size = rewriter.create<arith::AddIOp>(loc, size, c1);
+ rewriter.create<memref::StoreOp>(loc, size, bufferSizes, idx);
+
+ rewriter.replaceOp(op, buffer);
+ return success();
+ }
+};
+
/// Base class for getter-like operations, e.g., to_indices, to_pointers.
template <typename SourceOp, typename Base>
class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
@@ -697,6 +752,7 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
SparseCastConverter, SparseTensorAllocConverter,
SparseTensorDeallocConverter, SparseTensorLoadConverter,
SparseExpandConverter, SparseCompressConverter,
- SparseToPointersConverter, SparseToIndicesConverter,
- SparseToValuesConverter>(typeConverter, patterns.getContext());
+ SparsePushBackConverter, SparseToPointersConverter,
+ SparseToIndicesConverter, SparseToValuesConverter>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 7e241de59806..207eaa1324d4 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -384,3 +384,29 @@ func.func @sparse_compression(%arg0: tensor<8x8xf64, #CSR>,
: tensor<8x8xf64, #CSR>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
return
}
+
+// CHECK-LABEL: func @sparse_push_back(
+// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
+// CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]]
+// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+// CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]]
+// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
+// CHECK: %[[P1:.*]] = arith.muli %[[S]], %[[C2]]
+// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P1]])
+// CHECK: scf.yield %[[M2]] : memref<?xf64>
+// CHECK: } else {
+// CHECK: scf.yield %[[B]] : memref<?xf64>
+// CHECK: }
+// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]]
+// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
+// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
+// CHECK: return %[[M]] : memref<?xf64>
+func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
+ %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+ return %0 : memref<?xf64>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir
new file mode 100644
index 000000000000..ff57bfee527d
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+module {
+ func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = arith.constant 0.0 : f32
+ %d1 = arith.constant 1.0 : f32
+ %d2 = arith.constant 2.0 : f32
+
+ %bufferSizes = memref.alloc(%c1) : memref<?xindex>
+ %buffer = memref.alloc(%c1) : memref<?xf32>
+
+ memref.store %c0, %bufferSizes[%c0] : memref<?xindex>
+ %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32 to memref<?xf32>
+ %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32 to memref<?xf32>
+
+ // CHECK: ( 2 )
+ %sizeValue = vector.transfer_read %bufferSizes[%c0], %c0: memref<?xindex>, vector<1xindex>
+ vector.print %sizeValue : vector<1xindex>
+
+ // CHECK ( 2, 1 )
+ %bufferValue = vector.transfer_read %buffer3[%c0], %d0: memref<?xf32>, vector<2xf32>
+ vector.print %bufferValue : vector<2xf32>
+
+ // Release the buffers.
+ memref.dealloc %bufferSizes : memref<?xindex>
+ memref.dealloc %buffer3 : memref<?xf32>
+ return
+ }
+}
+
More information about the Mlir-commits
mailing list