[Mlir-commits] [mlir] af7ac1d - [mlir][sparse] Sharing calls to adaptor.getOperands()[0]

wren romano llvmlistbot at llvm.org
Fri Oct 1 14:20:37 PDT 2021


Author: wren romano
Date: 2021-10-01T14:20:31-07:00
New Revision: af7ac1d95b7daa7d758f69d0c117c4d91a21463e

URL: https://github.com/llvm/llvm-project/commit/af7ac1d95b7daa7d758f69d0c117c4d91a21463e
DIFF: https://github.com/llvm/llvm-project/commit/af7ac1d95b7daa7d758f69d0c117c4d91a21463e.diff

LOG: [mlir][sparse] Sharing calls to adaptor.getOperands()[0]

This is preliminary work towards D110790. Depends On D110883.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D110884

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 154fe3fb9be9..11f589c8b446 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -380,6 +380,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     Type resType = op.getType();
     auto encDst = getSparseTensorEncoding(resType);
     auto encSrc = getSparseTensorEncoding(op.source().getType());
+    auto src = adaptor.getOperands()[0];
     if (encDst && encSrc) {
       // This is a sparse => sparse conversion, which is handled as follows:
       //   t = src->toCOO();         ; src to COO in dst order
@@ -388,8 +389,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       // yield the fastest conversion but avoids the need for a full
       // O(N^2) conversion matrix.
       Value perm;
-      Value coo =
-          genNewCall(rewriter, op, encDst, 3, perm, adaptor.getOperands()[0]);
+      Value coo = genNewCall(rewriter, op, encDst, 3, perm, src);
       rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo));
       return success();
     }
@@ -433,8 +433,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     SmallVector<Value> st;
     Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
     Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
-    Value tensor = adaptor.getOperands()[0];
-    auto indicesValues = genSplitSparseConstant(rewriter, op, tensor);
+    auto indicesValues = genSplitSparseConstant(rewriter, op, src);
     bool isCOOConstant = indicesValues.hasValue();
     Value indices;
     Value values;
@@ -447,7 +446,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     } else {
       for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
         lo.push_back(zero);
-        hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i));
+        hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
         st.push_back(one);
       }
     }
@@ -461,7 +460,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
                            val = genIndexAndValueForSparse(
                                rewriter, op, indices, values, ind, ivs, rank);
                          else
-                           val = genIndexAndValueForDense(rewriter, op, tensor,
+                           val = genIndexAndValueForDense(rewriter, op, src,
                                                           ind, ivs);
                          genAddEltCall(rewriter, op, eltType, ptr, val, ind,
                                        perm);


        


More information about the Mlir-commits mailing list