[Mlir-commits] [mlir] 34c9c59 - [mlir][sparse] Using SparseTensorType in SparsePackOpConverter
wren romano
llvmlistbot at llvm.org
Mon Apr 3 16:37:04 PDT 2023
Author: wren romano
Date: 2023-04-03T16:36:56-07:00
New Revision: 34c9c59ce4bbe3f6df2b3bc82d0485d4339e057e
URL: https://github.com/llvm/llvm-project/commit/34c9c59ce4bbe3f6df2b3bc82d0485d4339e057e
DIFF: https://github.com/llvm/llvm-project/commit/34c9c59ce4bbe3f6df2b3bc82d0485d4339e057e.diff
LOG: [mlir][sparse] Using SparseTensorType in SparsePackOpConverter
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D147465
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 31bf59552f4e0..b9f75e9ad0054 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1235,29 +1235,28 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
matchAndRewrite(PackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- const auto rtp = getRankedTensorType(op.getResult());
- assert(isUniqueCOOType(rtp));
+ const auto stt = getSparseTensorType(op.getResult());
+ assert(isUniqueCOOType(stt));
SmallVector<Value> fields;
Location loc = op.getLoc();
foreachFieldAndTypeInSparseTensor(
- rtp,
- [&rewriter, &fields, &op, rtp,
+ stt,
+ [&rewriter, &fields, &op, stt,
loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
Level /*lvl*/, DimLevelType /*dlt*/) -> bool {
assert(fields.size() == fIdx);
- auto enc = getSparseTensorEncoding(rtp);
Value field;
switch (fKind) {
case SparseTensorFieldKind::StorageSpec:
- field = SparseTensorSpecifier::getInitValue(rewriter, loc, rtp);
+ field = SparseTensorSpecifier::getInitValue(rewriter, loc, stt);
break;
case SparseTensorFieldKind::PosMemRef: {
// TACO-style COO starts with a PosBuffer
// By creating a constant value for it, we avoid the complexity of
// memory management.
- const auto posTp = enc.getPosType();
+ const auto posTp = stt.getPosType();
auto tensorType = RankedTensorType::get({2}, posTp);
auto memrefType = MemRefType::get(tensorType.getShape(),
tensorType.getElementType());
@@ -1306,13 +1305,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
return true;
});
- MutSparseTensorDescriptor desc(rtp, fields);
+ MutSparseTensorDescriptor desc(stt, fields);
auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getValues(), 0);
- // FIXME: should use `SparseTensorType::getLvlRank` in lieu of
- // `RankedTensorType::getRank`, because the latter introduces dim/lvl
- // ambiguity.
- for (Level lvl = 0, lvlRank = rtp.getRank(); lvl < lvlRank; lvl++) {
- const auto sh = rtp.getShape()[lvl];
+ for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
+ // FIXME: dim/lvl confusion!
+ const auto sh = stt.getDimShape()[lvl];
assert(!ShapedType::isDynamic(sh));
desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh));
if (lvl == 0)
More information about the Mlir-commits
mailing list