[Mlir-commits] [mlir] [mlir][sparse] avoid excessive macro magic (PR #70276)
Aart Bik
llvmlistbot at llvm.org
Wed Oct 25 18:26:23 PDT 2023
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/70276
The shorthands are not even always shorter and the code is less clear than when simply written out.
>From 06b46d432e7595ce332ed781b3b73a9dca1f91a2 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 25 Oct 2023 18:23:48 -0700
Subject: [PATCH] [mlir][sparse] avoid excessive macro magic
The shorthands are not even always shorter and the
code is less clear than when simply written out.
---
.../SparseTensor/IR/SparseTensorDialect.cpp | 137 ++++++++----------
1 file changed, 64 insertions(+), 73 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index f05cbd8d16d9a76..359c0a696858329 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -34,11 +34,6 @@
using namespace mlir;
using namespace mlir::sparse_tensor;
-#define RETURN_FAILURE_IF_FAILED(X) \
- if (failed(X)) { \
- return failure(); \
- }
-
//===----------------------------------------------------------------------===//
// Local convenience methods.
//===----------------------------------------------------------------------===//
@@ -68,10 +63,6 @@ void StorageLayout::foreachField(
llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
DimLevelType)>
callback) const {
-#define RETURN_ON_FALSE(fidx, kind, lvl, dlt) \
- if (!(callback(fidx, kind, lvl, dlt))) \
- return;
-
const auto lvlTypes = enc.getLvlTypes();
const Level lvlRank = enc.getLvlRank();
const Level cooStart = getCOOStart(enc);
@@ -81,21 +72,22 @@ void StorageLayout::foreachField(
for (Level l = 0; l < end; l++) {
const auto dlt = lvlTypes[l];
if (isDLTWithPos(dlt)) {
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt);
+ if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt)))
+ return;
}
if (isDLTWithCrd(dlt)) {
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
+ if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt)))
+ return;
}
}
// The values array.
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
- DimLevelType::Undef);
-
+ if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
+ DimLevelType::Undef)))
+ return;
// Put metadata at the end.
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
- DimLevelType::Undef);
-
-#undef RETURN_ON_FALSE
+ if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
+ DimLevelType::Undef)))
+ return;
}
void sparse_tensor::foreachFieldAndTypeInSparseTensor(
@@ -435,18 +427,11 @@ SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
}
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
-#define RETURN_ON_FAIL(stmt) \
- if (failed(stmt)) { \
- return {}; \
- }
-#define ERROR_IF(COND, MSG) \
- if (COND) { \
- parser.emitError(parser.getNameLoc(), MSG); \
- return {}; \
- }
-
- RETURN_ON_FAIL(parser.parseLess())
- RETURN_ON_FAIL(parser.parseLBrace())
+ // Open "<{" part.
+ if (failed(parser.parseLess()))
+ return {};
+ if (failed(parser.parseLBrace()))
+ return {};
// Process the data from the parsed dictionary value into struct-like data.
SmallVector<DimLevelType> lvlTypes;
@@ -466,13 +451,15 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
}
unsigned keyWordIndex = it - keys.begin();
// Consume the `=` after keys
- RETURN_ON_FAIL(parser.parseEqual())
+ if (failed(parser.parseEqual()))
+ return {};
// Dispatch on keyword.
switch (keyWordIndex) {
case 0: { // map
ir_detail::DimLvlMapParser cParser(parser);
auto res = cParser.parseDimLvlMap();
- RETURN_ON_FAIL(res);
+ if (failed(res))
+ return {};
const auto &dlm = *res;
const Level lvlRank = dlm.getLvlRank();
@@ -504,17 +491,27 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
}
case 1: { // posWidth
Attribute attr;
- RETURN_ON_FAIL(parser.parseAttribute(attr))
+ if (failed(parser.parseAttribute(attr)))
+ return {};
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
- ERROR_IF(!intAttr, "expected an integral position bitwidth")
+ if (!intAttr) {
+ parser.emitError(parser.getNameLoc(),
+ "expected an integral position bitwidth");
+ return {};
+ }
posWidth = intAttr.getInt();
break;
}
case 2: { // crdWidth
Attribute attr;
- RETURN_ON_FAIL(parser.parseAttribute(attr))
+ if (failed(parser.parseAttribute(attr)))
+ return {};
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
- ERROR_IF(!intAttr, "expected an integral index bitwidth")
+ if (!intAttr) {
+ parser.emitError(parser.getNameLoc(),
+ "expected an integral index bitwidth");
+ return {};
+ }
crdWidth = intAttr.getInt();
break;
}
@@ -524,10 +521,11 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
break;
}
- RETURN_ON_FAIL(parser.parseRBrace())
- RETURN_ON_FAIL(parser.parseGreater())
-#undef ERROR_IF
-#undef RETURN_ON_FAIL
+ // Close "}>" part.
+ if (failed(parser.parseRBrace()))
+ return {};
+ if (failed(parser.parseGreater()))
+ return {};
// Construct struct-like storage for attribute.
if (!lvlToDim || lvlToDim.isEmpty()) {
@@ -668,9 +666,9 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
function_ref<InFlightDiagnostic()> emitError) const {
// Check structural integrity. In particular, this ensures that the
// level-rank is coherent across all the fields.
- RETURN_FAILURE_IF_FAILED(verify(emitError, getLvlTypes(), getDimToLvl(),
- getLvlToDim(), getPosWidth(), getCrdWidth(),
- getDimSlices()))
+ if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
+ getPosWidth(), getCrdWidth(), getDimSlices())))
+ return failure();
// Check integrity with tensor type specifics. In particular, we
// need only check that the dimension-rank of the tensor agrees with
// the dimension-rank of the encoding.
@@ -926,10 +924,6 @@ Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) {
return toStoredDim(getSparseTensorEncoding(type), d);
}
-//===----------------------------------------------------------------------===//
-// SparseTensorDialect Types.
-//===----------------------------------------------------------------------===//
-
/// We normalized sparse tensor encoding attribute by always using
/// ordered/unique DLT such that "compressed_nu_no" and "compressed_nu" (as well
/// as other variants) lead to the same storage specifier type, and stripping
@@ -1340,9 +1334,8 @@ LogicalResult ToSliceStrideOp::verify() {
}
LogicalResult GetStorageSpecifierOp::verify() {
- RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
- getSpecifierKind(), getLevel(), getSpecifier(), getOperation()))
- return success();
+ return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
+ getSpecifier(), getOperation());
}
template <typename SpecifierOp>
@@ -1360,9 +1353,8 @@ OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
}
LogicalResult SetStorageSpecifierOp::verify() {
- RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
- getSpecifierKind(), getLevel(), getSpecifier(), getOperation()))
- return success();
+ return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
+ getSpecifier(), getOperation());
}
template <class T>
@@ -1404,20 +1396,23 @@ LogicalResult BinaryOp::verify() {
// Check correct number of block arguments and return type for each
// non-empty region.
if (!overlap.empty()) {
- RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
- this, overlap, "overlap", TypeRange{leftType, rightType}, outputType))
+ if (failed(verifyNumBlockArgs(this, overlap, "overlap",
+ TypeRange{leftType, rightType}, outputType)))
+ return failure();
}
if (!left.empty()) {
- RETURN_FAILURE_IF_FAILED(
- verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType))
+ if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
+ outputType)))
+ return failure();
} else if (getLeftIdentity()) {
if (leftType != outputType)
return emitError("left=identity requires first argument to have the same "
"type as the output");
}
if (!right.empty()) {
- RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
- this, right, "right", TypeRange{rightType}, outputType))
+ if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
+ outputType)))
+ return failure();
} else if (getRightIdentity()) {
if (rightType != outputType)
return emitError("right=identity requires second argument to have the "
@@ -1434,13 +1429,15 @@ LogicalResult UnaryOp::verify() {
// non-empty region.
Region &present = getPresentRegion();
if (!present.empty()) {
- RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
- this, present, "present", TypeRange{inputType}, outputType))
+ if (failed(verifyNumBlockArgs(this, present, "present",
+ TypeRange{inputType}, outputType)))
+ return failure();
}
Region &absent = getAbsentRegion();
if (!absent.empty()) {
- RETURN_FAILURE_IF_FAILED(
- verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType))
+ if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
+ outputType)))
+ return failure();
// Absent branch can only yield invariant values.
Block *absentBlock = &absent.front();
Block *parent = getOperation()->getBlock();
@@ -1655,22 +1652,18 @@ LogicalResult ReorderCOOOp::verify() {
LogicalResult ReduceOp::verify() {
Type inputType = getX().getType();
- // Check correct number of block arguments and return type.
Region &formula = getRegion();
- RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
- this, formula, "reduce", TypeRange{inputType, inputType}, inputType))
- return success();
+ return verifyNumBlockArgs(this, formula, "reduce",
+ TypeRange{inputType, inputType}, inputType);
}
LogicalResult SelectOp::verify() {
Builder b(getContext());
Type inputType = getX().getType();
Type boolType = b.getI1Type();
- // Check correct number of block arguments and return type.
Region &formula = getRegion();
- RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(this, formula, "select",
- TypeRange{inputType}, boolType))
- return success();
+ return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
+ boolType);
}
LogicalResult SortOp::verify() {
@@ -1725,8 +1718,6 @@ LogicalResult YieldOp::verify() {
"reduce, select or foreach");
}
-#undef RETURN_FAILURE_IF_FAILED
-
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
More information about the Mlir-commits
mailing list