[Mlir-commits] [mlir] f2696e4 - [mlir][sparse] Cleaning up some usage of SparseTensorType
wren romano
llvmlistbot at llvm.org
Thu Mar 30 12:00:15 PDT 2023
Author: wren romano
Date: 2023-03-30T12:00:00-07:00
New Revision: f2696e469a5ca1fa3efeebef56e77507e73b5047
URL: https://github.com/llvm/llvm-project/commit/f2696e469a5ca1fa3efeebef56e77507e73b5047
DIFF: https://github.com/llvm/llvm-project/commit/f2696e469a5ca1fa3efeebef56e77507e73b5047.diff
LOG: [mlir][sparse] Cleaning up some usage of SparseTensorType
This is a followup to D147192.
Reviewed By: aartbik, Peiming
Differential Revision: https://reviews.llvm.org/D147196
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index dc5755bcf5012..52281bfa94ae9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -356,16 +356,10 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value srcTensor = op.getSrc();
- auto srcTp = getRankedTensorType(srcTensor);
- auto dstTp = getRankedTensorType(op.getResult());
-
- SparseTensorType srcStt(srcTp);
- SparseTensorType dstStt(dstTp);
-
- const auto encSrc = srcStt.getEncoding();
- if (!srcStt.hasEncoding() || !dstStt.hasEncoding()) {
+ const auto srcTp = getSparseTensorType(srcTensor);
+ const auto dstTp = getSparseTensorType(op.getResult());
+ if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
return failure();
- }
// Generate code to represent the static dimension constants or compute
// the dynamic dimension values.
@@ -373,11 +367,11 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
SmallVector<Value> dstSizes;
SmallVector<Value> dstDynSizes;
- if (dstTp.hasStaticShape()) {
- for (auto d : dstTp.getShape())
+ if (dstTp.hasStaticDimShape()) {
+ for (Dimension d : dstTp.getDimShape())
dstSizes.push_back(constantIndex(rewriter, loc, d));
} else {
- ArrayRef<int64_t> dstShape = dstTp.getShape();
+ ArrayRef<DynSize> dstShape = dstTp.getDimShape();
genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
op.getReassociationIndices());
for (auto [idx, shape] : llvm::enumerate(dstShape)) {
@@ -389,8 +383,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
// Only need a unordered COO buffer if input and output are not sorted
// in the same way.
Type bufferTp =
- srcStt.isAllOrdered() && srcStt.isIdentity() && dstStt.isIdentity()
- ? dstTp
+ srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity()
+ ? dstTp.getRankedTensorType()
: getUnorderedCOOFromType(dstTp);
Value buffer =
@@ -406,11 +400,12 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
// followed by an optional
// %t = sparse_tensor.cast %tmp
// depending on whether the input/output are sorted in the same way.
+ const auto encSrc = srcTp.getEncoding();
ForeachOp foreachOp = rewriter.create<ForeachOp>(
loc, srcTensor, buffer,
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
ValueRange reduc) {
- const Dimension dimRank = srcTp.getRank();
+ const Dimension dimRank = srcTp.getDimRank();
SmallVector<Value> srcDcvs;
srcDcvs.reserve(dimRank);
for (Dimension d = 0; d < dimRank; d++) {
@@ -427,7 +422,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
if (bufferTp != dstTp) {
- Value converted = rewriter.create<ConvertOp>(loc, dstTp, t).getResult();
+ auto dstRTT = dstTp.getRankedTensorType();
+ Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
rewriter.create<DeallocTensorOp>(loc, t);
t = converted;
}
More information about the Mlir-commits
mailing list