[Mlir-commits] [mlir] a3672ad - [mlir][sparse] avoid unnecessary tmp COO buffer and convert when lowering ConcatentateOp.
Peiming Liu
llvmlistbot at llvm.org
Fri Dec 16 10:26:46 PST 2022
Author: Peiming Liu
Date: 2022-12-16T18:26:39Z
New Revision: a3672add76051ed27486eb163a8d0efbdf76f898
URL: https://github.com/llvm/llvm-project/commit/a3672add76051ed27486eb163a8d0efbdf76f898
DIFF: https://github.com/llvm/llvm-project/commit/a3672add76051ed27486eb163a8d0efbdf76f898.diff
LOG: [mlir][sparse] avoid unnecessary tmp COO buffer and convert when lowering ConcatentateOp.
When concat along dim 0, and all inputs/outputs are ordered with identity dimension ordering,
the concatenated coordinates will be yield in lexOrder, thus no need to sort.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D140228
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 70c74bf7f82f5..37691253e58d6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -173,6 +173,12 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// Constructs a new encoding with the dimOrdering and higherOrdering
/// reset to the default/identity.
SparseTensorEncodingAttr withoutOrdering() const;
+
+ /// Return true if every level is dense in the encoding.
+ bool isAllDense() const;
+
+ /// Return true if the encoding has an identity dimension ordering.
+ bool hasIdDimOrdering() const;
}];
let genVerifyDecl = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index aecde7dfd4e2d..3d7dca9bb1196 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -63,6 +63,14 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
getPointerBitWidth(), getIndexBitWidth());
}
+bool SparseTensorEncodingAttr::isAllDense() const {
+ return llvm::all_of(getDimLevelType(), isDenseDLT);
+}
+
+bool SparseTensorEncodingAttr::hasIdDimOrdering() const {
+ return !getDimOrdering() || getDimOrdering().isIdentity();
+}
+
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
@@ -172,7 +180,7 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
}
printer << " ]";
// Print remaining members only for non-default values.
- if (getDimOrdering() && !getDimOrdering().isIdentity())
+ if (!hasIdDimOrdering())
printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
if (getHigherOrdering())
printer << ", higherOrdering = affine_map<" << getHigherOrdering() << ">";
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b4d1491ad6af6..01ae0150e6505 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -1349,8 +1349,7 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
bool allDense = false;
Value dstTensor;
if (encDst) {
- allDense = llvm::all_of(encDst.getDimLevelType(),
- [](DimLevelType dlt) { return isDenseDLT(dlt); });
+ allDense = encDst.isAllDense();
// Start a new COO or an initialized annotated all dense sparse tensor.
dst = params.genBuffers(encDst, sizes, dstTp)
.genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fdb5f67311ffc..ebc4c8152b008 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -525,14 +525,35 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// %t = convert_to_dest_tensor(%tmp)
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
Value dst; // Destination tensor for inserting source tensor values.
- bool allDense = false;
+ bool needTmpCOO = true;
if (encDst) {
- allDense = llvm::all_of(encDst.getDimLevelType(),
- [](DimLevelType dlt) { return isDenseDLT(dlt); });
+ bool allDense = encDst.isAllDense();
+ bool allOrdered = false;
+ // When concatenating on dimension 0, and all inputs are sorted and have
+ // an identity dimOrdering, the concatenate will generate coords in
+ // lexOrder thus no need for the tmp COO buffer.
+ // TODO: When conDim != 0, as long as conDim is the first dimension
+ // in all input/output buffers, and all input/output buffers have the same
+ // dimOrdering, the tmp COO buffer is still unnecessary (e.g, concatenate
+ // CSC matrices along column).
+ if (!allDense && conDim == 0 && encDst.hasIdDimOrdering()) {
+ for (auto i : op.getInputs()) {
+ auto rtp = i.getType().cast<RankedTensorType>();
+ auto srcEnc = getSparseTensorEncoding(rtp);
+ if (isAllDimOrdered(rtp) && (!srcEnc || srcEnc.hasIdDimOrdering())) {
+ allOrdered = true;
+ continue;
+ }
+ allOrdered = false;
+ break;
+ }
+ }
+
+ needTmpCOO = !allDense && !allOrdered;
SmallVector<Value> dynSizes;
getDynamicSizes(dstTp, sizes, dynSizes);
RankedTensorType tp = dstTp;
- if (!allDense) {
+ if (needTmpCOO) {
tp = getUnorderedCOOFromType(dstTp);
encDst = getSparseTensorEncoding(tp);
}
@@ -596,7 +617,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
if (encDst) {
dst = rewriter.create<LoadOp>(loc, dst, true);
- if (!allDense) {
+ if (needTmpCOO) {
Value tmpCoo = dst;
dst = rewriter.create<ConvertOp>(loc, dstTp, tmpCoo).getResult();
rewriter.create<DeallocTensorOp>(loc, tmpCoo);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
index 1819579855b2e..310159ae369c4 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
@@ -79,8 +79,7 @@
// CHECK: scf.yield %[[RET_6]]
// CHECK: }
// CHECK: %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts
-// CHECK: %[[TMP_22:.*]] = sparse_tensor.convert %[[TMP_23]] : tensor<9x4xf64, #sparse_tensor
-// CHECK: return %[[TMP_22]] : tensor<9x4xf64, #sparse_tensor
+// CHECK: return %[[TMP_23]] : tensor<9x4xf64, #sparse_tensor
func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
%arg1: tensor<3x4xf64, #DCSR>,
%arg2: tensor<4x4xf64, #DCSR>)
@@ -166,8 +165,7 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
// CHECK: scf.yield %[[RET_6]]
// CHECK: }
// CHECK: %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts
-// CHECK: %[[TMP_22:.*]] = sparse_tensor.convert %[[TMP_23]] : tensor<?x?xf64, #sparse_tensor
-// CHECK: return %[[TMP_22]] : tensor<?x?xf64, #sparse_tensor
+// CHECK: return %[[TMP_23]] : tensor<?x?xf64, #sparse_tensor
func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
%arg1: tensor<3x4xf64, #DCSR>,
%arg2: tensor<4x4xf64, #DCSR>)
More information about the Mlir-commits
mailing list