[Mlir-commits] [mlir] [mlir][sparse] code cleanup, remove FIXMEs (PR #73575)

Peiming Liu llvmlistbot at llvm.org
Mon Nov 27 14:33:58 PST 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/73575

>From c56970af6ebb9f7dcc242166ba3591c26e07f988 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 27 Nov 2023 22:22:47 +0000
Subject: [PATCH 1/2] [mlir][sparse] code cleanup, remove FIXMEs

---
 .../Dialect/SparseTensor/IR/SparseTensor.h    | 16 +++++----
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 36 ++++++-------------
 .../SparseTensor/Transforms/CodegenUtils.cpp  | 26 --------------
 .../SparseTensor/Transforms/CodegenUtils.h    |  7 ----
 .../SparseTensor/Transforms/LoopEmitter.cpp   |  6 ++--
 .../Transforms/SparseTensorRewriting.cpp      |  9 ++---
 6 files changed, 25 insertions(+), 75 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index eb7c50ae2efdf8c..f102f0270154264 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -163,13 +163,15 @@ bool isBlockSparsity(AffineMap dimToLvl);
 // Reordering.
 //
 
-/// [deprecated] Convenience method to translate the given level to the
-/// corresponding dimension. Requires: `0 <= l < lvlRank`.
-Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
-
-/// [deprecated] Convenience method to translate the given dimension to
-/// the corresponding level. Requires: `0 <= d < dimRank`.
-Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
+/// Convenience method to translate the given level to the corresponding
+/// dimension.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`.
+Dimension toDim(SparseTensorEncodingAttr enc, Level l);
+
+/// Convenience method to translate the given dimension to the corresponding
+/// level.
+/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`.
+Level toLvl(SparseTensorEncodingAttr enc, Dimension d);
 
 } // namespace sparse_tensor
 } // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 791aeebee5a328d..28e07e1669e7919 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -375,14 +375,12 @@ SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
 
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
-  // FIXME: `toOrigDim` is deprecated.
-  return getStaticDimSliceOffset(toOrigDim(*this, lvl));
+  return getStaticDimSliceOffset(toDim(*this, lvl));
 }
 
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
-  // FIXME: `toOrigDim` is deprecated.
-  return getStaticDimSliceStride(toOrigDim(*this, lvl));
+  return getStaticDimSliceStride(toDim(*this, lvl));
 }
 
 SmallVector<int64_t>
@@ -399,9 +397,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
   if (isPermutation()) {
     for (unsigned r = 0; r < rank; r++) {
       // FIXME: `toOrigDim` and `toStoredDim` are deprecated.
-      unsigned trans = dir == CrdTransDirectionKind::dim2lvl
-                           ? toOrigDim(*this, r)
-                           : toStoredDim(*this, r);
+      unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
+                                                             : toLvl(*this, r);
       ret.push_back(srcShape[trans]);
     }
     return ret;
@@ -925,31 +922,20 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
       ordered);
 }
 
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc,
-                                         Level l) {
+Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
   if (enc) {
-    if (const auto dimToLvl = enc.getDimToLvl()) {
-      assert(enc.isPermutation());
+    assert(enc.isPermutation() && "Non permutation map");
+    if (const auto dimToLvl = enc.getDimToLvl())
       return dimToLvl.getDimPosition(l);
-    }
   }
   return l;
 }
 
-// TODO: Remove this definition once all use-sites have been fixed to
-// properly handle non-permutations.
-Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
-                                       Dimension d) {
+Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
   if (enc) {
-    if (const auto dimToLvl = enc.getDimToLvl()) {
-      assert(enc.isPermutation());
-      auto maybePos =
-          dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext()));
-      assert(maybePos.has_value());
-      return *maybePos;
-    }
+    assert(enc.isPermutation() && "");
+    if (const auto lvlToDim = enc.getLvlToDim())
+      return lvlToDim.getDimPosition(d);
   }
   return d;
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 1200b999f9a90ff..33d449aac5a3550 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -546,32 +546,6 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
   }
 }
 
-Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc,
-                                           SparseTensorEncodingAttr enc,
-                                           ValueRange dimSizes,
-                                           Value valuesBuffer,
-                                           Value lvlCoords) {
-  // Reuse the `lvlCoords` buffer to store the level-sizes.
-  const Level lvlRank = enc.getLvlRank();
-  SmallVector<Value> lvlSizes;
-  lvlSizes.reserve(lvlRank);
-  for (Level l = 0; l < lvlRank; l++)
-    // FIXME: `toOrigDim` is deprecated.
-    lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]);
-  storeAll(builder, loc, lvlCoords, lvlSizes);
-  // The memref ReshapeOp requires the sizes buffer to have a static
-  // shape.
-  const auto iTp = builder.getIndexType();
-  const SmallVector<Size, 1> lvlSizesShape{static_cast<Size>(lvlRank)};
-  const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp);
-  lvlCoords = builder.create<memref::CastOp>(loc, lvlSizesTp, lvlCoords);
-  // Finally, create the ReshapeOp.
-  const SmallVector<Size> resShape(lvlRank, ShapedType::kDynamic);
-  const Type elemTp = getMemRefType(valuesBuffer).getElementType();
-  const auto resTp = MemRefType::get(resShape, elemTp);
-  return builder.create<memref::ReshapeOp>(loc, resTp, valuesBuffer, lvlCoords);
-}
-
 TypedValue<BaseMemRefType>
 sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
   auto tTp = llvm::cast<TensorType>(tensor.getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index cb0acdd2be9f7b0..0ce33427281f598 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -277,13 +277,6 @@ SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size,
 void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
               size_t offsetIdx = 0, Value offsetVal = Value());
 
-/// Reshapes the linear values buffer for an annotated all dense sparse tensor
-/// to match the shape of the corresponding dense tensor to support direct
-/// access of the buffer through `lvlCoords`.
-Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
-                            SparseTensorEncodingAttr enc, ValueRange dimSizes,
-                            Value valuesBuffer, Value lvlCoords);
-
 // Generates code to cast a tensor to a memref.
 TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
                                        Value tensor);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index f8bcc0fe12a1093..413a835ff14d314 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -68,15 +68,13 @@ static constexpr unsigned kSliceIterWidth = 3;
 static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
                             Level lvl) {
   auto enc = getSparseTensorEncoding(tensor.getType());
-  // FIXME: `toOrigDim` is deprecated
-  return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
+  return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl));
 }
 
 static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
                             Level lvl) {
   auto enc = getSparseTensorEncoding(tensor.getType());
-  // FIXME: `toOrigDim` is deprecated
-  return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
+  return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
 }
 
 /// Converts a coordinate relative to the slice to the coordinate relative
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5374ab55c5c0d9f..103908b2cf5bd89 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -661,8 +661,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
           SmallVector<Value> srcDcvs;
           srcDcvs.reserve(srcRank);
           for (Dimension d = 0; d < srcRank; d++) {
-            // FIXME: `toStoredDim` is deprecated
-            Level lvl = toStoredDim(encSrc, d);
+            Level lvl = toLvl(encSrc, d);
             srcDcvs.push_back(srcLcvs[lvl]);
           }
 
@@ -766,8 +765,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
           SmallVector<Value> srcDcvs;
           srcDcvs.reserve(dimRank);
           for (Dimension d = 0; d < dimRank; d++) {
-            // FIXME: `toStoredDim` is deprecated
-            Level lvl = toStoredDim(encSrc, d);
+            Level lvl = toLvl(encSrc, d);
             srcDcvs.push_back(srcLcvs[lvl]);
           }
           SmallVector<Value> dstDcvs;
@@ -872,9 +870,8 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
       return failure();
 
     if (stt.isPermutation()) {
-      // FIXME: `toStoredDim` is deprecated
       rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
-                                         toStoredDim(stt.getEncoding(), *dim));
+                                         toLvl(stt.getEncoding(), *dim));
       return success();
     }
 

>From 80ca0bb13c336c2ac76bd2d3a7236d005a05ffb5 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 27 Nov 2023 22:33:43 +0000
Subject: [PATCH 2/2] remove unused variables

---
 .../SparseTensor/Transforms/IterationGraphSorter.h        | 1 -
 .../SparseTensor/Transforms/SparseReinterpretMap.cpp      | 8 ++++----
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
index 613a8609ac0973a..52ee11702930096 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h
@@ -26,7 +26,6 @@ enum class SortMask : unsigned {
   // The individual mask bits.
   kIncludeDenseOutput = 0x1, // b001
   kIncludeDenseInput = 0x2,  // b010
-  kIncludeUndef = 0x4,       // b100
   // The subsets of mask bits.
   kIncludeAll = 0x7,   // b111
   kIncludeDense = 0x3, // b011
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 268bd8fbe27387f..c94ef8b96287766 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -422,10 +422,10 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
     // computation. Must be ordered from more strict to less strict.
     // Ideally (though might not be guaranteed), the earlier a constraint mask
     // can be satisfied, the faster the generated kernel will be.
-    const auto allMasks = {
-        SortMask::kIncludeAll,        SortMask::kIncludeDense,
-        SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
-        SortMask::kIncludeUndef,      SortMask::kSparseOnly};
+    const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
+                           SortMask::kIncludeDenseInput,
+                           SortMask::kIncludeDenseOutput,
+                           SortMask::kSparseOnly};
     for (const SortMask mask : allMasks) {
       order = scheduler.sort(mask);
       if (order) {



More information about the Mlir-commits mailing list