[Mlir-commits] [mlir] 52028c1 - [mlir][sparse] Generate AOS subviews on-demand.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 11 08:57:07 PST 2023


Author: bixia1
Date: 2023-01-11T08:57:01-08:00
New Revision: 52028c1a48af87f3f56ca51fdfc13c8b89010302

URL: https://github.com/llvm/llvm-project/commit/52028c1a48af87f3f56ca51fdfc13c8b89010302
DIFF: https://github.com/llvm/llvm-project/commit/52028c1a48af87f3f56ca51fdfc13c8b89010302.diff

LOG: [mlir][sparse] Generate AOS subviews on-demand.

Previously, we generate AOS subviews for indices buffers when constructing an
immutable sparse tensor descriptor. We now only generate such subviews when
getIdxMemRefOrView is requested.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D141325

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index a34e66a2918bb..eaa4b420bbcd3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -138,7 +138,7 @@ class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
                           op.getDim().value().getZExtValue());
     } else {
       auto enc = op.getSpecifier().getType().getEncoding();
-      StorageLayout<true> layout(enc);
+      StorageLayout layout(enc);
       Optional<unsigned> dim = std::nullopt;
       if (op.getDim())
         dim = op.getDim().value().getZExtValue();

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index a680ddf06d88d..38a7e0e0610fb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -295,11 +295,10 @@ static Value genCompressed(OpBuilder &builder, Location loc,
   unsigned idxIndex;
   unsigned idxStride;
   std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(d);
-  unsigned ptrIndex = desc.getPtrMemRefIndex(d);
   Value one = constantIndex(builder, loc, 1);
   Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
-  Value plo = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pos);
-  Value phi = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pp1);
+  Value plo = genLoad(builder, loc, desc.getPtrMemRef(d), pos);
+  Value phi = genLoad(builder, loc, desc.getPtrMemRef(d), pp1);
   Value msz = desc.getIdxMemSize(builder, loc, d);
   Value idxStrideC;
   if (idxStride > 1) {
@@ -325,7 +324,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
   builder.create<scf::YieldOp>(loc, eq);
   builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
   if (d > 0)
-    genStore(builder, loc, msz, desc.getMemRefField(ptrIndex), pos);
+    genStore(builder, loc, msz, desc.getPtrMemRef(d), pos);
   builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
   builder.setInsertionPointAfter(ifOp1);
   Value p = ifOp1.getResult(0);
@@ -352,7 +351,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
   // If !present (changes fields, update next).
   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
   Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
-  genStore(builder, loc, mszp1, desc.getMemRefField(ptrIndex), pp1);
+  genStore(builder, loc, mszp1, desc.getPtrMemRef(d), pp1);
   createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d,
                  indices[d]);
   // Prepare the next dimension "as needed".
@@ -638,10 +637,8 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
     if (!index || !getSparseTensorEncoding(adaptor.getSource().getType()))
       return failure();
 
-    Location loc = op.getLoc();
-    auto desc =
-        getDescriptorFromTensorTuple(rewriter, loc, adaptor.getSource());
-    auto sz = sizeFromTensorAtDim(rewriter, loc, desc, *index);
+    auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
+    auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
 
     if (!sz)
       return failure();
@@ -756,8 +753,7 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
     if (!getSparseTensorEncoding(op.getTensor().getType()))
       return failure();
     Location loc = op->getLoc();
-    auto desc =
-        getDescriptorFromTensorTuple(rewriter, loc, adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
     RankedTensorType srcType =
         op.getTensor().getType().cast<RankedTensorType>();
     Type eltType = srcType.getElementType();
@@ -900,8 +896,7 @@ class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
     // Replace the requested pointer access with corresponding field.
     // The cast_op is inserted by type converter to intermix 1:N type
     // conversion.
-    auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
-                                             adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
     uint64_t dim = op.getDimension().getZExtValue();
     rewriter.replaceOp(op, desc.getPtrMemRef(dim));
     return success();
@@ -919,17 +914,17 @@ class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
     // Replace the requested pointer access with corresponding field.
     // The cast_op is inserted by type converter to intermix 1:N type
     // conversion.
-    auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
-                                             adaptor.getTensor());
+    Location loc = op.getLoc();
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
     uint64_t dim = op.getDimension().getZExtValue();
-    Value field = desc.getIdxMemRef(dim);
+    Value field = desc.getIdxMemRefOrView(rewriter, loc, dim);
 
     // Insert a cast to bridge the actual type to the user expected type. If the
     // actual type and the user expected type aren't compatible, the compiler or
     // the runtime will issue an error.
     Type resType = op.getResult().getType();
     if (resType != field.getType())
-      field = rewriter.create<memref::CastOp>(op.getLoc(), resType, field);
+      field = rewriter.create<memref::CastOp>(loc, resType, field);
     rewriter.replaceOp(op, field);
 
     return success();
@@ -967,8 +962,7 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
     // Replace the requested pointer access with corresponding field.
     // The cast_op is inserted by type converter to intermix 1:N type
     // conversion.
-    auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
-                                             adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
     rewriter.replaceOp(op, desc.getValMemRef());
     return success();
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
index b0eb72e6fd668..e24a38d3947db 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
@@ -109,41 +109,24 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
 // SparseTensorDescriptor methods.
 //===----------------------------------------------------------------------===//
 
-sparse_tensor::SparseTensorDescriptor::SparseTensorDescriptor(
-    OpBuilder &builder, Location loc, Type tp, ValueArrayRef buffers)
-    : SparseTensorDescriptorImpl<false>(tp), expandedFields() {
-  SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp);
-  unsigned rank = enc.getDimLevelType().size();
+Value sparse_tensor::SparseTensorDescriptor::getIdxMemRefOrView(
+    OpBuilder &builder, Location loc, unsigned idxDim) const {
+  auto enc = getSparseTensorEncoding(rType);
   unsigned cooStart = getCOOStart(enc);
-  if (cooStart < rank) {
-    ValueRange beforeFields = buffers.drop_back(3);
-    expandedFields.append(beforeFields.begin(), beforeFields.end());
-    Value buffer = buffers[buffers.size() - 3];
-
+  unsigned idx = idxDim >= cooStart ? cooStart : idxDim;
+  Value buffer = getMemRefField(SparseTensorFieldKind::IdxMemRef, idx);
+  if (idxDim >= cooStart) {
+    unsigned rank = enc.getDimLevelType().size();
     Value stride = constantIndex(builder, loc, rank - cooStart);
-    SmallVector<Value> buffersArray(buffers.begin(), buffers.end());
-    MutSparseTensorDescriptor mutDesc(tp, buffersArray);
-    // Calculate subbuffer size as memSizes[idx] / (stride).
-    Value subBufferSize = mutDesc.getIdxMemSize(builder, loc, cooStart);
-    subBufferSize = builder.create<arith::DivUIOp>(loc, subBufferSize, stride);
-
-    // Create views of the linear idx buffer for the COO indices.
-    for (unsigned i = cooStart; i < rank; i++) {
-      Value subBuffer = builder.create<memref::SubViewOp>(
-          loc, buffer,
-          /*offset=*/ValueRange{constantIndex(builder, loc, i - cooStart)},
-          /*size=*/ValueRange{subBufferSize},
-          /*step=*/ValueRange{stride});
-      expandedFields.push_back(subBuffer);
-    }
-    expandedFields.push_back(buffers[buffers.size() - 2]); // The Values memref.
-    expandedFields.push_back(buffers.back());              // The specifier.
-    fields = expandedFields;
-  } else {
-    fields = buffers;
+    Value size = getIdxMemSize(builder, loc, cooStart);
+    size = builder.create<arith::DivUIOp>(loc, size, stride);
+    buffer = builder.create<memref::SubViewOp>(
+        loc, buffer,
+        /*offset=*/ValueRange{constantIndex(builder, loc, idxDim - cooStart)},
+        /*size=*/ValueRange{size},
+        /*step=*/ValueRange{stride});
   }
-
-  sanityCheck();
+  return buffer;
 }
 
 //===----------------------------------------------------------------------===//
@@ -156,8 +139,7 @@ void sparse_tensor::foreachFieldInSparseTensor(
     const SparseTensorEncodingAttr enc,
     llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
                             DimLevelType)>
-        callback,
-    bool isBuffer) {
+        callback) {
   assert(enc);
 
 #define RETURN_ON_FALSE(idx, kind, dim, dlt)                                   \
@@ -165,11 +147,13 @@ void sparse_tensor::foreachFieldInSparseTensor(
     return;
 
   unsigned rank = enc.getDimLevelType().size();
-  unsigned cooStart = isBuffer ? getCOOStart(enc) : rank;
+  unsigned end = getCOOStart(enc);
+  if (end != rank)
+    end += 1;
   static_assert(kDataFieldStartingIdx == 0);
   unsigned fieldIdx = kDataFieldStartingIdx;
   // Per-dimension storage.
-  for (unsigned r = 0; r < rank; r++) {
+  for (unsigned r = 0; r < end; r++) {
     // Dimension level types apply in order to the reordered dimension.
     // As a result, the compound type can be constructed directly in the given
     // order.
@@ -178,8 +162,7 @@ void sparse_tensor::foreachFieldInSparseTensor(
       RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt);
       RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
     } else if (isSingletonDLT(dlt)) {
-      if (r < cooStart)
-        RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
+      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
     } else {
       assert(isDenseDLT(dlt)); // no fields
     }
@@ -231,38 +214,32 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
           return callback(valMemType, fieldIdx, fieldKind, dim, dlt);
         };
         llvm_unreachable("unrecognized field kind");
-      },
-      /*isBuffer=*/true);
+      });
 }
 
-unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc,
-                                                 bool isBuffer) {
+unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
   unsigned numFields = 0;
-  foreachFieldInSparseTensor(
-      enc,
-      [&numFields](unsigned, SparseTensorFieldKind, unsigned,
-                   DimLevelType) -> bool {
-        numFields++;
-        return true;
-      },
-      isBuffer);
+  foreachFieldInSparseTensor(enc,
+                             [&numFields](unsigned, SparseTensorFieldKind,
+                                          unsigned, DimLevelType) -> bool {
+                               numFields++;
+                               return true;
+                             });
   return numFields;
 }
 
 unsigned
 sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
   unsigned numFields = 0; // one value memref
-  foreachFieldInSparseTensor(
-      enc,
-      [&numFields](unsigned fidx, SparseTensorFieldKind, unsigned,
-                   DimLevelType) -> bool {
-        if (fidx >= kDataFieldStartingIdx)
-          numFields++;
-        return true;
-      },
-      /*isBuffer=*/true);
+  foreachFieldInSparseTensor(enc,
+                             [&numFields](unsigned fidx, SparseTensorFieldKind,
+                                          unsigned, DimLevelType) -> bool {
+                               if (fidx >= kDataFieldStartingIdx)
+                                 numFields++;
+                               return true;
+                             });
   numFields -= 1; // the last field is MetaData field
-  assert(numFields == getNumFieldsFromEncoding(enc, /*isBuffer=*/true) -
-                          kDataFieldStartingIdx - 1);
+  assert(numFields ==
+         getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1);
   return numFields;
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 8d25ba2160e44..9ca7149056ddd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -77,8 +77,7 @@ void foreachFieldInSparseTensor(
     llvm::function_ref<bool(unsigned /*fieldIdx*/,
                             SparseTensorFieldKind /*fieldKind*/,
                             unsigned /*dim (if applicable)*/,
-                            DimLevelType /*DLT (if applicable)*/)>,
-    bool isBuffer = false);
+                            DimLevelType /*DLT (if applicable)*/)>);
 
 /// Same as above, except that it also builds the Type for the corresponding
 /// field.
@@ -90,7 +89,7 @@ void foreachFieldAndTypeInSparseTensor(
                             DimLevelType /*DLT (if applicable)*/)>);
 
 /// Gets the total number of fields for the given sparse tensor encoding.
-unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc, bool isBuffer);
+unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc);
 
 /// Gets the total number of data fields (index arrays, pointer arrays, and a
 /// value array) for the given sparse tensor encoding.
@@ -107,12 +106,7 @@ inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) {
 }
 
 /// Provides methods to access fields of a sparse tensor with the given
-/// encoding. When isBuffer is true, the fields are the actual buffers of the
-/// sparse tensor storage. In particular, when a linear buffer is used to
-/// store the COO data as an array-of-structures, the fields include the
-/// linear buffer (isBuffer=true) or includes the subviews of the buffer for the
-/// indices (isBuffer=false).
-template <bool isBuffer>
+/// encoding.
 class StorageLayout {
 public:
   explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {}
@@ -132,7 +126,7 @@ class StorageLayout {
   }
 
   static unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
-    return sparse_tensor::getNumFieldsFromEncoding(enc, isBuffer);
+    return sparse_tensor::getNumFieldsFromEncoding(enc);
   }
 
   static void foreachFieldInSparseTensor(
@@ -140,7 +134,7 @@ class StorageLayout {
       llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
                               DimLevelType)>
           callback) {
-    return sparse_tensor::foreachFieldInSparseTensor(enc, callback, isBuffer);
+    return sparse_tensor::foreachFieldInSparseTensor(enc, callback);
   }
 
   std::pair<unsigned, unsigned>
@@ -148,7 +142,7 @@ class StorageLayout {
                          std::optional<unsigned> dim) const {
     unsigned fieldIdx = -1u;
     unsigned stride = 1;
-    if (isBuffer && kind == SparseTensorFieldKind::IdxMemRef) {
+    if (kind == SparseTensorFieldKind::IdxMemRef) {
       assert(dim.has_value());
       unsigned cooStart = getCOOStart(enc);
       unsigned rank = enc.getDimLevelType().size();
@@ -222,18 +216,11 @@ class SparseTensorDescriptorImpl {
   using ValueArrayRef = typename std::conditional<mut, SmallVectorImpl<Value> &,
                                                   ValueRange>::type;
 
-  SparseTensorDescriptorImpl(Type tp)
-      : rType(tp.cast<RankedTensorType>()), fields() {}
-
   SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields)
       : rType(tp.cast<RankedTensorType>()), fields(fields) {
-    sanityCheck();
-  }
-
-  void sanityCheck() {
-    assert(getSparseTensorEncoding(rType) &&
-           StorageLayout<mut>::getNumFieldsFromEncoding(
-               getSparseTensorEncoding(rType)) == fields.size());
+    assert(getSparseTensorEncoding(tp) &&
+           getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
+               fields.size());
     // We should make sure the class is trivially copyable (and should be small
     // enough) such that we can pass it by value.
     static_assert(
@@ -244,22 +231,10 @@ class SparseTensorDescriptorImpl {
   unsigned getMemRefFieldIndex(SparseTensorFieldKind kind,
                                std::optional<unsigned> dim) const {
     // Delegates to storage layout.
-    StorageLayout<mut> layout(getSparseTensorEncoding(rType));
+    StorageLayout layout(getSparseTensorEncoding(rType));
     return layout.getMemRefFieldIndex(kind, dim);
   }
 
-  unsigned getPtrMemRefIndex(unsigned ptrDim) const {
-    return getMemRefFieldIndex(SparseTensorFieldKind::PtrMemRef, ptrDim);
-  }
-
-  unsigned getIdxMemRefIndex(unsigned idxDim) const {
-    return getMemRefFieldIndex(SparseTensorFieldKind::IdxMemRef, idxDim);
-  }
-
-  unsigned getValMemRefIndex() const {
-    return getMemRefFieldIndex(SparseTensorFieldKind::ValMemRef, std::nullopt);
-  }
-
   unsigned getNumFields() const { return fields.size(); }
 
   ///
@@ -281,10 +256,6 @@ class SparseTensorDescriptorImpl {
     return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim);
   }
 
-  Value getIdxMemRef(unsigned idxDim) const {
-    return getMemRefField(SparseTensorFieldKind::IdxMemRef, idxDim);
-  }
-
   Value getValMemRef() const {
     return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt);
   }
@@ -299,15 +270,19 @@ class SparseTensorDescriptorImpl {
     return fields[fidx];
   }
 
-  Value getField(unsigned fidx) const {
-    assert(fidx < fields.size());
-    return fields[fidx];
+  Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+    return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
+                             dim);
   }
 
-  ValueRange getMemRefFields() const {
-    ValueRange ret = fields;
-    // Drop the last metadata fields.
-    return ret.slice(0, fields.size() - 1);
+  Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+    return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
+                             dim);
+  }
+
+  Value getValMemSize(OpBuilder &builder, Location loc) const {
+    return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
+                             std::nullopt);
   }
 
   Type getMemRefElementType(SparseTensorFieldKind kind,
@@ -331,23 +306,15 @@ class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
   MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers)
       : SparseTensorDescriptorImpl<true>(tp, buffers) {}
 
-  ///
-  /// Getters: get the value for required field.
-  ///
-
-  Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
-    return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
-                             dim);
-  }
-
-  Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
-    return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
-                             dim);
+  Value getField(unsigned fidx) const {
+    assert(fidx < fields.size());
+    return fields[fidx];
   }
 
-  Value getValMemSize(OpBuilder &builder, Location loc) const {
-    return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
-                             std::nullopt);
+  ValueRange getMemRefFields() const {
+    ValueRange ret = fields;
+    // Drop the last metadata fields.
+    return ret.slice(0, fields.size() - 1);
   }
 
   ///
@@ -384,7 +351,7 @@ class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
 
   std::pair<unsigned, unsigned>
   getIdxMemRefIndexAndStride(unsigned idxDim) const {
-    StorageLayout<true> layout(getSparseTensorEncoding(rType));
+    StorageLayout layout(getSparseTensorEncoding(rType));
     return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef,
                                          idxDim);
   }
@@ -393,19 +360,17 @@ class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
     auto enc = getSparseTensorEncoding(rType);
     unsigned cooStart = getCOOStart(enc);
     assert(cooStart < enc.getDimLevelType().size());
-    return getIdxMemRef(cooStart);
+    return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart);
   }
 };
 
 class SparseTensorDescriptor : public SparseTensorDescriptorImpl<false> {
 public:
-  SparseTensorDescriptor(OpBuilder &builder, Location loc, Type tp,
-                         ValueArrayRef buffers);
+  SparseTensorDescriptor(Type tp, ValueArrayRef buffers)
+      : SparseTensorDescriptorImpl<false>(tp, buffers) {}
 
-private:
-  // Store the fields passed to SparseTensorDescriptorImpl when the tensor has
-  // a COO region.
-  SmallVector<Value> expandedFields;
+  Value getIdxMemRefOrView(OpBuilder &builder, Location loc,
+                           unsigned idxDim) const;
 };
 
 /// Returns the "tuple" value of the adapted tensor.
@@ -425,11 +390,9 @@ inline Value genTuple(OpBuilder &builder, Location loc,
   return genTuple(builder, loc, desc.getTensorType(), desc.getFields());
 }
 
-inline SparseTensorDescriptor
-getDescriptorFromTensorTuple(OpBuilder &builder, Location loc, Value tensor) {
+inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
   auto tuple = getTuple(tensor);
-  return SparseTensorDescriptor(builder, loc, tuple.getResultTypes()[0],
-                                tuple.getInputs());
+  return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs());
 }
 
 inline MutSparseTensorDescriptor


        


More information about the Mlir-commits mailing list