[Mlir-commits] [mlir] a3672ad - [mlir][sparse] avoid unnecessary tmp COO buffer and convert when lowering ConcatentateOp.

Peiming Liu llvmlistbot at llvm.org
Fri Dec 16 10:26:46 PST 2022


Author: Peiming Liu
Date: 2022-12-16T18:26:39Z
New Revision: a3672add76051ed27486eb163a8d0efbdf76f898

URL: https://github.com/llvm/llvm-project/commit/a3672add76051ed27486eb163a8d0efbdf76f898
DIFF: https://github.com/llvm/llvm-project/commit/a3672add76051ed27486eb163a8d0efbdf76f898.diff

LOG: [mlir][sparse] avoid unnecessary tmp COO buffer and convert when lowering ConcatentateOp.

When concat along dim 0, and all inputs/outputs are ordered with identity dimension ordering,
the concatenated coordinates will be yield in lexOrder, thus no need to sort.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140228

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 70c74bf7f82f5..37691253e58d6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -173,6 +173,12 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     /// Constructs a new encoding with the dimOrdering and higherOrdering
     /// reset to the default/identity.
     SparseTensorEncodingAttr withoutOrdering() const;
+
+    /// Return true if every level is dense in the encoding.
+    bool isAllDense() const;
+
+    /// Return true if the encoding has an identity dimension ordering.
+    bool hasIdDimOrdering() const;
   }];
 
   let genVerifyDecl = 1;

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index aecde7dfd4e2d..3d7dca9bb1196 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -63,6 +63,14 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
       getPointerBitWidth(), getIndexBitWidth());
 }
 
+bool SparseTensorEncodingAttr::isAllDense() const {
+  return llvm::all_of(getDimLevelType(), isDenseDLT);
+}
+
+bool SparseTensorEncodingAttr::hasIdDimOrdering() const {
+  return !getDimOrdering() || getDimOrdering().isIdentity();
+}
+
 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
   if (failed(parser.parseLess()))
     return {};
@@ -172,7 +180,7 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
   }
   printer << " ]";
   // Print remaining members only for non-default values.
-  if (getDimOrdering() && !getDimOrdering().isIdentity())
+  if (!hasIdDimOrdering())
     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
   if (getHigherOrdering())
     printer << ", higherOrdering = affine_map<" << getHigherOrdering() << ">";

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b4d1491ad6af6..01ae0150e6505 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -1349,8 +1349,7 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
     bool allDense = false;
     Value dstTensor;
     if (encDst) {
-      allDense = llvm::all_of(encDst.getDimLevelType(),
-                              [](DimLevelType dlt) { return isDenseDLT(dlt); });
+      allDense = encDst.isAllDense();
       // Start a new COO or an initialized annotated all dense sparse tensor.
       dst = params.genBuffers(encDst, sizes, dstTp)
                 .genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fdb5f67311ffc..ebc4c8152b008 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -525,14 +525,35 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
     // %t = convert_to_dest_tensor(%tmp)
     SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
     Value dst; // Destination tensor for inserting source tensor values.
-    bool allDense = false;
+    bool needTmpCOO = true;
     if (encDst) {
-      allDense = llvm::all_of(encDst.getDimLevelType(),
-                              [](DimLevelType dlt) { return isDenseDLT(dlt); });
+      bool allDense = encDst.isAllDense();
+      bool allOrdered = false;
+      // When concatenating on dimension 0, and all inputs are sorted and have
+      // an identity dimOrdering, the concatenate will generate coords in
+      // lexOrder thus no need for the tmp COO buffer.
+      // TODO: When conDim != 0, as long as conDim is the first dimension
+      // in all input/output buffers, and all input/output buffers have the same
+      // dimOrdering, the tmp COO buffer is still unnecessary (e.g, concatenate
+      // CSC matrices along column).
+      if (!allDense && conDim == 0 && encDst.hasIdDimOrdering()) {
+        for (auto i : op.getInputs()) {
+          auto rtp = i.getType().cast<RankedTensorType>();
+          auto srcEnc = getSparseTensorEncoding(rtp);
+          if (isAllDimOrdered(rtp) && (!srcEnc || srcEnc.hasIdDimOrdering())) {
+            allOrdered = true;
+            continue;
+          }
+          allOrdered = false;
+          break;
+        }
+      }
+
+      needTmpCOO = !allDense && !allOrdered;
       SmallVector<Value> dynSizes;
       getDynamicSizes(dstTp, sizes, dynSizes);
       RankedTensorType tp = dstTp;
-      if (!allDense) {
+      if (needTmpCOO) {
         tp = getUnorderedCOOFromType(dstTp);
         encDst = getSparseTensorEncoding(tp);
       }
@@ -596,7 +617,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
 
     if (encDst) {
       dst = rewriter.create<LoadOp>(loc, dst, true);
-      if (!allDense) {
+      if (needTmpCOO) {
         Value tmpCoo = dst;
         dst = rewriter.create<ConvertOp>(loc, dstTp, tmpCoo).getResult();
         rewriter.create<DeallocTensorOp>(loc, tmpCoo);

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
index 1819579855b2e..310159ae369c4 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
@@ -79,8 +79,7 @@
 //       CHECK:    scf.yield %[[RET_6]]
 //       CHECK:  }
 //       CHECK:  %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts
-//       CHECK:  %[[TMP_22:.*]] = sparse_tensor.convert %[[TMP_23]] : tensor<9x4xf64, #sparse_tensor
-//       CHECK:  return %[[TMP_22]] : tensor<9x4xf64, #sparse_tensor
+//       CHECK:  return %[[TMP_23]] : tensor<9x4xf64, #sparse_tensor
 func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
                                 %arg1: tensor<3x4xf64, #DCSR>,
                                 %arg2: tensor<4x4xf64, #DCSR>)
@@ -166,8 +165,7 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
 //       CHECK:    scf.yield %[[RET_6]]
 //       CHECK:  }
 //       CHECK:  %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts
-//       CHECK:  %[[TMP_22:.*]] = sparse_tensor.convert %[[TMP_23]] : tensor<?x?xf64, #sparse_tensor
-//       CHECK:  return %[[TMP_22]] : tensor<?x?xf64, #sparse_tensor
+//       CHECK:  return %[[TMP_23]] : tensor<?x?xf64, #sparse_tensor
 func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
                                 %arg1: tensor<3x4xf64, #DCSR>,
                                 %arg2: tensor<4x4xf64, #DCSR>)


        


More information about the Mlir-commits mailing list