[Mlir-commits] [mlir] e445349 - [mlir][sparse] Add rewrite rule for the sparse_tensor.out operator.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 21 12:20:59 PDT 2022
Author: bixia1
Date: 2022-10-21T12:20:53-07:00
New Revision: e445349d2cca15cf4581810ce298172a3939c453
URL: https://github.com/llvm/llvm-project/commit/e445349d2cca15cf4581810ce298172a3939c453
DIFF: https://github.com/llvm/llvm-project/commit/e445349d2cca15cf4581810ce298172a3939c453.diff
LOG: [mlir][sparse] Add rewrite rule for the sparse_tensor.out operator.
Also fix the rewrite rule for sparse_tensor.new to reflect the recent change of
the runtime C interface and to use utilities for memref.alloca.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D135891
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fc5fb767f516..b793495489db 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -536,12 +536,10 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
{opaqueTp}, {fileName}, EmitCInterface::Off)
.getResult(0);
- // Allocate a buffer for storing dimension sizes and indices.
+ // Allocate a temporary buffer for storing dimension sizes and indices.
Type indexTp = rewriter.getIndexType();
- auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
uint64_t rank = dstTp.getRank();
- Value dimSizes = rewriter.create<memref::AllocOp>(
- loc, memTp, ValueRange{constantIndex(rewriter, loc, rank)});
+ Value dimSizes = genAlloca(rewriter, loc, rank, indexTp);
// If the result tensor has dynamic dimensions, get the dynamic sizes from
// the sparse tensor reader.
@@ -575,26 +573,27 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
Value nnz = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ",
{indexTp}, {reader}, EmitCInterface::Off)
.getResult(0);
+ Type eltTp = dstTp.getElementType();
+ Value value = genAllocaScalar(rewriter, loc, eltTp);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, c0, nnz, c1);
rewriter.setInsertionPointToStart(forOp.getBody());
- Type eltTp = dstTp.getElementType();
SmallString<18> getNextFuncName{"getSparseTensorReaderNext",
primaryTypeFunctionSuffix(eltTp)};
Value indices = dimSizes; // Reuse the indices memref to store indices.
- Value value = createFuncCall(rewriter, loc, getNextFuncName, {eltTp},
- {reader, indices}, EmitCInterface::On)
- .getResult(0);
+ createFuncCall(rewriter, loc, getNextFuncName, {eltTp},
+ {reader, indices, value}, EmitCInterface::On)
+ .getResult(0);
SmallVector<Value, 4> indicesArray;
for (uint64_t i = 0; i < rank; i++) {
indicesArray.push_back(rewriter.create<memref::LoadOp>(
loc, indices, constantIndex(rewriter, loc, i)));
}
- rewriter.create<InsertOp>(loc, value, cooBuffer, indicesArray);
+ Value v = rewriter.create<memref::LoadOp>(loc, value);
+ rewriter.create<InsertOp>(loc, v, cooBuffer, indicesArray);
rewriter.setInsertionPointAfter(forOp);
- // Release the indices buffer and the sparse tensor reader.
- rewriter.create<memref::DeallocOp>(loc, indices);
+ // Release the sparse tensor reader.
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
EmitCInterface::Off);
@@ -608,6 +607,70 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
}
};
+struct OutRewriter : public OpRewritePattern<OutOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(OutOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // Calculate NNZ.
+ Value src = op.getTensor();
+ Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
+
+ // Allocate a temporary buffer for storing dimension sizes and indices.
+ auto srcTp = src.getType().template cast<RankedTensorType>();
+ uint64_t rank = srcTp.getRank();
+ Type indexTp = rewriter.getIndexType();
+ Value dimSizes = genAlloca(rewriter, loc, rank, indexTp);
+
+ // Generate code to calculate dimension size values and store the values to
+ // the buffer.
+ SmallVector<Value, 4> dims;
+ sizesForTensor(rewriter, dims, loc, srcTp, src);
+ for (int64_t i = 0; i < rank; i++) {
+ rewriter.create<memref::StoreOp>(loc, dims[i], dimSizes,
+ constantIndex(rewriter, loc, i));
+ }
+
+ // Create a sparse tensor writer and output meta data.
+ Type opaqueTp = getOpaquePointerType(rewriter);
+ Value writer =
+ createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
+ {op.getDest()}, EmitCInterface::Off)
+ .getResult(0);
+ Value rankValue = constantIndex(rewriter, loc, rank);
+ createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
+ {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
+
+ Value indices = dimSizes; // Reuse the dimSizes buffer for indices.
+ Type eltTp = srcTp.getElementType();
+ SmallString<18> outNextFuncName{"outSparseTensorWriterNext",
+ primaryTypeFunctionSuffix(eltTp)};
+ Value value = genAllocaScalar(rewriter, loc, eltTp);
+ ModuleOp module = op->getParentOfType<ModuleOp>();
+ // For each element in the source tensor, output the element.
+ rewriter.create<ForeachOp>(
+ loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+ for (int64_t i = 0; i < rank; i++) {
+ rewriter.create<memref::StoreOp>(loc, args[i], indices,
+ constantIndex(builder, loc, i));
+ }
+ rewriter.create<memref::StoreOp>(loc, args.back(), value);
+ SmallVector<Value, 4> operands{writer, rankValue, indices, value};
+ FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
+ EmitCInterface::On);
+ builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
+ builder.create<sparse_tensor::YieldOp>(loc);
+ });
+
+ // Release the writer.
+ createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
+ EmitCInterface::Off);
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
} // namespace
//===---------------------------------------------------------------------===//
@@ -624,7 +687,7 @@ void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
if (!enableRT)
- patterns.add<ConcatenateRewriter, NewRewriter,
+ patterns.add<ConcatenateRewriter, NewRewriter, OutRewriter,
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>>(
patterns.getContext());
diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
index d77f8edba0f5..3d2a5e2f50c1 100644
--- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparse-tensor-rewrite=enable-runtime-library=false | FileCheck %s
+// RUN: mlir-opt %s -sparse-tensor-rewrite=enable-runtime-library=false | FileCheck %s
#CSR = #sparse_tensor.encoding<{
dimLevelType = ["dense", "compressed"]
@@ -10,19 +10,20 @@
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[R:.*]] = call @createSparseTensorReader(%[[A]])
-// CHECK: %[[DS:.*]] = memref.alloc(%[[C2]]) : memref<?xindex>
+// CHECK: %[[DS:.*]] = memref.alloca(%[[C2]]) : memref<?xindex>
// CHECK: call @getSparseTensorReaderDimSizes(%[[R]], %[[DS]])
// CHECK: %[[D0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
// CHECK: %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
// CHECK: %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]])
// CHECK: %[[N:.*]] = call @getSparseTensorReaderNNZ(%[[R]])
+// CHECK: %[[VB:.*]] = memref.alloca()
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] {
-// CHECK: %[[V:.*]] = func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]])
+// CHECK: func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]], %[[VB]])
// CHECK: %[[E0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
// CHECK: %[[E1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
+// CHECK: %[[V:.*]] = memref.load %[[VB]][]
// CHECK: sparse_tensor.insert %[[V]] into %[[T]]{{\[}}%[[E0]], %[[E1]]]
// CHECK: }
-// CHECK: memref.dealloc %[[DS]]
// CHECK: call @delSparseTensorReader(%[[R]])
// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T]]
// CHECK: bufferization.dealloc_tensor %[[T]]
@@ -32,3 +33,31 @@ func.func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #CSR>
return %0 : tensor<?x?xf32, #CSR>
}
+
+// CHECK-LABEL: func.func @sparse_out(
+// CHECK-SAME: %[[A:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>,
+// CHECK-SAME: %[[B:.*]]: !llvm.ptr<i8>) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
+// CHECK: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[A]]
+// CHECK: %[[DS:.*]] = memref.alloca(%[[C2]]) : memref<?xindex>
+// CHECK: memref.store %[[C10]], %[[DS]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK: memref.store %[[C20]], %[[DS]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK: %[[W:.*]] = call @createSparseTensorWriter(%[[B]])
+// CHECK: call @outSparseTensorWriterMetaData(%[[W]], %[[C2]], %[[NNZ]], %[[DS]])
+// CHECK: %[[V:.*]] = memref.alloca() : memref<f32>
+// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C10]] step %[[C1]] {
+// CHECK: scf.for {{.*}} {
+// CHECK: func.call @outSparseTensorWriterNextF32(%[[W]], %[[C2]], %[[DS]], %[[V]])
+// CHECK: }
+// CHECK: }
+// CHECK: call @delSparseTensorWriter(%[[W]])
+// CHECK: return
+// CHECK: }
+func.func @sparse_out( %arg0: tensor<10x20xf32, #CSR>, %arg1: !llvm.ptr<i8>) -> () {
+ sparse_tensor.out %arg0, %arg1 : tensor<10x20xf32, #CSR>, !llvm.ptr<i8>
+ return
+}
More information about the Mlir-commits
mailing list