[Mlir-commits] [mlir] 2ddfacd - [mlir][sparse] codegen for sparse dealloc

Aart Bik llvmlistbot at llvm.org
Thu Sep 1 22:21:38 PDT 2022


Author: Aart Bik
Date: 2022-09-01T22:21:20-07:00
New Revision: 2ddfacd95ccf439e369d3f5daea9066903fe2f50

URL: https://github.com/llvm/llvm-project/commit/2ddfacd95ccf439e369d3f5daea9066903fe2f50
DIFF: https://github.com/llvm/llvm-project/commit/2ddfacd95ccf439e369d3f5daea9066903fe2f50.diff

LOG: [mlir][sparse] codegen for sparse dealloc

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133171

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 65c1027fc0a8..48c07f42b3e5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -17,6 +17,7 @@
 
 #include "CodegenUtils.h"
 
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -232,7 +233,31 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
   }
 };
 
-/// Sparse conversion rule for pointer accesses.
+/// Sparse codegen rule for the dealloc operator.
+class SparseTensorDeallocConverter
+    : public OpConversionPattern<bufferization::DeallocTensorOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto enc = getSparseTensorEncoding(op.getTensor().getType());
+    if (!enc)
+      return failure();
+    // Replace the tuple deallocation with field deallocations.
+    Location loc = op->getLoc();
+    Value tuple = adaptor.getTensor();
+    for (unsigned i = 0, sz = tuple.getType().cast<TupleType>().size(); i < sz;
+         i++) {
+      Value mem = createTupleGet(rewriter, loc, tuple, i);
+      rewriter.create<memref::DeallocOp>(loc, mem);
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+/// Sparse codegen rule for pointer accesses.
 class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -251,7 +276,7 @@ class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
   }
 };
 
-/// Sparse conversion rule for index accesses.
+/// Sparse codegen rule for index accesses.
 class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -270,7 +295,7 @@ class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
   }
 };
 
-/// Sparse conversion rule for value accesses.
+/// Sparse codegen rule for value accesses.
 class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -280,7 +305,7 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
     // Replace the requested values access with corresponding field.
     Location loc = op->getLoc();
     Value tuple = adaptor.getTensor();
-    unsigned i = tuple.getType().cast<TupleType>().size() - 1;  // last
+    unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
     rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
     return success();
   }
@@ -306,6 +331,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                                                RewritePatternSet &patterns) {
   patterns.add<SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
-               SparseToPointersConverter, SparseToIndicesConverter,
-               SparseToValuesConverter>(typeConverter, patterns.getContext());
+               SparseTensorDeallocConverter, SparseToPointersConverter,
+               SparseToIndicesConverter, SparseToValuesConverter>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index c1a6b7a6e45c..538041a1b36a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -124,8 +124,7 @@ struct SparseTensorConversionPass
         });
     // The following operations and dialects may be introduced by the
     // rewriting rules, and are therefore marked as legal.
-    target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp,
-                      complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
+    target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
                       linalg::YieldOp, tensor::ExtractOp>();
     target.addLegalDialect<
         arith::ArithmeticDialect, bufferization::BufferizationDialect,
@@ -160,7 +159,9 @@ struct SparseTensorCodegenPass
     // Almost everything in the sparse dialect must go!
     target.addIllegalDialect<SparseTensorDialect>();
     target.addLegalOp<StorageGetOp, StorageSetOp>();
-    // All dynamic rules below accept new function, call, return.
+    // All dynamic rules below accept new function, call, return, and various
+    // tensor and bufferization operations as legal output of the rewriting
+    // provided that all sparse tensor types have been fully rewritten.
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
       return converter.isSignatureLegal(op.getFunctionType());
     });
@@ -170,6 +171,10 @@ struct SparseTensorCodegenPass
     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
       return converter.isLegal(op.getOperandTypes());
     });
+    target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
+        [&](bufferization::DeallocTensorOp op) {
+          return converter.isLegal(op.getTensor().getType());
+        });
     // Legal dialects may occur in generated code.
     target.addLegalDialect<arith::ArithmeticDialect,
                            bufferization::BufferizationDialect,

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 905278ca2665..6d95be81a8c2 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -141,3 +141,19 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
   %0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
   return %0 : memref<?xf64>
 }
+
+// CHECK-LABEL: func @sparse_dealloc_csr(
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
+//       CHECK: %[[F0:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<2xindex>
+//       CHECK: memref.dealloc %[[F0]] : memref<2xindex>
+//       CHECK: %[[F1:.*]] = sparse_tensor.storage_get %[[A]][1] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi32>
+//       CHECK: memref.dealloc %[[F1]] : memref<?xi32>
+//       CHECK: %[[F2:.*]] = sparse_tensor.storage_get %[[A]][2] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi64>
+//       CHECK: memref.dealloc %[[F2]] : memref<?xi64>
+//       CHECK: %[[F3:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xf64>
+//       CHECK: memref.dealloc %[[F3]] : memref<?xf64>
+//       CHECK: return
+func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
+  bufferization.dealloc_tensor %arg0 : tensor<?x?xf64, #CSR>
+  return
+}


        


More information about the Mlir-commits mailing list