[Mlir-commits] [mlir] 7a1077b - [mlir][sparse] Improving SparseTensorDimSliceAttr methods
wren romano
llvmlistbot at llvm.org
Tue May 30 17:31:03 PDT 2023
Author: wren romano
Date: 2023-05-30T17:30:55-07:00
New Revision: 7a1077baa01cd66afa193276796ee6679954d4e5
URL: https://github.com/llvm/llvm-project/commit/7a1077baa01cd66afa193276796ee6679954d4e5
DIFF: https://github.com/llvm/llvm-project/commit/7a1077baa01cd66afa193276796ee6679954d4e5.diff
LOG: [mlir][sparse] Improving SparseTensorDimSliceAttr methods
This patch makes the following changes to `SparseTensorDimSliceAttr` methods:
* Mark `isDynamic` constexpr.
* Add new helpers `getStatic` and `getStaticString` to avoid repetition.
* Moved the definitions for `getStatic{Offset,Stride,Size}` and `isCompletelyDynamic` out of the class declaration; because there's no benefit to inlining them.
* Changed `parse` to use `kDynamic` rather than literals.
* Changed `verify` to use the `isDynamic` helper.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D150919
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 9fe425a40415..d6c971b0cd36 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -76,32 +76,14 @@ def SparseTensorDimSliceAttr : SparseTensor_Attr<"SparseTensorDimSlice", []> {
let extraClassDeclaration = [{
/// Special value for dynamic offset/size/stride.
static constexpr int64_t kDynamic = -1;
-
- static bool isDynamic(int64_t v) {
- return v == kDynamic;
- }
-
- std::optional<uint64_t> getStaticOffset() const {
- if (isDynamic(getOffset()))
- return std::nullopt;
- return static_cast<uint64_t>(getOffset());
- };
-
- std::optional<uint64_t> getStaticStride() const {
- if (isDynamic(getStride()))
- return std::nullopt;
- return static_cast<uint64_t>(getStride());
- }
-
- std::optional<uint64_t> getStaticSize() const {
- if (isDynamic(getSize()))
- return std::nullopt;
- return static_cast<uint64_t>(getSize());
- }
-
- bool isCompletelyDynamic() const {
- return isDynamic(getOffset()) && isDynamic(getStride()) && isDynamic(getSize());
- };
+ static constexpr bool isDynamic(int64_t v) { return v == kDynamic; }
+ static std::optional<uint64_t> getStatic(int64_t v);
+ static std::string getStaticString(int64_t v);
+
+ std::optional<uint64_t> getStaticOffset() const;
+ std::optional<uint64_t> getStaticStride() const;
+ std::optional<uint64_t> getStaticSize() const;
+ bool isCompletelyDynamic() const;
}];
let genVerifyDecl = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 7f8dcba77fc8..490e35dfa2d0 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -31,6 +31,23 @@
using namespace mlir;
using namespace mlir::sparse_tensor;
+//===----------------------------------------------------------------------===//
+// Additional convenience methods.
+//===----------------------------------------------------------------------===//
+
+static constexpr bool acceptBitWidth(unsigned bitWidth) {
+ switch (bitWidth) {
+ case 0:
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ return true;
+ default:
+ return false;
+ }
+}
+
//===----------------------------------------------------------------------===//
// StorageLayout
//===----------------------------------------------------------------------===//
@@ -166,26 +183,39 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
// TensorDialect Attribute Methods.
//===----------------------------------------------------------------------===//
-static bool acceptBitWidth(unsigned bitWidth) {
- switch (bitWidth) {
- case 0:
- case 8:
- case 16:
- case 32:
- case 64:
- return true;
- default:
- return false;
- }
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
+ return isDynamic(v) ? std::nullopt
+ : std::make_optional(static_cast<uint64_t>(v));
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
+ return getStatic(getOffset());
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
+ return getStatic(getStride());
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
+ return getStatic(getSize());
+}
+
+bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
+ return isDynamic(getOffset()) && isDynamic(getStride()) &&
+ isDynamic(getSize());
+}
+
+std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
+ return isDynamic(v) ? "?" : std::to_string(v);
}
void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
printer << "(";
- printer << (getStaticOffset() ? std::to_string(*getStaticOffset()) : "?");
+ printer << getStaticString(getOffset());
printer << ", ";
- printer << (getStaticSize() ? std::to_string(*getStaticSize()) : "?");
+ printer << getStaticString(getSize());
printer << ", ";
- printer << (getStaticStride() ? std::to_string(*getStaticStride()) : "?");
+ printer << getStaticString(getStride());
printer << ")";
}
@@ -208,7 +238,7 @@ static ParseResult parseOptionalStaticSlice(int64_t &result,
}
Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
- int64_t offset = -1, size = -1, stride = -1;
+ int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
if (failed(parser.parseLParen()) ||
failed(parseOptionalStaticSlice(offset, parser)) ||
@@ -226,13 +256,13 @@ Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
LogicalResult
SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
int64_t offset, int64_t size, int64_t stride) {
- if ((offset == SparseTensorDimSliceAttr::kDynamic || offset >= 0) &&
- (size == SparseTensorDimSliceAttr::kDynamic || size > 0) &&
- (stride == SparseTensorDimSliceAttr::kDynamic || stride > 0)) {
- return success();
- }
- return emitError()
- << "expect positive value or ? for slice offset/size/stride";
+ if (!isDynamic(offset) && offset < 0)
+ return emitError() << "expect non-negative value or ? for slice offset";
+ if (!isDynamic(size) && size <= 0)
+ return emitError() << "expect positive value or ? for slice size";
+ if (!isDynamic(stride) && stride <= 0)
+ return emitError() << "expect positive value or ? for slice stride";
+ return success();
}
Type mlir::sparse_tensor::detail::getIntegerOrIndexType(MLIRContext *ctx,
More information about the Mlir-commits
mailing list