[Mlir-commits] [mlir] [mlir][sparse] code cleanup (PR #73047)
Aart Bik
llvmlistbot at llvm.org
Tue Nov 21 14:19:45 PST 2023
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/73047
removed two unused methods, removed obsoleted FIXME
>From c824dca95b1a1f05b338101ff0a95a4807eeb7ce Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 21 Nov 2023 13:56:30 -0800
Subject: [PATCH] [mlir][sparse] code cleanup
removed two unused methods, removed obsoleted FIXME
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 2 --
.../SparseTensor/IR/SparseTensorDialect.cpp | 32 ++++---------------
2 files changed, 7 insertions(+), 27 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index a254f52aa86e7db..31bb9be50384e3a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -412,10 +412,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
::mlir::sparse_tensor::SparseTensorDimSliceAttr getDimSlice(::mlir::sparse_tensor::Dimension dim) const;
std::optional<uint64_t> getStaticDimSliceOffset(::mlir::sparse_tensor::Dimension dim) const;
- std::optional<uint64_t> getStaticDimSliceSize(::mlir::sparse_tensor::Dimension dim) const;
std::optional<uint64_t> getStaticDimSliceStride(::mlir::sparse_tensor::Dimension dim) const;
std::optional<uint64_t> getStaticLvlSliceOffset(::mlir::sparse_tensor::Level lvl) const;
- std::optional<uint64_t> getStaticLvlSliceSize(::mlir::sparse_tensor::Level lvl) const;
std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 28445ce2a8bc243..f49d0893246943a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -368,11 +368,6 @@ SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
return getDimSlice(dim).getStaticOffset();
}
-std::optional<uint64_t>
-SparseTensorEncodingAttr::getStaticDimSliceSize(Dimension dim) const {
- return getDimSlice(dim).getStaticSize();
-}
-
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
return getDimSlice(dim).getStaticStride();
@@ -384,12 +379,6 @@ SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
return getStaticDimSliceOffset(toOrigDim(*this, lvl));
}
-std::optional<uint64_t>
-SparseTensorEncodingAttr::getStaticLvlSliceSize(Level lvl) const {
- // FIXME: `toOrigDim` is deprecated.
- return getStaticDimSliceSize(toOrigDim(*this, lvl));
-}
-
std::optional<uint64_t>
SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
// FIXME: `toOrigDim` is deprecated.
@@ -1744,33 +1733,26 @@ LogicalResult SortOp::verify() {
if (!xPerm.isPermutation())
emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
- std::optional<int64_t> cn = getConstantIntValue(getN());
// We can't check the size of the buffers when n or buffer dimensions aren't
// compile-time constants.
+ std::optional<int64_t> cn = getConstantIntValue(getN());
if (!cn)
return success();
- uint64_t n = cn.value();
- uint64_t ny = 0;
- if (auto nyAttr = getNyAttr()) {
- ny = nyAttr.getInt();
- }
-
- // FIXME: update the types of variables used in expressions bassed as
- // the `minSize` argument, to avoid implicit casting at the callsites
- // of this lambda.
+ // Verify dimensions.
const auto checkDim = [&](Value v, Size minSize, const char *message) {
const Size sh = getMemRefType(v).getShape()[0];
if (!ShapedType::isDynamic(sh) && sh < minSize)
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
};
-
+ uint64_t n = cn.value();
+ uint64_t ny = 0;
+ if (auto nyAttr = getNyAttr())
+ ny = nyAttr.getInt();
checkDim(getXy(), n * (nx + ny),
"Expected dimension(xy) >= n * (rank(perm_map) + ny)");
-
- for (Value opnd : getYs()) {
+ for (Value opnd : getYs())
checkDim(opnd, n, "Expected dimension(y) >= n");
- }
return success();
}
More information about the Mlir-commits
mailing list