[Mlir-commits] [mlir] af7ac1d - [mlir][sparse] Sharing calls to adaptor.getOperands()[0]
wren romano
llvmlistbot at llvm.org
Fri Oct 1 14:20:37 PDT 2021
Author: wren romano
Date: 2021-10-01T14:20:31-07:00
New Revision: af7ac1d95b7daa7d758f69d0c117c4d91a21463e
URL: https://github.com/llvm/llvm-project/commit/af7ac1d95b7daa7d758f69d0c117c4d91a21463e
DIFF: https://github.com/llvm/llvm-project/commit/af7ac1d95b7daa7d758f69d0c117c4d91a21463e.diff
LOG: [mlir][sparse] Sharing calls to adaptor.getOperands()[0]
This is preliminary work towards D110790. Depends On D110883.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D110884
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 154fe3fb9be9..11f589c8b446 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -380,6 +380,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
Type resType = op.getType();
auto encDst = getSparseTensorEncoding(resType);
auto encSrc = getSparseTensorEncoding(op.source().getType());
+ auto src = adaptor.getOperands()[0];
if (encDst && encSrc) {
// This is a sparse => sparse conversion, which is handled as follows:
// t = src->toCOO(); ; src to COO in dst order
@@ -388,8 +389,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
// yield the fastest conversion but avoids the need for a full
// O(N^2) conversion matrix.
Value perm;
- Value coo =
- genNewCall(rewriter, op, encDst, 3, perm, adaptor.getOperands()[0]);
+ Value coo = genNewCall(rewriter, op, encDst, 3, perm, src);
rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo));
return success();
}
@@ -433,8 +433,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
SmallVector<Value> st;
Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
- Value tensor = adaptor.getOperands()[0];
- auto indicesValues = genSplitSparseConstant(rewriter, op, tensor);
+ auto indicesValues = genSplitSparseConstant(rewriter, op, src);
bool isCOOConstant = indicesValues.hasValue();
Value indices;
Value values;
@@ -447,7 +446,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
} else {
for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
lo.push_back(zero);
- hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i));
+ hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
st.push_back(one);
}
}
@@ -461,7 +460,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
val = genIndexAndValueForSparse(
rewriter, op, indices, values, ind, ivs, rank);
else
- val = genIndexAndValueForDense(rewriter, op, tensor,
+ val = genIndexAndValueForDense(rewriter, op, src,
ind, ivs);
genAddEltCall(rewriter, op, eltType, ptr, val, ind,
perm);
More information about the Mlir-commits
mailing list