[Mlir-commits] [mlir] 743fbcb - [mlir][sparse] IR/SparseTensorDialect.cpp: misc code cleanup

wren romano llvmlistbot at llvm.org
Fri Jan 20 13:31:46 PST 2023


Author: wren romano
Date: 2023-01-20T13:31:39-08:00
New Revision: 743fbcb79d9af759377df5f5929ffdd38ff52b09

URL: https://github.com/llvm/llvm-project/commit/743fbcb79d9af759377df5f5929ffdd38ff52b09
DIFF: https://github.com/llvm/llvm-project/commit/743fbcb79d9af759377df5f5929ffdd38ff52b09.diff

LOG: [mlir][sparse] IR/SparseTensorDialect.cpp: misc code cleanup

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D142072

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 5ea4f7ca63eae..0ad21a1729970 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -101,16 +101,18 @@ SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
          << "expect positive value or ? for slice offset/size/stride";
 }
 
+static Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth) {
+  if (bitwidth)
+    return IntegerType::get(ctx, bitwidth);
+  return IndexType::get(ctx);
+}
+
 Type SparseTensorEncodingAttr::getPointerType() const {
-  unsigned ptrWidth = getPointerBitWidth();
-  Type indexType = IndexType::get(getContext());
-  return ptrWidth ? IntegerType::get(getContext(), ptrWidth) : indexType;
+  return getIntegerOrIndexType(getContext(), getPointerBitWidth());
 }
 
 Type SparseTensorEncodingAttr::getIndexType() const {
-  unsigned idxWidth = getIndexBitWidth();
-  Type indexType = IndexType::get(getContext());
-  return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType;
+  return getIntegerOrIndexType(getContext(), getIndexBitWidth());
 }
 
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
@@ -157,11 +159,30 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(unsigned lvl) const {
   return getStaticDimSliceStride(toOrigDim(*this, lvl));
 }
 
+const static DimLevelType validDLTs[] = {
+    DimLevelType::Dense,          DimLevelType::Compressed,
+    DimLevelType::CompressedNu,   DimLevelType::CompressedNo,
+    DimLevelType::CompressedNuNo, DimLevelType::Singleton,
+    DimLevelType::SingletonNu,    DimLevelType::SingletonNo,
+    DimLevelType::SingletonNuNo};
+
+static std::optional<DimLevelType> parseDLT(StringRef str) {
+  for (DimLevelType dlt : validDLTs)
+    if (str == toMLIRString(dlt))
+      return dlt;
+  return std::nullopt;
+}
+
 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 #define RETURN_ON_FAIL(stmt)                                                   \
   if (failed(stmt)) {                                                          \
     return {};                                                                 \
   }
+#define ERROR_IF(COND, MSG)                                                    \
+  if (COND) {                                                                  \
+    parser.emitError(parser.getNameLoc(), MSG);                                \
+    return {};                                                                 \
+  }
 
   RETURN_ON_FAIL(parser.parseLess())
   RETURN_ON_FAIL(parser.parseLBrace())
@@ -191,37 +212,13 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
       Attribute attr;
       RETURN_ON_FAIL(parser.parseAttribute(attr));
       auto arrayAttr = attr.dyn_cast<ArrayAttr>();
-      if (!arrayAttr) {
-        parser.emitError(parser.getNameLoc(),
-                         "expected an array for dimension level types");
-        return {};
-      }
+      ERROR_IF(!arrayAttr, "expected an array for dimension level types")
       for (auto i : arrayAttr) {
         auto strAttr = i.dyn_cast<StringAttr>();
-        if (!strAttr) {
-          parser.emitError(parser.getNameLoc(),
-                           "expected a string value in dimension level types");
-          return {};
-        }
+        ERROR_IF(!strAttr, "expected a string value in dimension level types")
         auto strVal = strAttr.getValue();
-        if (strVal == "dense") {
-          dlt.push_back(DimLevelType::Dense);
-        } else if (strVal == "compressed") {
-          dlt.push_back(DimLevelType::Compressed);
-        } else if (strVal == "compressed-nu") {
-          dlt.push_back(DimLevelType::CompressedNu);
-        } else if (strVal == "compressed-no") {
-          dlt.push_back(DimLevelType::CompressedNo);
-        } else if (strVal == "compressed-nu-no") {
-          dlt.push_back(DimLevelType::CompressedNuNo);
-        } else if (strVal == "singleton") {
-          dlt.push_back(DimLevelType::Singleton);
-        } else if (strVal == "singleton-nu") {
-          dlt.push_back(DimLevelType::SingletonNu);
-        } else if (strVal == "singleton-no") {
-          dlt.push_back(DimLevelType::SingletonNo);
-        } else if (strVal == "singleton-nu-no") {
-          dlt.push_back(DimLevelType::SingletonNuNo);
+        if (auto optDLT = parseDLT(strVal)) {
+          dlt.push_back(optDLT.value());
         } else {
           parser.emitError(parser.getNameLoc(),
                            "unexpected dimension level type: ")
@@ -232,46 +229,26 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
     } else if (attrName == "dimOrdering") {
       Attribute attr;
       RETURN_ON_FAIL(parser.parseAttribute(attr))
-
       auto affineAttr = attr.dyn_cast<AffineMapAttr>();
-      if (!affineAttr) {
-        parser.emitError(parser.getNameLoc(),
-                         "expected an affine map for dimension ordering");
-        return {};
-      }
+      ERROR_IF(!affineAttr, "expected an affine map for dimension ordering")
       dimOrd = affineAttr.getValue();
     } else if (attrName == "higherOrdering") {
       Attribute attr;
       RETURN_ON_FAIL(parser.parseAttribute(attr))
-
       auto affineAttr = attr.dyn_cast<AffineMapAttr>();
-      if (!affineAttr) {
-        parser.emitError(parser.getNameLoc(),
-                         "expected an affine map for higher ordering");
-        return {};
-      }
+      ERROR_IF(!affineAttr, "expected an affine map for higher ordering")
       higherOrd = affineAttr.getValue();
     } else if (attrName == "pointerBitWidth") {
       Attribute attr;
       RETURN_ON_FAIL(parser.parseAttribute(attr))
-
       auto intAttr = attr.dyn_cast<IntegerAttr>();
-      if (!intAttr) {
-        parser.emitError(parser.getNameLoc(),
-                         "expected an integral pointer bitwidth");
-        return {};
-      }
+      ERROR_IF(!intAttr, "expected an integral pointer bitwidth")
       ptr = intAttr.getInt();
     } else if (attrName == "indexBitWidth") {
       Attribute attr;
       RETURN_ON_FAIL(parser.parseAttribute(attr))
-
       auto intAttr = attr.dyn_cast<IntegerAttr>();
-      if (!intAttr) {
-        parser.emitError(parser.getNameLoc(),
-                         "expected an integral index bitwidth");
-        return {};
-      }
+      ERROR_IF(!intAttr, "expected an integral index bitwidth")
       ind = intAttr.getInt();
     } else if (attrName == "slice") {
       RETURN_ON_FAIL(parser.parseLSquare())
@@ -298,6 +275,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 
   RETURN_ON_FAIL(parser.parseRBrace())
   RETURN_ON_FAIL(parser.parseGreater())
+#undef ERROR_IF
 #undef RETURN_ON_FAIL
 
   // Construct struct-like storage for attribute.
@@ -367,18 +345,21 @@ LogicalResult SparseTensorEncodingAttr::verify(
     return emitError() << "unexpected mismatch in dimension slices and "
                           "dimension level type size";
   }
-
   return success();
 }
 
+#define RETURN_FAILURE_IF_FAILED(X)                                            \
+  if (failed(X)) {                                                             \
+    return failure();                                                          \
+  }
+
 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
     ArrayRef<int64_t> shape, Type elementType,
     function_ref<InFlightDiagnostic()> emitError) const {
   // Check structural integrity.
-  if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
-                    getHigherOrdering(), getPointerBitWidth(),
-                    getIndexBitWidth(), getDimSlices())))
-    return failure();
+  RETURN_FAILURE_IF_FAILED(verify(
+      emitError, getDimLevelType(), getDimOrdering(), getHigherOrdering(),
+      getPointerBitWidth(), getIndexBitWidth(), getDimSlices()))
   // Check integrity with tensor type specifics. Dimension ordering is optional,
   // but we always should have dimension level types for the full rank.
   unsigned size = shape.size();
@@ -435,23 +416,17 @@ static bool isCOOType(SparseTensorEncodingAttr enc, uint64_t s, bool isUnique) {
 
 bool mlir::sparse_tensor::isUniqueCOOType(RankedTensorType tp) {
   SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp);
-  if (!enc)
-    return false;
-
-  return isCOOType(enc, 0, /*isUnique=*/true);
+  return enc && isCOOType(enc, 0, /*isUnique=*/true);
 }
 
 unsigned mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
-  unsigned rank = enc.getDimLevelType().size();
-  if (rank <= 1)
-    return rank;
-
+  const unsigned rank = enc.getDimLevelType().size();
   // We only consider COO region with at least two dimensions for the purpose
   // of AOS storage optimization.
-  for (unsigned r = 0; r < rank - 1; r++) {
-    if (isCOOType(enc, r, /*isUnique=*/false))
-      return r;
-  }
+  if (rank > 1)
+    for (unsigned r = 0; r < rank - 1; r++)
+      if (isCOOType(enc, r, /*isUnique=*/false))
+        return r;
 
   return rank;
 }
@@ -541,10 +516,8 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
 
 Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
                                         std::optional<APInt> dim) const {
-  std::optional<unsigned> intDim;
-  if (dim)
-    intDim = dim.value().getZExtValue();
-  return getFieldType(kind, intDim);
+  return getFieldType(kind, dim ? std::optional(dim.value().getZExtValue())
+                                : std::nullopt);
 }
 
 //===----------------------------------------------------------------------===//
@@ -552,17 +525,12 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
 //===----------------------------------------------------------------------===//
 
 static LogicalResult isInBounds(uint64_t dim, Value tensor) {
-  uint64_t rank = tensor.getType().cast<RankedTensorType>().getRank();
-  if (dim >= rank)
-    return failure();
-  return success(); // in bounds
+  return success(dim < tensor.getType().cast<RankedTensorType>().getRank());
 }
 
 static LogicalResult isMatchingWidth(Value result, unsigned width) {
-  Type etp = result.getType().cast<MemRefType>().getElementType();
-  if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
-    return success();
-  return failure();
+  const Type etp = result.getType().cast<MemRefType>().getElementType();
+  return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
 }
 
 static LogicalResult verifySparsifierGetterSetter(
@@ -663,11 +631,8 @@ LogicalResult ToValuesOp::verify() {
 }
 
 LogicalResult GetStorageSpecifierOp::verify() {
-  if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(),
-                                          getSpecifier(), getOperation()))) {
-    return failure();
-  }
-
+  RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
+      getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
   // Checks the result type
   if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
       getResult().getType()) {
@@ -692,11 +657,8 @@ OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult SetStorageSpecifierOp::verify() {
-  if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(),
-                                          getSpecifier(), getOperation()))) {
-    return failure();
-  }
-
+  RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
+      getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
   // Checks the input type
   if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
       getValue().getType()) {
@@ -748,59 +710,45 @@ LogicalResult BinaryOp::verify() {
 
   // Check correct number of block arguments and return type for each
   // non-empty region.
-  LogicalResult regionResult = success();
   if (!overlap.empty()) {
-    regionResult = verifyNumBlockArgs(
-        this, overlap, "overlap", TypeRange{leftType, rightType}, outputType);
-    if (failed(regionResult))
-      return regionResult;
+    RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
+        this, overlap, "overlap", TypeRange{leftType, rightType}, outputType))
   }
   if (!left.empty()) {
-    regionResult =
-        verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType);
-    if (failed(regionResult))
-      return regionResult;
+    RETURN_FAILURE_IF_FAILED(
+        verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType))
   } else if (getLeftIdentity()) {
     if (leftType != outputType)
       return emitError("left=identity requires first argument to have the same "
                        "type as the output");
   }
   if (!right.empty()) {
-    regionResult = verifyNumBlockArgs(this, right, "right",
-                                      TypeRange{rightType}, outputType);
-    if (failed(regionResult))
-      return regionResult;
+    RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
+        this, right, "right", TypeRange{rightType}, outputType))
   } else if (getRightIdentity()) {
     if (rightType != outputType)
       return emitError("right=identity requires second argument to have the "
                        "same type as the output");
   }
-
   return success();
 }
 
 LogicalResult UnaryOp::verify() {
   Type inputType = getX().getType();
   Type outputType = getOutput().getType();
-  LogicalResult regionResult = success();
 
   // Check correct number of block arguments and return type for each
   // non-empty region.
   Region &present = getPresentRegion();
   if (!present.empty()) {
-    regionResult = verifyNumBlockArgs(this, present, "present",
-                                      TypeRange{inputType}, outputType);
-    if (failed(regionResult))
-      return regionResult;
+    RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
+        this, present, "present", TypeRange{inputType}, outputType))
   }
   Region &absent = getAbsentRegion();
   if (!absent.empty()) {
-    regionResult =
-        verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType);
-    if (failed(regionResult))
-      return regionResult;
+    RETURN_FAILURE_IF_FAILED(
+        verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType))
   }
-
   return success();
 }
 
@@ -880,8 +828,7 @@ void PushBackOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult PushBackOp::verify() {
-  Value n = getN();
-  if (n) {
+  if (Value n = getN()) {
     auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
     if (nValue && nValue.value() < 1)
       return emitOpError("n must be not less than 1");
@@ -972,32 +919,21 @@ LogicalResult ForeachOp::verify() {
 
 LogicalResult ReduceOp::verify() {
   Type inputType = getX().getType();
-  LogicalResult regionResult = success();
-
   // Check correct number of block arguments and return type.
   Region &formula = getRegion();
-  regionResult = verifyNumBlockArgs(this, formula, "reduce",
-                                    TypeRange{inputType, inputType}, inputType);
-  if (failed(regionResult))
-    return regionResult;
-
+  RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
+      this, formula, "reduce", TypeRange{inputType, inputType}, inputType))
   return success();
 }
 
 LogicalResult SelectOp::verify() {
   Builder b(getContext());
-
   Type inputType = getX().getType();
   Type boolType = b.getI1Type();
-  LogicalResult regionResult = success();
-
   // Check correct number of block arguments and return type.
   Region &formula = getRegion();
-  regionResult = verifyNumBlockArgs(this, formula, "select",
-                                    TypeRange{inputType}, boolType);
-  if (failed(regionResult))
-    return regionResult;
-
+  RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(this, formula, "select",
+                                              TypeRange{inputType}, boolType))
   return success();
 }
 
@@ -1025,15 +961,8 @@ LogicalResult SortOp::verify() {
     }
     return success();
   };
-
-  LogicalResult result = checkTypes(getXs());
-  if (failed(result))
-    return result;
-
-  if (n)
-    return checkTypes(getYs(), false);
-
-  return success();
+  RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
+  return n ? checkTypes(getYs(), false) : success();
 }
 
 LogicalResult SortCooOp::verify() {
@@ -1084,6 +1013,8 @@ LogicalResult YieldOp::verify() {
                      "reduce, select or foreach");
 }
 
+#undef RETURN_FAILURE_IF_FAILED
+
 //===----------------------------------------------------------------------===//
 // TensorDialect Methods.
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list