[Mlir-commits] [mlir] 34c9c59 - [mlir][sparse] Using SparseTensorType in SparsePackOpConverter

wren romano llvmlistbot at llvm.org
Mon Apr 3 16:37:04 PDT 2023


Author: wren romano
Date: 2023-04-03T16:36:56-07:00
New Revision: 34c9c59ce4bbe3f6df2b3bc82d0485d4339e057e

URL: https://github.com/llvm/llvm-project/commit/34c9c59ce4bbe3f6df2b3bc82d0485d4339e057e
DIFF: https://github.com/llvm/llvm-project/commit/34c9c59ce4bbe3f6df2b3bc82d0485d4339e057e.diff

LOG: [mlir][sparse] Using SparseTensorType in SparsePackOpConverter

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D147465

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 31bf59552f4e0..b9f75e9ad0054 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1235,29 +1235,28 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
   matchAndRewrite(PackOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    const auto rtp = getRankedTensorType(op.getResult());
-    assert(isUniqueCOOType(rtp));
+    const auto stt = getSparseTensorType(op.getResult());
+    assert(isUniqueCOOType(stt));
 
     SmallVector<Value> fields;
     Location loc = op.getLoc();
 
     foreachFieldAndTypeInSparseTensor(
-        rtp,
-        [&rewriter, &fields, &op, rtp,
+        stt,
+        [&rewriter, &fields, &op, stt,
          loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
               Level /*lvl*/, DimLevelType /*dlt*/) -> bool {
           assert(fields.size() == fIdx);
-          auto enc = getSparseTensorEncoding(rtp);
           Value field;
           switch (fKind) {
           case SparseTensorFieldKind::StorageSpec:
-            field = SparseTensorSpecifier::getInitValue(rewriter, loc, rtp);
+            field = SparseTensorSpecifier::getInitValue(rewriter, loc, stt);
             break;
           case SparseTensorFieldKind::PosMemRef: {
             // TACO-style COO starts with a PosBuffer
             // By creating a constant value for it, we avoid the complexity of
             // memory management.
-            const auto posTp = enc.getPosType();
+            const auto posTp = stt.getPosType();
             auto tensorType = RankedTensorType::get({2}, posTp);
             auto memrefType = MemRefType::get(tensorType.getShape(),
                                               tensorType.getElementType());
@@ -1306,13 +1305,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
           return true;
         });
 
-    MutSparseTensorDescriptor desc(rtp, fields);
+    MutSparseTensorDescriptor desc(stt, fields);
     auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getValues(), 0);
-    // FIXME: should use `SparseTensorType::getLvlRank` in lieu of
-    // `RankedTensorType::getRank`, because the latter introduces dim/lvl
-    // ambiguity.
-    for (Level lvl = 0, lvlRank = rtp.getRank(); lvl < lvlRank; lvl++) {
-      const auto sh = rtp.getShape()[lvl];
+    for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
+      // FIXME: dim/lvl confusion!
+      const auto sh = stt.getDimShape()[lvl];
       assert(!ShapedType::isDynamic(sh));
       desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh));
       if (lvl == 0)


        


More information about the Mlir-commits mailing list