[Mlir-commits] [mlir] 4e2f152 - [mlir][sparse] code cleanup, remove FIXMEs (#73575)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 14:57:14 PST 2023


Author: Peiming Liu
Date: 2023-11-27T14:57:08-08:00
New Revision: 4e2f1521ec4b03a5838f9a4b61cd59bb8d60d8e6

URL: https://github.com/llvm/llvm-project/commit/4e2f1521ec4b03a5838f9a4b61cd59bb8d60d8e6
DIFF: https://github.com/llvm/llvm-project/commit/4e2f1521ec4b03a5838f9a4b61cd59bb8d60d8e6.diff

LOG: [mlir][sparse] code cleanup, remove FIXMEs (#73575)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index eb7c50ae2efdf8c..f102f0270154264 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -163,13 +163,15 @@ bool isBlockSparsity(AffineMap dimToLvl);
 // Reordering.
 //
 
-/// [deprecated] Convenience method to translate the given level to the
-/// corresponding dimension. Requires: `0 <= l < lvlRank`.
-Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
-
-/// [deprecated] Convenience method to translate the given dimension to
-/// the corresponding level. Requires: `0 <= d < dimRank`.
-Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
+/// Convenience method to translate the given level to the corresponding
+/// dimension.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`.
+Dimension toDim(SparseTensorEncodingAttr enc, Level l);
+
+/// Convenience method to translate the given dimension to the corresponding
+/// level.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`.
+Level toLvl(SparseTensorEncodingAttr enc, Dimension d);
 
 } // namespace sparse_tensor
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 686bc021c2667ab..c4a828a20465a80 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -375,14 +375,12 @@ SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
 
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
-  // FIXME: `toOrigDim` is deprecated.
-  return getStaticDimSliceOffset(toOrigDim(*this, lvl));
+  return getStaticDimSliceOffset(toDim(*this, lvl));
 }
 
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
-  // FIXME: `toOrigDim` is deprecated.
-  return getStaticDimSliceStride(toOrigDim(*this, lvl));
+  return getStaticDimSliceStride(toDim(*this, lvl));
 }
 
 SmallVector<int64_t>
@@ -398,10 +396,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
 
   if (isPermutation()) {
     for (unsigned r = 0; r < rank; r++) {
-      // FIXME: `toOrigDim` and `toStoredDim` are deprecated.
-      unsigned trans = dir == CrdTransDirectionKind::dim2lvl
-                           ? toOrigDim(*this, r)
-                           : toStoredDim(*this, r);
+      unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
+                                                             : toLvl(*this, r);
       ret.push_back(srcShape[trans]);
     }
     return ret;
@@ -922,31 +918,20 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
       ordered);
 }
 
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc,
-                                         Level l) {
+Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
   if (enc) {
-    if (const auto dimToLvl = enc.getDimToLvl()) {
-      assert(enc.isPermutation());
+    assert(enc.isPermutation() && "Non permutation map not supported");
+    if (const auto dimToLvl = enc.getDimToLvl())
       return dimToLvl.getDimPosition(l);
-    }
   }
   return l;
 }
 
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
-                                       Dimension d) {
+Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
   if (enc) {
-    if (const auto dimToLvl = enc.getDimToLvl()) {
-      assert(enc.isPermutation());
-      auto maybePos =
-          dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext()));
-      assert(maybePos.has_value());
-      return *maybePos;
-    }
+    assert(enc.isPermutation() && "Non permutation map not supported");
+    if (const auto lvlToDim = enc.getLvlToDim())
+      return lvlToDim.getDimPosition(d);
   }
   return d;
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 1200b999f9a90ff..33d449aac5a3550 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -546,32 +546,6 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
   }
 }
 
-Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc,
-                                           SparseTensorEncodingAttr enc,
-                                           ValueRange dimSizes,
-                                           Value valuesBuffer,
-                                           Value lvlCoords) {
-  // Reuse the `lvlCoords` buffer to store the level-sizes.
-  const Level lvlRank = enc.getLvlRank();
-  SmallVector<Value> lvlSizes;
-  lvlSizes.reserve(lvlRank);
-  for (Level l = 0; l < lvlRank; l++)
-    // FIXME: `toOrigDim` is deprecated.
-    lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]);
-  storeAll(builder, loc, lvlCoords, lvlSizes);
-  // The memref ReshapeOp requires the sizes buffer to have a static
-  // shape.
-  const auto iTp = builder.getIndexType();
-  const SmallVector<Size, 1> lvlSizesShape{static_cast<Size>(lvlRank)};
-  const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp);
-  lvlCoords = builder.create<memref::CastOp>(loc, lvlSizesTp, lvlCoords);
-  // Finally, create the ReshapeOp.
-  const SmallVector<Size> resShape(lvlRank, ShapedType::kDynamic);
-  const Type elemTp = getMemRefType(valuesBuffer).getElementType();
-  const auto resTp = MemRefType::get(resShape, elemTp);
-  return builder.create<memref::ReshapeOp>(loc, resTp, valuesBuffer, lvlCoords);
-}
-
 TypedValue<BaseMemRefType>
 sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
   auto tTp = llvm::cast<TensorType>(tensor.getType());

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 910b605731e2573..57de437e2311c0b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -277,13 +277,6 @@ SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size,
 void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
               size_t offsetIdx = 0, Value offsetVal = Value());
 
-/// Reshapes the linear values buffer for an annotated all dense sparse tensor
-/// to match the shape of the corresponding dense tensor to support direct
-/// access of the buffer through `lvlCoords`.
-Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
-                            SparseTensorEncodingAttr enc, ValueRange dimSizes,
-                            Value valuesBuffer, Value lvlCoords);
-
 // Generates code to cast a tensor to a memref.
 TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
                                        Value tensor);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
index 613a8609ac0973a..52ee11702930096 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
@@ -26,7 +26,6 @@ enum class SortMask : unsigned {
   // The individual mask bits.
   kIncludeDenseOutput = 0x1, // b001
   kIncludeDenseInput = 0x2,  // b010
-  kIncludeUndef = 0x4,       // b100
   // The subsets of mask bits.
   kIncludeAll = 0x7,   // b111
   kIncludeDense = 0x3, // b011

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index dd8091a4baa9689..69072b91b2fa523 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -68,15 +68,13 @@ static constexpr unsigned kSliceIterWidth = 3;
 static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
                             Level lvl) {
   auto enc = getSparseTensorEncoding(tensor.getType());
-  // FIXME: `toOrigDim` is deprecated
-  return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
+  return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl));
 }
 
 static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
                             Level lvl) {
   auto enc = getSparseTensorEncoding(tensor.getType());
-  // FIXME: `toOrigDim` is deprecated
-  return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
+  return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
 }
 
 /// Converts a coordinate relative to the slice to the coordinate relative

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 268bd8fbe27387f..c94ef8b96287766 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -422,10 +422,10 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
     // computation. Must be ordered from more strict to less strict.
     // Ideally (though might not be guaranteed), the earlier a constraint mask
     // can be satisfied, the faster the generated kernel will be.
-    const auto allMasks = {
-        SortMask::kIncludeAll,        SortMask::kIncludeDense,
-        SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
-        SortMask::kIncludeUndef,      SortMask::kSparseOnly};
+    const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
+                           SortMask::kIncludeDenseInput,
+                           SortMask::kIncludeDenseOutput,
+                           SortMask::kSparseOnly};
     for (const SortMask mask : allMasks) {
       order = scheduler.sort(mask);
       if (order) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index ca3d336fcf1843b..702666e9d40c31f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -660,8 +660,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
           SmallVector<Value> srcDcvs;
           srcDcvs.reserve(srcRank);
           for (Dimension d = 0; d < srcRank; d++) {
-            // FIXME: `toStoredDim` is deprecated
-            Level lvl = toStoredDim(encSrc, d);
+            Level lvl = toLvl(encSrc, d);
             srcDcvs.push_back(srcLcvs[lvl]);
           }
 
@@ -765,8 +764,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
           SmallVector<Value> srcDcvs;
           srcDcvs.reserve(dimRank);
           for (Dimension d = 0; d < dimRank; d++) {
-            // FIXME: `toStoredDim` is deprecated
-            Level lvl = toStoredDim(encSrc, d);
+            Level lvl = toLvl(encSrc, d);
             srcDcvs.push_back(srcLcvs[lvl]);
           }
           SmallVector<Value> dstDcvs;
@@ -871,9 +869,8 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
       return failure();
 
     if (stt.isPermutation()) {
-      // FIXME: `toStoredDim` is deprecated
       rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
-                                         toStoredDim(stt.getEncoding(), *dim));
+                                         toLvl(stt.getEncoding(), *dim));
       return success();
     }
 


        


More information about the Mlir-commits mailing list