[Mlir-commits] [mlir] [mlir][sparse] avoid excessive macro magic (PR #70276)

Aart Bik llvmlistbot at llvm.org
Wed Oct 25 18:26:23 PDT 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/70276

The shorthands are not even always shorter and the code is less clear than when simply written out.

>From 06b46d432e7595ce332ed781b3b73a9dca1f91a2 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 25 Oct 2023 18:23:48 -0700
Subject: [PATCH] [mlir][sparse] avoid excessive macro magic

The shorthands are not even always shorter and the
code is less clear than when simply written out.
---
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 137 ++++++++----------
 1 file changed, 64 insertions(+), 73 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index f05cbd8d16d9a76..359c0a696858329 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -34,11 +34,6 @@
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
-#define RETURN_FAILURE_IF_FAILED(X)                                            \
-  if (failed(X)) {                                                             \
-    return failure();                                                          \
-  }
-
 //===----------------------------------------------------------------------===//
 // Local convenience methods.
 //===----------------------------------------------------------------------===//
@@ -68,10 +63,6 @@ void StorageLayout::foreachField(
     llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
                             DimLevelType)>
         callback) const {
-#define RETURN_ON_FALSE(fidx, kind, lvl, dlt)                                  \
-  if (!(callback(fidx, kind, lvl, dlt)))                                       \
-    return;
-
   const auto lvlTypes = enc.getLvlTypes();
   const Level lvlRank = enc.getLvlRank();
   const Level cooStart = getCOOStart(enc);
@@ -81,21 +72,22 @@ void StorageLayout::foreachField(
   for (Level l = 0; l < end; l++) {
     const auto dlt = lvlTypes[l];
     if (isDLTWithPos(dlt)) {
-      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt);
+      if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt)))
+        return;
     }
     if (isDLTWithCrd(dlt)) {
-      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
+      if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt)))
+        return;
     }
   }
   // The values array.
-  RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
-                  DimLevelType::Undef);
-
+  if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
+                 DimLevelType::Undef)))
+    return;
   // Put metadata at the end.
-  RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
-                  DimLevelType::Undef);
-
-#undef RETURN_ON_FALSE
+  if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
+                 DimLevelType::Undef)))
+    return;
 }
 
 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
@@ -435,18 +427,11 @@ SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
 }
 
 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
-#define RETURN_ON_FAIL(stmt)                                                   \
-  if (failed(stmt)) {                                                          \
-    return {};                                                                 \
-  }
-#define ERROR_IF(COND, MSG)                                                    \
-  if (COND) {                                                                  \
-    parser.emitError(parser.getNameLoc(), MSG);                                \
-    return {};                                                                 \
-  }
-
-  RETURN_ON_FAIL(parser.parseLess())
-  RETURN_ON_FAIL(parser.parseLBrace())
+  // Open "<{" part.
+  if (failed(parser.parseLess()))
+    return {};
+  if (failed(parser.parseLBrace()))
+    return {};
 
   // Process the data from the parsed dictionary value into struct-like data.
   SmallVector<DimLevelType> lvlTypes;
@@ -466,13 +451,15 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
     }
     unsigned keyWordIndex = it - keys.begin();
     // Consume the `=` after keys
-    RETURN_ON_FAIL(parser.parseEqual())
+    if (failed(parser.parseEqual()))
+      return {};
     // Dispatch on keyword.
     switch (keyWordIndex) {
     case 0: { // map
       ir_detail::DimLvlMapParser cParser(parser);
       auto res = cParser.parseDimLvlMap();
-      RETURN_ON_FAIL(res);
+      if (failed(res))
+        return {};
       const auto &dlm = *res;
 
       const Level lvlRank = dlm.getLvlRank();
@@ -504,17 +491,27 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
     }
     case 1: { // posWidth
       Attribute attr;
-      RETURN_ON_FAIL(parser.parseAttribute(attr))
+      if (failed(parser.parseAttribute(attr)))
+        return {};
       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
-      ERROR_IF(!intAttr, "expected an integral position bitwidth")
+      if (!intAttr) {
+        parser.emitError(parser.getNameLoc(),
+                         "expected an integral position bitwidth");
+        return {};
+      }
       posWidth = intAttr.getInt();
       break;
     }
     case 2: { // crdWidth
       Attribute attr;
-      RETURN_ON_FAIL(parser.parseAttribute(attr))
+      if (failed(parser.parseAttribute(attr)))
+        return {};
       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
-      ERROR_IF(!intAttr, "expected an integral index bitwidth")
+      if (!intAttr) {
+        parser.emitError(parser.getNameLoc(),
+                         "expected an integral index bitwidth");
+        return {};
+      }
       crdWidth = intAttr.getInt();
       break;
     }
@@ -524,10 +521,11 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
       break;
   }
 
-  RETURN_ON_FAIL(parser.parseRBrace())
-  RETURN_ON_FAIL(parser.parseGreater())
-#undef ERROR_IF
-#undef RETURN_ON_FAIL
+  // Close "}>" part.
+  if (failed(parser.parseRBrace()))
+    return {};
+  if (failed(parser.parseGreater()))
+    return {};
 
   // Construct struct-like storage for attribute.
   if (!lvlToDim || lvlToDim.isEmpty()) {
@@ -668,9 +666,9 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
     function_ref<InFlightDiagnostic()> emitError) const {
   // Check structural integrity.  In particular, this ensures that the
   // level-rank is coherent across all the fields.
-  RETURN_FAILURE_IF_FAILED(verify(emitError, getLvlTypes(), getDimToLvl(),
-                                  getLvlToDim(), getPosWidth(), getCrdWidth(),
-                                  getDimSlices()))
+  if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
+                    getPosWidth(), getCrdWidth(), getDimSlices())))
+    return failure();
   // Check integrity with tensor type specifics.  In particular, we
   // need only check that the dimension-rank of the tensor agrees with
   // the dimension-rank of the encoding.
@@ -926,10 +924,6 @@ Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) {
   return toStoredDim(getSparseTensorEncoding(type), d);
 }
 
-//===----------------------------------------------------------------------===//
-// SparseTensorDialect Types.
-//===----------------------------------------------------------------------===//
-
 /// We normalized sparse tensor encoding attribute by always using
 /// ordered/unique DLT such that "compressed_nu_no" and "compressed_nu" (as well
 /// as other variants) lead to the same storage specifier type, and stripping
@@ -1340,9 +1334,8 @@ LogicalResult ToSliceStrideOp::verify() {
 }
 
 LogicalResult GetStorageSpecifierOp::verify() {
-  RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
-      getSpecifierKind(), getLevel(), getSpecifier(), getOperation()))
-  return success();
+  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
+                                      getSpecifier(), getOperation());
 }
 
 template <typename SpecifierOp>
@@ -1360,9 +1353,8 @@ OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult SetStorageSpecifierOp::verify() {
-  RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
-      getSpecifierKind(), getLevel(), getSpecifier(), getOperation()))
-  return success();
+  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
+                                      getSpecifier(), getOperation());
 }
 
 template <class T>
@@ -1404,20 +1396,23 @@ LogicalResult BinaryOp::verify() {
   // Check correct number of block arguments and return type for each
   // non-empty region.
   if (!overlap.empty()) {
-    RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
-        this, overlap, "overlap", TypeRange{leftType, rightType}, outputType))
+    if (failed(verifyNumBlockArgs(this, overlap, "overlap",
+                                  TypeRange{leftType, rightType}, outputType)))
+      return failure();
   }
   if (!left.empty()) {
-    RETURN_FAILURE_IF_FAILED(
-        verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType))
+    if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
+                                  outputType)))
+      return failure();
   } else if (getLeftIdentity()) {
     if (leftType != outputType)
       return emitError("left=identity requires first argument to have the same "
                        "type as the output");
   }
   if (!right.empty()) {
-    RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
-        this, right, "right", TypeRange{rightType}, outputType))
+    if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
+                                  outputType)))
+      return failure();
   } else if (getRightIdentity()) {
     if (rightType != outputType)
       return emitError("right=identity requires second argument to have the "
@@ -1434,13 +1429,15 @@ LogicalResult UnaryOp::verify() {
   // non-empty region.
   Region &present = getPresentRegion();
   if (!present.empty()) {
-    RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
-        this, present, "present", TypeRange{inputType}, outputType))
+    if (failed(verifyNumBlockArgs(this, present, "present",
+                                  TypeRange{inputType}, outputType)))
+      return failure();
   }
   Region &absent = getAbsentRegion();
   if (!absent.empty()) {
-    RETURN_FAILURE_IF_FAILED(
-        verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType))
+    if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
+                                  outputType)))
+      return failure();
     // Absent branch can only yield invariant values.
     Block *absentBlock = &absent.front();
     Block *parent = getOperation()->getBlock();
@@ -1655,22 +1652,18 @@ LogicalResult ReorderCOOOp::verify() {
 
 LogicalResult ReduceOp::verify() {
   Type inputType = getX().getType();
-  // Check correct number of block arguments and return type.
   Region &formula = getRegion();
-  RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
-      this, formula, "reduce", TypeRange{inputType, inputType}, inputType))
-  return success();
+  return verifyNumBlockArgs(this, formula, "reduce",
+                            TypeRange{inputType, inputType}, inputType);
 }
 
 LogicalResult SelectOp::verify() {
   Builder b(getContext());
   Type inputType = getX().getType();
   Type boolType = b.getI1Type();
-  // Check correct number of block arguments and return type.
   Region &formula = getRegion();
-  RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(this, formula, "select",
-                                              TypeRange{inputType}, boolType))
-  return success();
+  return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
+                            boolType);
 }
 
 LogicalResult SortOp::verify() {
@@ -1725,8 +1718,6 @@ LogicalResult YieldOp::verify() {
                      "reduce, select or foreach");
 }
 
-#undef RETURN_FAILURE_IF_FAILED
-
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,



More information about the Mlir-commits mailing list