[Mlir-commits] [mlir] 4132bce - [mlir][sparse] Add codegen rule for the push_back operator.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 22 09:09:57 PDT 2022


Author: bixia1
Date: 2022-09-22T09:09:49-07:00
New Revision: 4132bce9e56b00cdce8928e4ea67b136c93f46a2

URL: https://github.com/llvm/llvm-project/commit/4132bce9e56b00cdce8928e4ea67b136c93f46a2
DIFF: https://github.com/llvm/llvm-project/commit/4132bce9e56b00cdce8928e4ea67b136c93f46a2.diff

LOG: [mlir][sparse] Add codegen rule for the push_back operator.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134372

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0ad44b56e687..7dcc81c3c3ee 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -613,6 +613,61 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
   }
 };
 
+/// Sparse codegen rule for the push_back operator.
+class SparsePushBackConverter : public OpConversionPattern<PushBackOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(PushBackOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Lower push_back(buffer, value) to:
+    // if (size(buffer) >= capacity(buffer))
+    //    new_capacity = capacity(buffer)*2
+    //    new_buffer = realloc(buffer, new_capacity)
+    // buffer = new_buffer
+    // store(buffer, value)
+    // size(buffer)++
+    Location loc = op->getLoc();
+    Value c0 = constantIndex(rewriter, loc, 0);
+    Value buffer = adaptor.getInBuffer();
+    Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
+    Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
+    Value bufferSizes = adaptor.getBufferSizes();
+    Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
+    Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
+                                                size, capacity);
+    Value value = adaptor.getValue();
+    auto bufferType =
+        MemRefType::get({ShapedType::kDynamicSize}, value.getType());
+    scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
+                                                /*else=*/true);
+    // True branch.
+    rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    Value c2 = constantIndex(rewriter, loc, 2);
+    capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
+    Value newBuffer =
+        rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
+    rewriter.create<scf::YieldOp>(loc, newBuffer);
+
+    // False branch.
+    rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+    rewriter.create<scf::YieldOp>(loc, buffer);
+
+    // Add the value to the end of the buffer.
+    rewriter.setInsertionPointAfter(ifOp);
+    buffer = ifOp.getResult(0);
+    rewriter.create<memref::StoreOp>(loc, value, buffer, size);
+
+    // Increment the size of the buffer by 1.
+    Value c1 = constantIndex(rewriter, loc, 1);
+    size = rewriter.create<arith::AddIOp>(loc, size, c1);
+    rewriter.create<memref::StoreOp>(loc, size, bufferSizes, idx);
+
+    rewriter.replaceOp(op, buffer);
+    return success();
+  }
+};
+
 /// Base class for getter-like operations, e.g., to_indices, to_pointers.
 template <typename SourceOp, typename Base>
 class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
@@ -697,6 +752,7 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                SparseCastConverter, SparseTensorAllocConverter,
                SparseTensorDeallocConverter, SparseTensorLoadConverter,
                SparseExpandConverter, SparseCompressConverter,
-               SparseToPointersConverter, SparseToIndicesConverter,
-               SparseToValuesConverter>(typeConverter, patterns.getContext());
+               SparsePushBackConverter, SparseToPointersConverter,
+               SparseToIndicesConverter, SparseToValuesConverter>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 7e241de59806..207eaa1324d4 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -384,3 +384,29 @@ func.func @sparse_compression(%arg0: tensor<8x8xf64, #CSR>,
     : tensor<8x8xf64, #CSR>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
   return
 }
+
+// CHECK-LABEL: func @sparse_push_back(
+//  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+//  CHECK-SAME: %[[B:.*]]: memref<?xf64>,
+//  CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//       CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]]
+//       CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+//       CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]]
+//       CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
+//       CHECK:  %[[P1:.*]] = arith.muli %[[S]], %[[C2]]
+//       CHECK:  %[[M2:.*]] = memref.realloc %[[B]](%[[P1]])
+//       CHECK:  scf.yield %[[M2]] : memref<?xf64>
+//       CHECK: } else {
+//       CHECK:  scf.yield %[[B]] : memref<?xf64>
+//       CHECK: }
+//       CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]]
+//       CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
+//       CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
+//       CHECK: return %[[M]] : memref<?xf64>
+func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
+  %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+  return %0 : memref<?xf64>
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir
new file mode 100644
index 000000000000..ff57bfee527d
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+module {
+  func.func @entry() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %d0 = arith.constant 0.0 : f32
+    %d1 = arith.constant 1.0 : f32
+    %d2 = arith.constant 2.0 : f32
+
+    %bufferSizes = memref.alloc(%c1) : memref<?xindex>
+    %buffer = memref.alloc(%c1) : memref<?xf32>
+
+    memref.store %c0, %bufferSizes[%c0] : memref<?xindex>
+    %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32 to memref<?xf32>
+    %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32 to memref<?xf32>
+
+    // CHECK: ( 2 )
+    %sizeValue = vector.transfer_read %bufferSizes[%c0], %c0: memref<?xindex>, vector<1xindex>
+    vector.print %sizeValue : vector<1xindex>
+
+    // CHECK ( 2, 1 )
+    %bufferValue = vector.transfer_read %buffer3[%c0], %d0: memref<?xf32>, vector<2xf32>
+    vector.print %bufferValue : vector<2xf32>
+
+    // Release the buffers.
+    memref.dealloc %bufferSizes : memref<?xindex>
+    memref.dealloc %buffer3 : memref<?xf32>
+    return
+  }
+}
+


        


More information about the Mlir-commits mailing list