[Mlir-commits] [mlir] f2696e4 - [mlir][sparse] Cleaning up some usage of SparseTensorType

wren romano llvmlistbot at llvm.org
Thu Mar 30 12:00:15 PDT 2023


Author: wren romano
Date: 2023-03-30T12:00:00-07:00
New Revision: f2696e469a5ca1fa3efeebef56e77507e73b5047

URL: https://github.com/llvm/llvm-project/commit/f2696e469a5ca1fa3efeebef56e77507e73b5047
DIFF: https://github.com/llvm/llvm-project/commit/f2696e469a5ca1fa3efeebef56e77507e73b5047.diff

LOG: [mlir][sparse] Cleaning up some usage of SparseTensorType

This is a followup to D147192.

Reviewed By: aartbik, Peiming

Differential Revision: https://reviews.llvm.org/D147196

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index dc5755bcf5012..52281bfa94ae9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -356,16 +356,10 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     Value srcTensor = op.getSrc();
-    auto srcTp = getRankedTensorType(srcTensor);
-    auto dstTp = getRankedTensorType(op.getResult());
-
-    SparseTensorType srcStt(srcTp);
-    SparseTensorType dstStt(dstTp);
-
-    const auto encSrc = srcStt.getEncoding();
-    if (!srcStt.hasEncoding() || !dstStt.hasEncoding()) {
+    const auto srcTp = getSparseTensorType(srcTensor);
+    const auto dstTp = getSparseTensorType(op.getResult());
+    if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
       return failure();
-    }
 
     // Generate code to represent the static dimension constants or compute
     // the dynamic dimension values.
@@ -373,11 +367,11 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
     sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
     SmallVector<Value> dstSizes;
     SmallVector<Value> dstDynSizes;
-    if (dstTp.hasStaticShape()) {
-      for (auto d : dstTp.getShape())
+    if (dstTp.hasStaticDimShape()) {
+      for (Dimension d : dstTp.getDimShape())
         dstSizes.push_back(constantIndex(rewriter, loc, d));
     } else {
-      ArrayRef<int64_t> dstShape = dstTp.getShape();
+      ArrayRef<DynSize> dstShape = dstTp.getDimShape();
       genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
                          op.getReassociationIndices());
       for (auto [idx, shape] : llvm::enumerate(dstShape)) {
@@ -389,8 +383,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
     // Only need a unordered COO buffer if input and output are not sorted
     // in the same way.
     Type bufferTp =
-        srcStt.isAllOrdered() && srcStt.isIdentity() && dstStt.isIdentity()
-            ? dstTp
+        srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity()
+            ? dstTp.getRankedTensorType()
             : getUnorderedCOOFromType(dstTp);
 
     Value buffer =
@@ -406,11 +400,12 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
     // followed by an optional
     //   %t = sparse_tensor.cast %tmp
     // depending on whether the input/output are sorted in the same way.
+    const auto encSrc = srcTp.getEncoding();
     ForeachOp foreachOp = rewriter.create<ForeachOp>(
         loc, srcTensor, buffer,
         [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
             ValueRange reduc) {
-          const Dimension dimRank = srcTp.getRank();
+          const Dimension dimRank = srcTp.getDimRank();
           SmallVector<Value> srcDcvs;
           srcDcvs.reserve(dimRank);
           for (Dimension d = 0; d < dimRank; d++) {
@@ -427,7 +422,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
 
     Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
     if (bufferTp != dstTp) {
-      Value converted = rewriter.create<ConvertOp>(loc, dstTp, t).getResult();
+      auto dstRTT = dstTp.getRankedTensorType();
+      Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
       rewriter.create<DeallocTensorOp>(loc, t);
       t = converted;
     }


        


More information about the Mlir-commits mailing list