[Mlir-commits] [mlir] [mlir][sparse] code cleanup (PR #73047)

Aart Bik llvmlistbot at llvm.org
Tue Nov 21 14:19:45 PST 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/73047

removed two unused methods, removed obsoleted FIXME

>From c824dca95b1a1f05b338101ff0a95a4807eeb7ce Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 21 Nov 2023 13:56:30 -0800
Subject: [PATCH] [mlir][sparse] code cleanup

removed two unused methods, removed obsoleted FIXME
---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  2 --
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 32 ++++---------------
 2 files changed, 7 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index a254f52aa86e7db..31bb9be50384e3a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -412,10 +412,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     ::mlir::sparse_tensor::SparseTensorDimSliceAttr getDimSlice(::mlir::sparse_tensor::Dimension dim) const;
 
     std::optional<uint64_t> getStaticDimSliceOffset(::mlir::sparse_tensor::Dimension dim) const;
-    std::optional<uint64_t> getStaticDimSliceSize(::mlir::sparse_tensor::Dimension dim) const;
     std::optional<uint64_t> getStaticDimSliceStride(::mlir::sparse_tensor::Dimension dim) const;
     std::optional<uint64_t> getStaticLvlSliceOffset(::mlir::sparse_tensor::Level lvl) const;
-    std::optional<uint64_t> getStaticLvlSliceSize(::mlir::sparse_tensor::Level lvl) const;
     std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
 
     //
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 28445ce2a8bc243..f49d0893246943a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -368,11 +368,6 @@ SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
   return getDimSlice(dim).getStaticOffset();
 }
 
-std::optional<uint64_t>
-SparseTensorEncodingAttr::getStaticDimSliceSize(Dimension dim) const {
-  return getDimSlice(dim).getStaticSize();
-}
-
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
   return getDimSlice(dim).getStaticStride();
@@ -384,12 +379,6 @@ SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
   return getStaticDimSliceOffset(toOrigDim(*this, lvl));
 }
 
-std::optional<uint64_t>
-SparseTensorEncodingAttr::getStaticLvlSliceSize(Level lvl) const {
-  // FIXME: `toOrigDim` is deprecated.
-  return getStaticDimSliceSize(toOrigDim(*this, lvl));
-}
-
 std::optional<uint64_t>
 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
   // FIXME: `toOrigDim` is deprecated.
@@ -1744,33 +1733,26 @@ LogicalResult SortOp::verify() {
   if (!xPerm.isPermutation())
     emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
 
-  std::optional<int64_t> cn = getConstantIntValue(getN());
   // We can't check the size of the buffers when n or buffer dimensions aren't
   // compile-time constants.
+  std::optional<int64_t> cn = getConstantIntValue(getN());
   if (!cn)
     return success();
 
-  uint64_t n = cn.value();
-  uint64_t ny = 0;
-  if (auto nyAttr = getNyAttr()) {
-    ny = nyAttr.getInt();
-  }
-
-  // FIXME: update the types of variables used in expressions bassed as
-  // the `minSize` argument, to avoid implicit casting at the callsites
-  // of this lambda.
+  // Verify dimensions.
   const auto checkDim = [&](Value v, Size minSize, const char *message) {
     const Size sh = getMemRefType(v).getShape()[0];
     if (!ShapedType::isDynamic(sh) && sh < minSize)
       emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
   };
-
+  uint64_t n = cn.value();
+  uint64_t ny = 0;
+  if (auto nyAttr = getNyAttr())
+    ny = nyAttr.getInt();
   checkDim(getXy(), n * (nx + ny),
            "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
-
-  for (Value opnd : getYs()) {
+  for (Value opnd : getYs())
     checkDim(opnd, n, "Expected dimension(y) >= n");
-  }
 
   return success();
 }



More information about the Mlir-commits mailing list