[Mlir-commits] [mlir] e445349 - [mlir][sparse] Add rewrite rule for the sparse_tensor.out operator.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 21 12:20:59 PDT 2022

Author: bixia1
Date: 2022-10-21T12:20:53-07:00
New Revision: e445349d2cca15cf4581810ce298172a3939c453

URL: https://github.com/llvm/llvm-project/commit/e445349d2cca15cf4581810ce298172a3939c453
DIFF: https://github.com/llvm/llvm-project/commit/e445349d2cca15cf4581810ce298172a3939c453.diff

LOG: [mlir][sparse] Add rewrite rule for the sparse_tensor.out operator.

Also fix the rewrite rule for sparse_tensor.new to reflect the recent change of
the runtime C interface and to use utilities for memref.alloca.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135891




diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fc5fb767f516..b793495489db 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -536,12 +536,10 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
                                   {opaqueTp}, {fileName}, EmitCInterface::Off)
-    // Allocate a buffer for storing dimension sizes and indices.
+    // Allocate a temporary buffer for storing dimension sizes and indices.
     Type indexTp = rewriter.getIndexType();
-    auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
     uint64_t rank = dstTp.getRank();
-    Value dimSizes = rewriter.create<memref::AllocOp>(
-        loc, memTp, ValueRange{constantIndex(rewriter, loc, rank)});
+    Value dimSizes = genAlloca(rewriter, loc, rank, indexTp);
     // If the result tensor has dynamic dimensions, get the dynamic sizes from
     // the sparse tensor reader.
@@ -575,26 +573,27 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
     Value nnz = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ",
                                {indexTp}, {reader}, EmitCInterface::Off)
+    Type eltTp = dstTp.getElementType();
+    Value value = genAllocaScalar(rewriter, loc, eltTp);
     scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, c0, nnz, c1);
-    Type eltTp = dstTp.getElementType();
     SmallString<18> getNextFuncName{"getSparseTensorReaderNext",
     Value indices = dimSizes; // Reuse the indices memref to store indices.
-    Value value = createFuncCall(rewriter, loc, getNextFuncName, {eltTp},
-                                 {reader, indices}, EmitCInterface::On)
-                      .getResult(0);
+    createFuncCall(rewriter, loc, getNextFuncName, {eltTp},
+                   {reader, indices, value}, EmitCInterface::On)
+        .getResult(0);
     SmallVector<Value, 4> indicesArray;
     for (uint64_t i = 0; i < rank; i++) {
           loc, indices, constantIndex(rewriter, loc, i)));
-    rewriter.create<InsertOp>(loc, value, cooBuffer, indicesArray);
+    Value v = rewriter.create<memref::LoadOp>(loc, value);
+    rewriter.create<InsertOp>(loc, v, cooBuffer, indicesArray);
-    // Release the indices buffer and the sparse tensor reader.
-    rewriter.create<memref::DeallocOp>(loc, indices);
+    // Release the sparse tensor reader.
     createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
@@ -608,6 +607,70 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
+struct OutRewriter : public OpRewritePattern<OutOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(OutOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // Calculate NNZ.
+    Value src = op.getTensor();
+    Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
+    // Allocate a temporary buffer for storing dimension sizes and indices.
+    auto srcTp = src.getType().template cast<RankedTensorType>();
+    uint64_t rank = srcTp.getRank();
+    Type indexTp = rewriter.getIndexType();
+    Value dimSizes = genAlloca(rewriter, loc, rank, indexTp);
+    // Generate code to calculate dimension size values and store the values to
+    // the buffer.
+    SmallVector<Value, 4> dims;
+    sizesForTensor(rewriter, dims, loc, srcTp, src);
+    for (int64_t i = 0; i < rank; i++) {
+      rewriter.create<memref::StoreOp>(loc, dims[i], dimSizes,
+                                       constantIndex(rewriter, loc, i));
+    }
+    // Create a sparse tensor writer and output meta data.
+    Type opaqueTp = getOpaquePointerType(rewriter);
+    Value writer =
+        createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
+                       {op.getDest()}, EmitCInterface::Off)
+            .getResult(0);
+    Value rankValue = constantIndex(rewriter, loc, rank);
+    createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
+                   {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
+    Value indices = dimSizes; // Reuse the dimSizes buffer for indices.
+    Type eltTp = srcTp.getElementType();
+    SmallString<18> outNextFuncName{"outSparseTensorWriterNext",
+                                    primaryTypeFunctionSuffix(eltTp)};
+    Value value = genAllocaScalar(rewriter, loc, eltTp);
+    ModuleOp module = op->getParentOfType<ModuleOp>();
+    // For each element in the source tensor, output the element.
+    rewriter.create<ForeachOp>(
+        loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+          for (int64_t i = 0; i < rank; i++) {
+            rewriter.create<memref::StoreOp>(loc, args[i], indices,
+                                             constantIndex(builder, loc, i));
+          }
+          rewriter.create<memref::StoreOp>(loc, args.back(), value);
+          SmallVector<Value, 4> operands{writer, rankValue, indices, value};
+          FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
+                                         EmitCInterface::On);
+          builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
+          builder.create<sparse_tensor::YieldOp>(loc);
+        });
+    // Release the writer.
+    createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
+                   EmitCInterface::Off);
+    rewriter.eraseOp(op);
+    return success();
+  }
 } // namespace
@@ -624,7 +687,7 @@ void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
   // TODO: If RT not enabled, rewrite concatenate ops, etc here.
   if (!enableRT)
-    patterns.add<ConcatenateRewriter, NewRewriter,
+    patterns.add<ConcatenateRewriter, NewRewriter, OutRewriter,

diff  --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
index d77f8edba0f5..3d2a5e2f50c1 100644
--- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparse-tensor-rewrite=enable-runtime-library=false  | FileCheck %s
+// RUN: mlir-opt %s -sparse-tensor-rewrite=enable-runtime-library=false | FileCheck %s
 #CSR = #sparse_tensor.encoding<{
   dimLevelType = ["dense", "compressed"]
@@ -10,19 +10,20 @@
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK:         %[[R:.*]] = call @createSparseTensorReader(%[[A]])
-// CHECK:         %[[DS:.*]] = memref.alloc(%[[C2]]) : memref<?xindex>
+// CHECK:         %[[DS:.*]] = memref.alloca(%[[C2]]) : memref<?xindex>
 // CHECK:         call @getSparseTensorReaderDimSizes(%[[R]], %[[DS]])
 // CHECK:         %[[D0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
 // CHECK:         %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
 // CHECK:         %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]])
 // CHECK:         %[[N:.*]] = call @getSparseTensorReaderNNZ(%[[R]])
+// CHECK:         %[[VB:.*]] = memref.alloca()
 // CHECK:         scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] {
-// CHECK:           %[[V:.*]] = func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]])
+// CHECK:           func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]], %[[VB]])
 // CHECK:           %[[E0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
 // CHECK:           %[[E1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
+// CHECK:           %[[V:.*]] = memref.load %[[VB]][]
 // CHECK:           sparse_tensor.insert %[[V]] into %[[T]]{{\[}}%[[E0]], %[[E1]]]
 // CHECK:         }
-// CHECK:         memref.dealloc %[[DS]]
 // CHECK:         call @delSparseTensorReader(%[[R]])
 // CHECK:         %[[R:.*]] = sparse_tensor.convert %[[T]]
 // CHECK:         bufferization.dealloc_tensor %[[T]]
@@ -32,3 +33,31 @@ func.func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #CSR>
   return %0 : tensor<?x?xf32, #CSR>
+// CHECK-LABEL:   func.func @sparse_out(
+// CHECK-SAME:    %[[A:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>,
+// CHECK-SAME:    %[[B:.*]]: !llvm.ptr<i8>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
+// CHECK-DAG:     %[[C20:.*]] = arith.constant 20 : index
+// CHECK:         %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[A]]
+// CHECK:         %[[DS:.*]] = memref.alloca(%[[C2]]) : memref<?xindex>
+// CHECK:         memref.store %[[C10]], %[[DS]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK:         memref.store %[[C20]], %[[DS]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK:         %[[W:.*]] = call @createSparseTensorWriter(%[[B]])
+// CHECK:         call @outSparseTensorWriterMetaData(%[[W]], %[[C2]], %[[NNZ]], %[[DS]])
+// CHECK:         %[[V:.*]] = memref.alloca() : memref<f32>
+// CHECK:         scf.for  %{{.*}} = %[[C0]] to %[[C10]] step %[[C1]] {
+// CHECK:           scf.for  {{.*}} {
+// CHECK:             func.call @outSparseTensorWriterNextF32(%[[W]], %[[C2]], %[[DS]], %[[V]])
+// CHECK:           }
+// CHECK:         }
+// CHECK:         call @delSparseTensorWriter(%[[W]])
+// CHECK:         return
+// CHECK:         }
+func.func @sparse_out( %arg0: tensor<10x20xf32, #CSR>, %arg1: !llvm.ptr<i8>) -> () {
+  sparse_tensor.out %arg0, %arg1 : tensor<10x20xf32, #CSR>, !llvm.ptr<i8> 
+  return


More information about the Mlir-commits mailing list