[Mlir-commits] [mlir] [mlir][sparse] code cleanup, remove FIXMEs (PR #73575)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 14:25:53 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/73575.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (+9-7) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+11-25) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp (-26) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h (-7) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+2-4) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+3-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index eb7c50ae2efdf8c..f102f0270154264 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -163,13 +163,15 @@ bool isBlockSparsity(AffineMap dimToLvl);
 // Reordering.
 //
 
-/// [deprecated] Convenience method to translate the given level to the
-/// corresponding dimension. Requires: `0 <= l < lvlRank`.
-Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
-
-/// [deprecated] Convenience method to translate the given dimension to
-/// the corresponding level. Requires: `0 <= d < dimRank`.
-Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
+/// Convenience method to translate the given level to the corresponding
+/// dimension.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`.
+Dimension toDim(SparseTensorEncodingAttr enc, Level l);
+
+/// Convenience method to translate the given dimension to the corresponding
+/// level.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`.
+Level toLvl(SparseTensorEncodingAttr enc, Dimension d);
 
 } // namespace sparse_tensor
 } // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 791aeebee5a328d..28e07e1669e7919 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -375,14 +375,12 @@ SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
 
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
-  // FIXME: `toOrigDim` is deprecated.
-  return getStaticDimSliceOffset(toOrigDim(*this, lvl));
+  return getStaticDimSliceOffset(toDim(*this, lvl));
 }
 
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
-  // FIXME: `toOrigDim` is deprecated.
-  return getStaticDimSliceStride(toOrigDim(*this, lvl));
+  return getStaticDimSliceStride(toDim(*this, lvl));
 }
 
 SmallVector<int64_t>
@@ -399,9 +397,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
   if (isPermutation()) {
     for (unsigned r = 0; r < rank; r++) {
       // FIXME: `toOrigDim` and `toStoredDim` are deprecated.
-      unsigned trans = dir == CrdTransDirectionKind::dim2lvl
-                           ? toOrigDim(*this, r)
-                           : toStoredDim(*this, r);
+      unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
+                                                             : toLvl(*this, r);
       ret.push_back(srcShape[trans]);
     }
     return ret;
@@ -925,31 +922,20 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
       ordered);
 }
 
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc,
-                                         Level l) {
+Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
   if (enc) {
-    if (const auto dimToLvl = enc.getDimToLvl()) {
-      assert(enc.isPermutation());
+    assert(enc.isPermutation() && "Non permutation map");
+    if (const auto dimToLvl = enc.getDimToLvl())
       return dimToLvl.getDimPosition(l);
-    }
   }
   return l;
 }
 
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
-                                       Dimension d) {
+Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
   if (enc) {
-    if (const auto dimToLvl = enc.getDimToLvl()) {
-      assert(enc.isPermutation());
-      auto maybePos =
-          dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext()));
-      assert(maybePos.has_value());
-      return *maybePos;
-    }
+    assert(enc.isPermutation() && "");
+    if (const auto lvlToDim = enc.getLvlToDim())
+      return lvlToDim.getDimPosition(d);
   }
   return d;
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 1200b999f9a90ff..33d449aac5a3550 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -546,32 +546,6 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
   }
 }
 
-Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc,
-                                           SparseTensorEncodingAttr enc,
-                                           ValueRange dimSizes,
-                                           Value valuesBuffer,
-                                           Value lvlCoords) {
-  // Reuse the `lvlCoords` buffer to store the level-sizes.
-  const Level lvlRank = enc.getLvlRank();
-  SmallVector<Value> lvlSizes;
-  lvlSizes.reserve(lvlRank);
-  for (Level l = 0; l < lvlRank; l++)
-    // FIXME: `toOrigDim` is deprecated.
-    lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]);
-  storeAll(builder, loc, lvlCoords, lvlSizes);
-  // The memref ReshapeOp requires the sizes buffer to have a static
-  // shape.
-  const auto iTp = builder.getIndexType();
-  const SmallVector<Size, 1> lvlSizesShape{static_cast<Size>(lvlRank)};
-  const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp);
-  lvlCoords = builder.create<memref::CastOp>(loc, lvlSizesTp, lvlCoords);
-  // Finally, create the ReshapeOp.
-  const SmallVector<Size> resShape(lvlRank, ShapedType::kDynamic);
-  const Type elemTp = getMemRefType(valuesBuffer).getElementType();
-  const auto resTp = MemRefType::get(resShape, elemTp);
-  return builder.create<memref::ReshapeOp>(loc, resTp, valuesBuffer, lvlCoords);
-}
-
 TypedValue<BaseMemRefType>
 sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
   auto tTp = llvm::cast<TensorType>(tensor.getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index cb0acdd2be9f7b0..0ce33427281f598 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -277,13 +277,6 @@ SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size,
 void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
               size_t offsetIdx = 0, Value offsetVal = Value());
 
-/// Reshapes the linear values buffer for an annotated all dense sparse tensor
-/// to match the shape of the corresponding dense tensor to support direct
-/// access of the buffer through `lvlCoords`.
-Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
-                            SparseTensorEncodingAttr enc, ValueRange dimSizes,
-                            Value valuesBuffer, Value lvlCoords);
-
 // Generates code to cast a tensor to a memref.
 TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
                                        Value tensor);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index f8bcc0fe12a1093..413a835ff14d314 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -68,15 +68,13 @@ static constexpr unsigned kSliceIterWidth = 3;
 static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
                             Level lvl) {
   auto enc = getSparseTensorEncoding(tensor.getType());
-  // FIXME: `toOrigDim` is deprecated
-  return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
+  return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl));
 }
 
 static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
                             Level lvl) {
   auto enc = getSparseTensorEncoding(tensor.getType());
-  // FIXME: `toOrigDim` is deprecated
-  return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
+  return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
 }
 
 /// Converts a coordinate relative to the slice to the coordinate relative
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5374ab55c5c0d9f..103908b2cf5bd89 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -661,8 +661,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
           SmallVector<Value> srcDcvs;
           srcDcvs.reserve(srcRank);
           for (Dimension d = 0; d < srcRank; d++) {
-            // FIXME: `toStoredDim` is deprecated
-            Level lvl = toStoredDim(encSrc, d);
+            Level lvl = toLvl(encSrc, d);
             srcDcvs.push_back(srcLcvs[lvl]);
           }
 
@@ -766,8 +765,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
           SmallVector<Value> srcDcvs;
           srcDcvs.reserve(dimRank);
           for (Dimension d = 0; d < dimRank; d++) {
-            // FIXME: `toStoredDim` is deprecated
-            Level lvl = toStoredDim(encSrc, d);
+            Level lvl = toLvl(encSrc, d);
             srcDcvs.push_back(srcLcvs[lvl]);
           }
           SmallVector<Value> dstDcvs;
@@ -872,9 +870,8 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
       return failure();
 
     if (stt.isPermutation()) {
-      // FIXME: `toStoredDim` is deprecated
       rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
-                                         toStoredDim(stt.getEncoding(), *dim));
+                                         toLvl(stt.getEncoding(), *dim));
       return success();
     }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/73575


More information about the Mlir-commits mailing list