[Mlir-commits] [mlir] d950bdc - [mlir][sparse] misc code cleanup

wren romano llvmlistbot at llvm.org
Wed Feb 15 13:29:07 PST 2023


Author: wren romano
Date: 2023-02-15T13:29:00-08:00
New Revision: d950bdc73eb23a79cd4cf35fd4c8cb198e00b2d0

URL: https://github.com/llvm/llvm-project/commit/d950bdc73eb23a79cd4cf35fd4c8cb198e00b2d0
DIFF: https://github.com/llvm/llvm-project/commit/d950bdc73eb23a79cd4cf35fd4c8cb198e00b2d0.diff

LOG: [mlir][sparse] misc code cleanup

* Flattening/simplifying some nested conditionals
* const-ifying some local variables

Depends On D143800

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D143949

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index cfd7bca8824e..7622554479af 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -1414,20 +1414,18 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
           createOrFoldDimCall(rewriter, loc, srcTp, adaptedOp, concatDim);
       offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
     }
-    if (dstTp.hasEncoding()) {
-      if (!allDense) {
-        // In sparse output case, the destination holds the COO.
-        Value coo = dst;
-        dst = params.genNewCall(Action::kFromCOO, coo);
-        // Release resources.
-        genDelCOOCall(rewriter, loc, elemTp, coo);
-      } else {
-        dst = dstTensor;
-      }
-      rewriter.replaceOp(op, dst);
-    } else {
+    if (!dstTp.hasEncoding()) {
       rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(
           op, dstTp.getRankedTensorType(), dst);
+    } else if (allDense) {
+      rewriter.replaceOp(op, dstTensor);
+    } else {
+      // In sparse output case, the destination holds the COO.
+      Value coo = dst;
+      dst = params.genNewCall(Action::kFromCOO, coo);
+      // Release resources.
+      genDelCOOCall(rewriter, loc, elemTp, coo);
+      rewriter.replaceOp(op, dst);
     }
     return success();
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 7046306ef17e..350649426d82 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -513,13 +513,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
       }
 
       needTmpCOO = !allDense && !allOrdered;
+      const RankedTensorType tp = needTmpCOO ? getUnorderedCOOFromType(dstTp)
+                                             : dstTp.getRankedTensorType();
+      encDst = needTmpCOO ? getSparseTensorEncoding(tp) : encDst;
       SmallVector<Value> dynSizes;
       getDynamicSizes(dstTp, sizes, dynSizes);
-      RankedTensorType tp = dstTp;
-      if (needTmpCOO) {
-        tp = getUnorderedCOOFromType(dstTp);
-        encDst = getSparseTensorEncoding(tp);
-      }
       dst = rewriter.create<AllocTensorOp>(loc, tp, dynSizes).getResult();
       if (allDense) {
         // Create a view of the values buffer to match the unannotated dense
@@ -592,21 +590,20 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
     // Temp variable to avoid needing to call `getRankedTensorType`
     // in the three use-sites below.
     const RankedTensorType dstRTT = dstTp;
-    if (encDst) {
-      if (!allDense) {
-        dst = rewriter.create<LoadOp>(loc, dst, true);
-        if (needTmpCOO) {
-          Value tmpCoo = dst;
-          dst = rewriter.create<ConvertOp>(loc, dstRTT, tmpCoo).getResult();
-          rewriter.create<DeallocTensorOp>(loc, tmpCoo);
-        }
-      } else {
-        dst = rewriter.create<ConvertOp>(loc, dstRTT, annotatedDenseDst)
-                  .getResult();
+    if (!encDst) {
+      rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstRTT, dst);
+    } else if (allDense) {
+      rewriter.replaceOp(
+          op, rewriter.create<ConvertOp>(loc, dstRTT, annotatedDenseDst)
+                  .getResult());
+    } else {
+      dst = rewriter.create<LoadOp>(loc, dst, true);
+      if (needTmpCOO) {
+        Value tmpCoo = dst;
+        dst = rewriter.create<ConvertOp>(loc, dstRTT, tmpCoo).getResult();
+        rewriter.create<DeallocTensorOp>(loc, tmpCoo);
       }
       rewriter.replaceOp(op, dst);
-    } else {
-      rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstRTT, dst);
     }
     return success();
   }


        


More information about the Mlir-commits mailing list