[Mlir-commits] [mlir] 2ddfacd - [mlir][sparse] codegen for sparse dealloc
Aart Bik
llvmlistbot at llvm.org
Thu Sep 1 22:21:38 PDT 2022
Author: Aart Bik
Date: 2022-09-01T22:21:20-07:00
New Revision: 2ddfacd95ccf439e369d3f5daea9066903fe2f50
URL: https://github.com/llvm/llvm-project/commit/2ddfacd95ccf439e369d3f5daea9066903fe2f50
DIFF: https://github.com/llvm/llvm-project/commit/2ddfacd95ccf439e369d3f5daea9066903fe2f50.diff
LOG: [mlir][sparse] codegen for sparse dealloc
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D133171
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 65c1027fc0a8..48c07f42b3e5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -17,6 +17,7 @@
#include "CodegenUtils.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -232,7 +233,31 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
}
};
-/// Sparse conversion rule for pointer accesses.
+/// Sparse codegen rule for the dealloc operator.
+class SparseTensorDeallocConverter
+ : public OpConversionPattern<bufferization::DeallocTensorOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto enc = getSparseTensorEncoding(op.getTensor().getType());
+ if (!enc)
+ return failure();
+ // Replace the tuple deallocation with field deallocations.
+ Location loc = op->getLoc();
+ Value tuple = adaptor.getTensor();
+ for (unsigned i = 0, sz = tuple.getType().cast<TupleType>().size(); i < sz;
+ i++) {
+ Value mem = createTupleGet(rewriter, loc, tuple, i);
+ rewriter.create<memref::DeallocOp>(loc, mem);
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+/// Sparse codegen rule for pointer accesses.
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -251,7 +276,7 @@ class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
}
};
-/// Sparse conversion rule for index accesses.
+/// Sparse codegen rule for index accesses.
class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -270,7 +295,7 @@ class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
}
};
-/// Sparse conversion rule for value accesses.
+/// Sparse codegen rule for value accesses.
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -280,7 +305,7 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
// Replace the requested values access with corresponding field.
Location loc = op->getLoc();
Value tuple = adaptor.getTensor();
- unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
+ unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
return success();
}
@@ -306,6 +331,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
- SparseToPointersConverter, SparseToIndicesConverter,
- SparseToValuesConverter>(typeConverter, patterns.getContext());
+ SparseTensorDeallocConverter, SparseToPointersConverter,
+ SparseToIndicesConverter, SparseToValuesConverter>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index c1a6b7a6e45c..538041a1b36a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -124,8 +124,7 @@ struct SparseTensorConversionPass
});
// The following operations and dialects may be introduced by the
// rewriting rules, and are therefore marked as legal.
- target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp,
- complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
+ target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
linalg::YieldOp, tensor::ExtractOp>();
target.addLegalDialect<
arith::ArithmeticDialect, bufferization::BufferizationDialect,
@@ -160,7 +159,9 @@ struct SparseTensorCodegenPass
// Almost everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addLegalOp<StorageGetOp, StorageSetOp>();
- // All dynamic rules below accept new function, call, return.
+ // All dynamic rules below accept new function, call, return, and various
+ // tensor and bufferization operations as legal output of the rewriting
+ // provided that all sparse tensor types have been fully rewritten.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
@@ -170,6 +171,10 @@ struct SparseTensorCodegenPass
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
+ target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
+ [&](bufferization::DeallocTensorOp op) {
+ return converter.isLegal(op.getTensor().getType());
+ });
// Legal dialects may occur in generated code.
target.addLegalDialect<arith::ArithmeticDialect,
bufferization::BufferizationDialect,
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 905278ca2665..6d95be81a8c2 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -141,3 +141,19 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
%0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
return %0 : memref<?xf64>
}
+
+// CHECK-LABEL: func @sparse_dealloc_csr(
+// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
+// CHECK: %[[F0:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<2xindex>
+// CHECK: memref.dealloc %[[F0]] : memref<2xindex>
+// CHECK: %[[F1:.*]] = sparse_tensor.storage_get %[[A]][1] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi32>
+// CHECK: memref.dealloc %[[F1]] : memref<?xi32>
+// CHECK: %[[F2:.*]] = sparse_tensor.storage_get %[[A]][2] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi64>
+// CHECK: memref.dealloc %[[F2]] : memref<?xi64>
+// CHECK: %[[F3:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xf64>
+// CHECK: memref.dealloc %[[F3]] : memref<?xf64>
+// CHECK: return
+func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
+ bufferization.dealloc_tensor %arg0 : tensor<?x?xf64, #CSR>
+ return
+}
More information about the Mlir-commits
mailing list