[Mlir-commits] [mlir] 7a1077b - [mlir][sparse] Improving SparseTensorDimSliceAttr methods

wren romano llvmlistbot at llvm.org
Tue May 30 17:31:03 PDT 2023


Author: wren romano
Date: 2023-05-30T17:30:55-07:00
New Revision: 7a1077baa01cd66afa193276796ee6679954d4e5

URL: https://github.com/llvm/llvm-project/commit/7a1077baa01cd66afa193276796ee6679954d4e5
DIFF: https://github.com/llvm/llvm-project/commit/7a1077baa01cd66afa193276796ee6679954d4e5.diff

LOG: [mlir][sparse] Improving SparseTensorDimSliceAttr methods

This patch makes the following changes to `SparseTensorDimSliceAttr` methods:
* Mark `isDynamic` constexpr.
* Add new helpers `getStatic` and `getStaticString` to avoid repetition.
* Moved the definitions for `getStatic{Offset,Stride,Size}` and `isCompletelyDynamic` out of the class declaration; because there's no benefit to inlining them.
* Changed `parse` to use `kDynamic` rather than literals.
* Changed `verify` to use the `isDynamic` helper.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D150919

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 9fe425a40415..d6c971b0cd36 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -76,32 +76,14 @@ def SparseTensorDimSliceAttr : SparseTensor_Attr<"SparseTensorDimSlice", []> {
   let extraClassDeclaration = [{
     /// Special value for dynamic offset/size/stride.
     static constexpr int64_t kDynamic = -1;
-
-    static bool isDynamic(int64_t v) {
-      return v == kDynamic;
-    }
-
-    std::optional<uint64_t> getStaticOffset() const {
-      if (isDynamic(getOffset()))
-        return std::nullopt;
-      return static_cast<uint64_t>(getOffset());
-    };
-
-    std::optional<uint64_t> getStaticStride() const {
-      if (isDynamic(getStride()))
-        return std::nullopt;
-      return static_cast<uint64_t>(getStride());
-    }
-
-    std::optional<uint64_t> getStaticSize() const {
-      if (isDynamic(getSize()))
-        return std::nullopt;
-      return static_cast<uint64_t>(getSize());
-    }
-
-    bool isCompletelyDynamic() const {
-      return isDynamic(getOffset()) && isDynamic(getStride()) && isDynamic(getSize());
-    };
+    static constexpr bool isDynamic(int64_t v) { return v == kDynamic; }
+    static std::optional<uint64_t> getStatic(int64_t v);
+    static std::string getStaticString(int64_t v);
+
+    std::optional<uint64_t> getStaticOffset() const;
+    std::optional<uint64_t> getStaticStride() const;
+    std::optional<uint64_t> getStaticSize() const;
+    bool isCompletelyDynamic() const;
   }];
 
   let genVerifyDecl = 1;

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 7f8dcba77fc8..490e35dfa2d0 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -31,6 +31,23 @@
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
+//===----------------------------------------------------------------------===//
+// Additional convenience methods.
+//===----------------------------------------------------------------------===//
+
+static constexpr bool acceptBitWidth(unsigned bitWidth) {
+  switch (bitWidth) {
+  case 0:
+  case 8:
+  case 16:
+  case 32:
+  case 64:
+    return true;
+  default:
+    return false;
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // StorageLayout
 //===----------------------------------------------------------------------===//
@@ -166,26 +183,39 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
 // TensorDialect Attribute Methods.
 //===----------------------------------------------------------------------===//
 
-static bool acceptBitWidth(unsigned bitWidth) {
-  switch (bitWidth) {
-  case 0:
-  case 8:
-  case 16:
-  case 32:
-  case 64:
-    return true;
-  default:
-    return false;
-  }
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
+  return isDynamic(v) ? std::nullopt
+                      : std::make_optional(static_cast<uint64_t>(v));
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
+  return getStatic(getOffset());
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
+  return getStatic(getStride());
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
+  return getStatic(getSize());
+}
+
+bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
+  return isDynamic(getOffset()) && isDynamic(getStride()) &&
+         isDynamic(getSize());
+}
+
+std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
+  return isDynamic(v) ? "?" : std::to_string(v);
 }
 
 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
   printer << "(";
-  printer << (getStaticOffset() ? std::to_string(*getStaticOffset()) : "?");
+  printer << getStaticString(getOffset());
   printer << ", ";
-  printer << (getStaticSize() ? std::to_string(*getStaticSize()) : "?");
+  printer << getStaticString(getSize());
   printer << ", ";
-  printer << (getStaticStride() ? std::to_string(*getStaticStride()) : "?");
+  printer << getStaticString(getStride());
   printer << ")";
 }
 
@@ -208,7 +238,7 @@ static ParseResult parseOptionalStaticSlice(int64_t &result,
 }
 
 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
-  int64_t offset = -1, size = -1, stride = -1;
+  int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
 
   if (failed(parser.parseLParen()) ||
       failed(parseOptionalStaticSlice(offset, parser)) ||
@@ -226,13 +256,13 @@ Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
 LogicalResult
 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                                  int64_t offset, int64_t size, int64_t stride) {
-  if ((offset == SparseTensorDimSliceAttr::kDynamic || offset >= 0) &&
-      (size == SparseTensorDimSliceAttr::kDynamic || size > 0) &&
-      (stride == SparseTensorDimSliceAttr::kDynamic || stride > 0)) {
-    return success();
-  }
-  return emitError()
-         << "expect positive value or ? for slice offset/size/stride";
+  if (!isDynamic(offset) && offset < 0)
+    return emitError() << "expect non-negative value or ? for slice offset";
+  if (!isDynamic(size) && size <= 0)
+    return emitError() << "expect positive value or ? for slice size";
+  if (!isDynamic(stride) && stride <= 0)
+    return emitError() << "expect positive value or ? for slice stride";
+  return success();
 }
 
 Type mlir::sparse_tensor::detail::getIntegerOrIndexType(MLIRContext *ctx,


        


More information about the Mlir-commits mailing list