[Mlir-commits] [mlir] 22212ca - [mlir][sparse] simplify some header code (#70989)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 2 09:31:15 PDT 2023


Author: Aart Bik
Date: 2023-11-02T09:31:11-07:00
New Revision: 22212ca745cfaa6e8e55d808fb83c6dd94791f74

URL: https://github.com/llvm/llvm-project/commit/22212ca745cfaa6e8e55d808fb83c6dd94791f74
DIFF: https://github.com/llvm/llvm-project/commit/22212ca745cfaa6e8e55d808fb83c6dd94791f74.diff

LOG: [mlir][sparse] simplify some header code (#70989)

This is a first revision in a small series of changes that removes
duplications between direct encoding methods and sparse tensor type
wrapper methods (in favor of the latter abstraction, since it provides
more safety). The goal is to simply end up with "just" SparseTensorType

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 9776361da480920..94e7d12b9ee915f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -22,39 +22,24 @@
 
 //===----------------------------------------------------------------------===//
 //
-// Type aliases to help code be more self-documenting.  Unfortunately
+// Type aliases to help code be more self-documenting. Unfortunately
 // these are not type-checked, so they only provide documentation rather
 // than doing anything to prevent mixups.
 //
-// We must include these here (rather than in "SparseTensorType.h")
-// because they are used by methods declared in the tablegen files.
-//
 //===----------------------------------------------------------------------===//
 
 namespace mlir {
 namespace sparse_tensor {
 
-/// The type of dimension identifiers, and dimension-ranks.  We use the
-/// same type for both identifiers and ranks because the latter are used
-/// mainly for ordering-comparisons against the former (just like how the
-/// one-past-the-end iterators are used).
+/// The type of dimension identifiers and dimension-ranks.
 using Dimension = uint64_t;
 
-/// The type of level identifiers, and level-ranks.  We use the same
-/// type for both identifiers and ranks because the latter are used
-/// mainly for ordering-comparisons against the former (just like how
-/// the one-past-the-end iterators are used).
+/// The type of level identifiers and level-ranks.
 using Level = uint64_t;
 
-/// The type for individual components of a compile-time shape.  We avoid
-/// calling this "size" because we use the term "sizes" to indicate the
-/// actual run-time sizes, whereas this type also allows the value
-/// `ShapedType::kDynamic`.
-using DynSize = int64_t;
-
-/// The type for individual components of a compile-time shape which
-/// are known not to be `ShapedType::kDynamic`.
-using StaticSize = int64_t;
+/// The type for individual components of a compile-time shape,
+/// including the value `ShapedType::kDynamic` (for shapes).
+using Size = int64_t;
 
 } // namespace sparse_tensor
 } // namespace mlir
@@ -63,9 +48,6 @@ using StaticSize = int64_t;
 // TableGen-defined classes
 //===----------------------------------------------------------------------===//
 
-// We must include Enums.h.inc before AttrDefs.h.inc due to dependency between
-// StorageSpecifierKindAttr and StorageSpeciferKind Enum.
-
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.h.inc"
 
@@ -87,11 +69,6 @@ using StaticSize = int64_t;
 namespace mlir {
 namespace sparse_tensor {
 
-// NOTE: `Value::getType` doesn't check for null before trying to
-// dereference things.  Therefore we check, because an assertion-failure
-// is easier to debug than a segfault.  Presumably other `T::getType`
-// methods are similarly susceptible.
-
 /// Convenience method to abbreviate casting `getType()`.
 template <typename T>
 inline RankedTensorType getRankedTensorType(T &&t) {
@@ -192,33 +169,15 @@ bool isBlockSparsity(AffineMap dimToLvl);
 // Reordering.
 //
 
-// This CPP guard is to disable deprecation warnings for the LLVM
-// build-bot, while making it easy to re-enable it for local development.
-#if 0
-#define DEPRECATED                                                             \
-  LLVM_DEPRECATED("The toOrigDim/toStoredDim functions are deprecated "        \
-                  "because they only work for permutations; therefore any "    \
-                  "code using them cannot support non-permutations.",          \
-                  "")
-#else
-#define DEPRECATED
-#endif
-
 /// [deprecated] Convenience method to translate the given level to the
-/// corresponding dimension.  Requires: `0 <= l < lvlRank`.
-DEPRECATED Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
-DEPRECATED Dimension toOrigDim(RankedTensorType type, Level l);
+/// corresponding dimension. Requires: `0 <= l < lvlRank`.
+Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
+Dimension toOrigDim(RankedTensorType type, Level l);
 
 /// [deprecated] Convenience method to translate the given dimension to
-/// the corresponding level.  Requires: `0 <= d < dimRank`.
-DEPRECATED Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
-DEPRECATED Level toStoredDim(RankedTensorType type, Dimension d);
-
-#undef DEPRECATED
-
-namespace detail {
-Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth);
-} // namespace detail
+/// the corresponding level. Requires: `0 <= d < dimRank`.
+Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
+Level toStoredDim(RankedTensorType type, Dimension d);
 
 } // namespace sparse_tensor
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 7e2ad11752b34d8..3c73b19319e588c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -403,20 +403,6 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     /// always have the identity mapping).
     bool isPermutation() const;
 
-    //
-    // posWidth/crdWidth methods.
-    //
-
-    /// Returns the type for position storage based on posWidth.
-    /// Asserts that the encoding is non-null (since there's nowhere
-    /// to get the `MLIRContext` from).
-    Type getPosType() const;
-
-    /// Returns the type for coordinate storage based on crdWidth.
-    /// Asserts that the encoding is non-null (since there's nowhere
-    /// to get the `MLIRContext` from).
-    Type getCrdType() const;
-
     //
     // dimSlices methods.
     //
@@ -571,5 +557,4 @@ def SparseTensorCrdTransDirectionAttr
                "CrdTransDirection"> {
 }
 
-
 #endif // SPARSETENSOR_ATTRDEFS

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 3e9cada83c6d50b..e808057cf6b0a67 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -60,10 +60,7 @@ class SparseTensorType {
       : SparseTensorType(
             RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
 
-  // Copy-assignment would be implicitly deleted (because our fields
-  // are const), so we explicitly delete it for clarity.
   SparseTensorType &operator=(const SparseTensorType &) = delete;
-  // So we must explicitly define the copy-ctor to silence -Wdeprecated-copy.
   SparseTensorType(const SparseTensorType &) = default;
 
   //
@@ -243,10 +240,10 @@ class SparseTensorType {
   Level getLvlRank() const { return lvlRank; }
 
   /// Returns the dimension-shape.
-  ArrayRef<DynSize> getDimShape() const { return rtp.getShape(); }
+  ArrayRef<Size> getDimShape() const { return rtp.getShape(); }
 
   /// Returns the Level-shape.
-  SmallVector<DynSize> getLvlShape() const {
+  SmallVector<Size> getLvlShape() const {
     return getEncoding().tranlateShape(getDimShape(),
                                        CrdTransDirectionKind::dim2lvl);
   }
@@ -260,19 +257,11 @@ class SparseTensorType {
   /// Safely looks up the requested dimension-DynSize.  If you intend
   /// to check the result with `ShapedType::isDynamic`, then see the
   /// `getStaticDimSize` method instead.
-  DynSize getDynamicDimSize(Dimension d) const {
+  Size getDynamicDimSize(Dimension d) const {
     assert(d < getDimRank() && "Dimension is out of bounds");
     return getDimShape()[d];
   }
 
-  /// Safely looks up the requested dimension-size, mapping dynamic
-  /// sizes to `std::nullopt`.
-  std::optional<StaticSize> getStaticDimSize(Dimension d) const {
-    const DynSize sh = getDynamicDimSize(d);
-    return ShapedType::isDynamic(sh) ? std::nullopt
-                                     : std::optional<StaticSize>(sh);
-  }
-
   /// Returns true if no dimension has dynamic size.
   bool hasStaticDimShape() const { return rtp.hasStaticShape(); }
 
@@ -318,12 +307,16 @@ class SparseTensorType {
 
   /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
   Type getCrdType() const {
-    return detail::getIntegerOrIndexType(getContext(), getCrdWidth());
+    if (getCrdWidth())
+      return IntegerType::get(getContext(), getCrdWidth());
+    return IndexType::get(getContext());
   }
 
   /// Returns the position-overhead MLIR type, defaulting to `IndexType`.
   Type getPosType() const {
-    return detail::getIntegerOrIndexType(getContext(), getPosWidth());
+    if (getPosWidth())
+      return IntegerType::get(getContext(), getPosWidth());
+    return IndexType::get(getContext());
   }
 
 private:
@@ -336,14 +329,13 @@ class SparseTensorType {
   const AffineMap lvlToDim;
 };
 
-/// Convenience methods to abbreviate wrapping `getRankedTensorType`.
-template <typename T>
-inline SparseTensorType getSparseTensorType(T t) {
-  return SparseTensorType(getRankedTensorType(t));
+/// Convenience methods to obtain a SparseTensorType from a Value.
+inline SparseTensorType getSparseTensorType(Value val) {
+  return SparseTensorType(cast<RankedTensorType>(val.getType()));
 }
-inline std::optional<SparseTensorType> tryGetSparseTensorType(Value v) {
-  if (isa<RankedTensorType>(v.getType()))
-    return getSparseTensorType(v);
+inline std::optional<SparseTensorType> tryGetSparseTensorType(Value val) {
+  if (auto rtp = dyn_cast<RankedTensorType>(val.getType()))
+    return SparseTensorType(rtp);
   return std::nullopt;
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 9a6d3161be3d6e4..6080317d07a64e0 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -270,23 +270,6 @@ SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-Type mlir::sparse_tensor::detail::getIntegerOrIndexType(MLIRContext *ctx,
-                                                        unsigned bitwidth) {
-  if (bitwidth)
-    return IntegerType::get(ctx, bitwidth);
-  return IndexType::get(ctx);
-}
-
-Type SparseTensorEncodingAttr::getPosType() const {
-  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
-  return detail::getIntegerOrIndexType(getContext(), getPosWidth());
-}
-
-Type SparseTensorEncodingAttr::getCrdType() const {
-  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
-  return detail::getIntegerOrIndexType(getContext(), getCrdWidth());
-}
-
 SparseTensorEncodingAttr
 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
@@ -722,7 +705,7 @@ SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
-    ArrayRef<DynSize> dimShape, Type elementType,
+    ArrayRef<Size> dimShape, Type elementType,
     function_ref<InFlightDiagnostic()> emitError) const {
   // Check structural integrity.  In particular, this ensures that the
   // level-rank is coherent across all the fields.
@@ -1312,7 +1295,7 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
 
   // TODO: we can remove this after SparseTensorEncoding always returns non-null
   // dimToLvl map.
-  ArrayRef<DynSize> shape = stt.getDimShape();
+  ArrayRef<Size> shape = stt.getDimShape();
   if (stt.isPermutation()) {
     Dimension dim = toOrigDim(stt, lvl);
     if (!ShapedType::isDynamic(shape[dim])) {
@@ -1378,8 +1361,8 @@ LogicalResult ReinterpretMapOp::verify() {
   if (srcStt.getElementType() != dstStt.getElementType())
     return emitError("Element type mismatch between source/dest tensors");
 
-  SmallVector<DynSize> srcLvlShape = srcStt.getLvlShape();
-  SmallVector<DynSize> dstLvlShape = dstStt.getLvlShape();
+  SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
+  SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
   for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
     if (srcLvlSz != dstLvlSz) {
       // Should we allow one side to be dynamic size, e.g., <?x?> should be
@@ -1616,13 +1599,13 @@ LogicalResult ConcatenateOp::verify() {
   }
 
   for (Dimension d = 0; d < dimRank; d++) {
-    const DynSize dstSh = dstTp.getDimShape()[d];
+    const Size dstSh = dstTp.getDimShape()[d];
     if (d == concatDim) {
       if (!ShapedType::isDynamic(dstSh)) {
         // If we reach here, then all inputs have static shapes.  So we
         // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
         // to avoid redundant assertions in the loop.
-        StaticSize sumSz = 0;
+        Size sumSz = 0;
         for (const auto src : getInputs())
           sumSz += getSparseTensorType(src).getDimShape()[d];
         // If all dimension are statically known, the sum of all the input
@@ -1633,7 +1616,7 @@ LogicalResult ConcatenateOp::verify() {
               "sum of all the concatenation dimensions of the input tensors.");
       }
     } else {
-      DynSize prev = dstSh;
+      Size prev = dstSh;
       for (const auto src : getInputs()) {
         const auto sh = getSparseTensorType(src).getDimShape()[d];
         if (!ShapedType::isDynamic(prev) && sh != prev)
@@ -1808,8 +1791,8 @@ LogicalResult SortOp::verify() {
   // FIXME: update the types of variables used in expressions bassed as
   // the `minSize` argument, to avoid implicit casting at the callsites
   // of this lambda.
-  const auto checkDim = [&](Value v, StaticSize minSize, const char *message) {
-    const DynSize sh = getMemRefType(v).getShape()[0];
+  const auto checkDim = [&](Value v, Size minSize, const char *message) {
+    const Size sh = getMemRefType(v).getShape()[0];
     if (!ShapedType::isDynamic(sh) && sh < minSize)
       emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
   };

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 743476259619996..f6fb59fa2c3b84b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -51,8 +51,6 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
   llvm_unreachable("Unknown overhead type");
 }
 
-// TODO: should offer an overload of this that takes a `MLIRContext*`
-// instead of the builder, similar to `detail::getIntegerOrIndexType`.
 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
   switch (ot) {
   case OverheadType::kIndex:
@@ -209,7 +207,7 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
 
 void mlir::sparse_tensor::genReshapeDstShape(
     OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
-    ArrayRef<Value> srcShape, ArrayRef<StaticSize> staticDstShape,
+    ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
     ArrayRef<ReassociationIndices> reassociation) {
   // Collapse shape.
   if (reassociation.size() < srcShape.size()) {
@@ -242,7 +240,7 @@ void mlir::sparse_tensor::genReshapeDstShape(
       if (staticDstShape[j] == ShapedType::kDynamic) {
         // The expanded dimension has dynamic size. We compute the dimension
         // by dividing srcDim by the product of the static dimensions.
-        StaticSize product = 1;
+        Size product = 1;
         for (unsigned k = start; k < start + map.size(); k++) {
           if (staticDstShape[k] != ShapedType::kDynamic) {
             product *= staticDstShape[k];
@@ -423,7 +421,8 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
 void sparse_tensor::foreachInSparseConstant(
     OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
     function_ref<void(ArrayRef<Value>, Value)> callback) {
-  const Dimension dimRank = getSparseTensorType(attr).getDimRank();
+  const Dimension dimRank =
+      SparseTensorType(getRankedTensorType(attr)).getDimRank();
   const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
   const auto values = attr.getValues().getValues<Attribute>();
 
@@ -494,8 +493,8 @@ SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc,
 #ifndef NDEBUG
   const auto memTp = cast<MemRefType>(mem.getType());
   assert(memTp.getRank() == 1);
-  const DynSize memSh = memTp.getDimSize(0);
-  assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<DynSize>(size));
+  const Size memSh = memTp.getDimSize(0);
+  assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size));
   assert(offsetIdx == 0 || offsetIdx < size);
 #endif // NDEBUG
   SmallVector<Value> vs;
@@ -516,8 +515,8 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
   const size_t vsize = vs.size();
   const auto memTp = cast<MemRefType>(mem.getType());
   assert(memTp.getRank() == 1);
-  const DynSize memSh = memTp.getDimSize(0);
-  assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<DynSize>(vsize));
+  const Size memSh = memTp.getDimSize(0);
+  assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize));
   assert(offsetIdx == 0 || offsetIdx < vsize);
 #endif // NDEBUG
   for (const auto &v : llvm::enumerate(vs)) {
@@ -546,11 +545,11 @@ Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc,
   // The memref ReshapeOp requires the sizes buffer to have a static
   // shape.
   const auto iTp = builder.getIndexType();
-  const SmallVector<DynSize, 1> lvlSizesShape{static_cast<DynSize>(lvlRank)};
+  const SmallVector<Size, 1> lvlSizesShape{static_cast<Size>(lvlRank)};
   const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp);
   lvlCoords = builder.create<memref::CastOp>(loc, lvlSizesTp, lvlCoords);
   // Finally, create the ReshapeOp.
-  const SmallVector<DynSize> resShape(lvlRank, ShapedType::kDynamic);
+  const SmallVector<Size> resShape(lvlRank, ShapedType::kDynamic);
   const Type elemTp = getMemRefType(valuesBuffer).getElementType();
   const auto resTp = MemRefType::get(resShape, elemTp);
   return builder.create<memref::ReshapeOp>(loc, resTp, valuesBuffer, lvlCoords);
@@ -628,7 +627,7 @@ void sparse_tensor::fillDimShape(OpBuilder &builder, Location loc,
                                  SmallVectorImpl<Value> &out) {
   out.clear();
   out.reserve(stt.getDimRank());
-  for (const DynSize sh : stt.getDimShape()) {
+  for (const Size sh : stt.getDimShape()) {
     const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
     out.push_back(constantIndex(builder, loc, s));
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 4673d24fc81f3f9..1f53f3525203c70 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -163,8 +163,7 @@ Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
 /// stored into dstShape.
 void genReshapeDstShape(OpBuilder &builder, Location loc,
                         SmallVectorImpl<Value> &dstShape,
-                        ArrayRef<Value> srcShape,
-                        ArrayRef<StaticSize> staticDstShape,
+                        ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
                         ArrayRef<ReassociationIndices> reassociation);
 
 /// Reshape coordinates during a reshaping operation.

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 5cdf8cd7ccc9d8b..8c6312150f4c832 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -178,7 +178,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
   SmallVector<Value> dimSizes;
   dimSizes.reserve(dimRank);
   unsigned i = 0; // cumulative index into `dynSizes`.
-  for (const DynSize sh : stt.getDimShape())
+  for (const Size sh : stt.getDimShape())
     dimSizes.push_back(ShapedType::isDynamic(sh)
                            ? dynSizes[i++]
                            : constantIndex(builder, loc, sh));

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 79f6013640440df..e9d4005feaee389 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -96,8 +96,9 @@ static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
   // which is all we care about (for supporting permutations).
   const Dimension dim =
       stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
-  if (const auto sz = stt.getStaticDimSize(dim))
-    return constantIndex(builder, loc, *sz);
+  const Size sz = stt.getDynamicDimSize(dim);
+  if (!ShapedType::isDynamic(sz))
+    return constantIndex(builder, loc, sz);
   // If we cannot statically compute the size from the shape, then we
   // must dynamically query it.  (In principle we could also dynamically
   // compute it, but since we already did so to construct the `tensor`
@@ -112,8 +113,9 @@ static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
 static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
                                  SparseTensorType stt, Value tensor,
                                  Dimension dim) {
-  if (const auto sz = stt.getStaticDimSize(dim))
-    return constantIndex(builder, loc, *sz);
+  const Size sz = stt.getDynamicDimSize(dim);
+  if (!ShapedType::isDynamic(sz))
+    return constantIndex(builder, loc, sz);
   if (stt.hasEncoding())
     return genDimSizeCall(builder, loc, tensor, dim);
   return linalg::createOrFoldDimOp(builder, loc, tensor, dim);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index c00f19916e49fbe..13388dce6bbb5ec 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -729,7 +729,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
       for (Dimension d : dstTp.getDimShape())
         dstSizes.push_back(constantIndex(rewriter, loc, d));
     } else {
-      ArrayRef<DynSize> dstShape = dstTp.getDimShape();
+      ArrayRef<Size> dstShape = dstTp.getDimShape();
       genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
                          op.getReassociationIndices());
       for (auto [idx, shape] : llvm::enumerate(dstShape)) {
@@ -970,11 +970,10 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
       // Accumulates the offset. Note that only static-shaped inputs are allowed
       // by concatenate op verifier, which saves us from computing the offset
       // dynamically.
-      const auto sh = getSparseTensorType(input).getStaticDimSize(conDim);
-      assert(sh.has_value());
-      offset = rewriter.create<arith::AddIOp>(
-          loc, offset, constantIndex(rewriter, loc, *sh));
-
+      const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
+      assert(!ShapedType::isDynamic(sz));
+      offset = rewriter.create<arith::AddIOp>(loc, offset,
+                                              constantIndex(rewriter, loc, sz));
       iterArg = foreachOp.getResult(0);
       dstBuf.val = iterArg;
     }


        


More information about the Mlir-commits mailing list