[Mlir-commits] [mlir] d950bdc - [mlir][sparse] misc code cleanup
wren romano
llvmlistbot at llvm.org
Wed Feb 15 13:29:07 PST 2023
Author: wren romano
Date: 2023-02-15T13:29:00-08:00
New Revision: d950bdc73eb23a79cd4cf35fd4c8cb198e00b2d0
URL: https://github.com/llvm/llvm-project/commit/d950bdc73eb23a79cd4cf35fd4c8cb198e00b2d0
DIFF: https://github.com/llvm/llvm-project/commit/d950bdc73eb23a79cd4cf35fd4c8cb198e00b2d0.diff
LOG: [mlir][sparse] misc code cleanup
* Flattening/simplifying some nested conditionals
* const-ifying some local variables
Depends On D143800
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D143949
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index cfd7bca8824e..7622554479af 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -1414,20 +1414,18 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
createOrFoldDimCall(rewriter, loc, srcTp, adaptedOp, concatDim);
offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
}
- if (dstTp.hasEncoding()) {
- if (!allDense) {
- // In sparse output case, the destination holds the COO.
- Value coo = dst;
- dst = params.genNewCall(Action::kFromCOO, coo);
- // Release resources.
- genDelCOOCall(rewriter, loc, elemTp, coo);
- } else {
- dst = dstTensor;
- }
- rewriter.replaceOp(op, dst);
- } else {
+ if (!dstTp.hasEncoding()) {
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(
op, dstTp.getRankedTensorType(), dst);
+ } else if (allDense) {
+ rewriter.replaceOp(op, dstTensor);
+ } else {
+ // In sparse output case, the destination holds the COO.
+ Value coo = dst;
+ dst = params.genNewCall(Action::kFromCOO, coo);
+ // Release resources.
+ genDelCOOCall(rewriter, loc, elemTp, coo);
+ rewriter.replaceOp(op, dst);
}
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 7046306ef17e..350649426d82 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -513,13 +513,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
}
needTmpCOO = !allDense && !allOrdered;
+ const RankedTensorType tp = needTmpCOO ? getUnorderedCOOFromType(dstTp)
+ : dstTp.getRankedTensorType();
+ encDst = needTmpCOO ? getSparseTensorEncoding(tp) : encDst;
SmallVector<Value> dynSizes;
getDynamicSizes(dstTp, sizes, dynSizes);
- RankedTensorType tp = dstTp;
- if (needTmpCOO) {
- tp = getUnorderedCOOFromType(dstTp);
- encDst = getSparseTensorEncoding(tp);
- }
dst = rewriter.create<AllocTensorOp>(loc, tp, dynSizes).getResult();
if (allDense) {
// Create a view of the values buffer to match the unannotated dense
@@ -592,21 +590,20 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// Temp variable to avoid needing to call `getRankedTensorType`
// in the three use-sites below.
const RankedTensorType dstRTT = dstTp;
- if (encDst) {
- if (!allDense) {
- dst = rewriter.create<LoadOp>(loc, dst, true);
- if (needTmpCOO) {
- Value tmpCoo = dst;
- dst = rewriter.create<ConvertOp>(loc, dstRTT, tmpCoo).getResult();
- rewriter.create<DeallocTensorOp>(loc, tmpCoo);
- }
- } else {
- dst = rewriter.create<ConvertOp>(loc, dstRTT, annotatedDenseDst)
- .getResult();
+ if (!encDst) {
+ rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstRTT, dst);
+ } else if (allDense) {
+ rewriter.replaceOp(
+ op, rewriter.create<ConvertOp>(loc, dstRTT, annotatedDenseDst)
+ .getResult());
+ } else {
+ dst = rewriter.create<LoadOp>(loc, dst, true);
+ if (needTmpCOO) {
+ Value tmpCoo = dst;
+ dst = rewriter.create<ConvertOp>(loc, dstRTT, tmpCoo).getResult();
+ rewriter.create<DeallocTensorOp>(loc, tmpCoo);
}
rewriter.replaceOp(op, dst);
- } else {
- rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstRTT, dst);
}
return success();
}
More information about the Mlir-commits
mailing list