[Mlir-commits] [mlir] 743fbcb - [mlir][sparse] IR/SparseTensorDialect.cpp: misc code cleanup
wren romano
llvmlistbot at llvm.org
Fri Jan 20 13:31:46 PST 2023
Author: wren romano
Date: 2023-01-20T13:31:39-08:00
New Revision: 743fbcb79d9af759377df5f5929ffdd38ff52b09
URL: https://github.com/llvm/llvm-project/commit/743fbcb79d9af759377df5f5929ffdd38ff52b09
DIFF: https://github.com/llvm/llvm-project/commit/743fbcb79d9af759377df5f5929ffdd38ff52b09.diff
LOG: [mlir][sparse] IR/SparseTensorDialect.cpp: misc code cleanup
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D142072
Added:
Modified:
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 5ea4f7ca63eae..0ad21a1729970 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -101,16 +101,18 @@ SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
<< "expect positive value or ? for slice offset/size/stride";
}
+static Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth) {
+ if (bitwidth)
+ return IntegerType::get(ctx, bitwidth);
+ return IndexType::get(ctx);
+}
+
Type SparseTensorEncodingAttr::getPointerType() const {
- unsigned ptrWidth = getPointerBitWidth();
- Type indexType = IndexType::get(getContext());
- return ptrWidth ? IntegerType::get(getContext(), ptrWidth) : indexType;
+ return getIntegerOrIndexType(getContext(), getPointerBitWidth());
}
Type SparseTensorEncodingAttr::getIndexType() const {
- unsigned idxWidth = getIndexBitWidth();
- Type indexType = IndexType::get(getContext());
- return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType;
+ return getIntegerOrIndexType(getContext(), getIndexBitWidth());
}
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
@@ -157,11 +159,30 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(unsigned lvl) const {
return getStaticDimSliceStride(toOrigDim(*this, lvl));
}
+const static DimLevelType validDLTs[] = {
+ DimLevelType::Dense, DimLevelType::Compressed,
+ DimLevelType::CompressedNu, DimLevelType::CompressedNo,
+ DimLevelType::CompressedNuNo, DimLevelType::Singleton,
+ DimLevelType::SingletonNu, DimLevelType::SingletonNo,
+ DimLevelType::SingletonNuNo};
+
+static std::optional<DimLevelType> parseDLT(StringRef str) {
+ for (DimLevelType dlt : validDLTs)
+ if (str == toMLIRString(dlt))
+ return dlt;
+ return std::nullopt;
+}
+
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
#define RETURN_ON_FAIL(stmt) \
if (failed(stmt)) { \
return {}; \
}
+#define ERROR_IF(COND, MSG) \
+ if (COND) { \
+ parser.emitError(parser.getNameLoc(), MSG); \
+ return {}; \
+ }
RETURN_ON_FAIL(parser.parseLess())
RETURN_ON_FAIL(parser.parseLBrace())
@@ -191,37 +212,13 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr));
auto arrayAttr = attr.dyn_cast<ArrayAttr>();
- if (!arrayAttr) {
- parser.emitError(parser.getNameLoc(),
- "expected an array for dimension level types");
- return {};
- }
+ ERROR_IF(!arrayAttr, "expected an array for dimension level types")
for (auto i : arrayAttr) {
auto strAttr = i.dyn_cast<StringAttr>();
- if (!strAttr) {
- parser.emitError(parser.getNameLoc(),
- "expected a string value in dimension level types");
- return {};
- }
+ ERROR_IF(!strAttr, "expected a string value in dimension level types")
auto strVal = strAttr.getValue();
- if (strVal == "dense") {
- dlt.push_back(DimLevelType::Dense);
- } else if (strVal == "compressed") {
- dlt.push_back(DimLevelType::Compressed);
- } else if (strVal == "compressed-nu") {
- dlt.push_back(DimLevelType::CompressedNu);
- } else if (strVal == "compressed-no") {
- dlt.push_back(DimLevelType::CompressedNo);
- } else if (strVal == "compressed-nu-no") {
- dlt.push_back(DimLevelType::CompressedNuNo);
- } else if (strVal == "singleton") {
- dlt.push_back(DimLevelType::Singleton);
- } else if (strVal == "singleton-nu") {
- dlt.push_back(DimLevelType::SingletonNu);
- } else if (strVal == "singleton-no") {
- dlt.push_back(DimLevelType::SingletonNo);
- } else if (strVal == "singleton-nu-no") {
- dlt.push_back(DimLevelType::SingletonNuNo);
+ if (auto optDLT = parseDLT(strVal)) {
+ dlt.push_back(optDLT.value());
} else {
parser.emitError(parser.getNameLoc(),
"unexpected dimension level type: ")
@@ -232,46 +229,26 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
} else if (attrName == "dimOrdering") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
-
auto affineAttr = attr.dyn_cast<AffineMapAttr>();
- if (!affineAttr) {
- parser.emitError(parser.getNameLoc(),
- "expected an affine map for dimension ordering");
- return {};
- }
+ ERROR_IF(!affineAttr, "expected an affine map for dimension ordering")
dimOrd = affineAttr.getValue();
} else if (attrName == "higherOrdering") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
-
auto affineAttr = attr.dyn_cast<AffineMapAttr>();
- if (!affineAttr) {
- parser.emitError(parser.getNameLoc(),
- "expected an affine map for higher ordering");
- return {};
- }
+ ERROR_IF(!affineAttr, "expected an affine map for higher ordering")
higherOrd = affineAttr.getValue();
} else if (attrName == "pointerBitWidth") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
-
auto intAttr = attr.dyn_cast<IntegerAttr>();
- if (!intAttr) {
- parser.emitError(parser.getNameLoc(),
- "expected an integral pointer bitwidth");
- return {};
- }
+ ERROR_IF(!intAttr, "expected an integral pointer bitwidth")
ptr = intAttr.getInt();
} else if (attrName == "indexBitWidth") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
-
auto intAttr = attr.dyn_cast<IntegerAttr>();
- if (!intAttr) {
- parser.emitError(parser.getNameLoc(),
- "expected an integral index bitwidth");
- return {};
- }
+ ERROR_IF(!intAttr, "expected an integral index bitwidth")
ind = intAttr.getInt();
} else if (attrName == "slice") {
RETURN_ON_FAIL(parser.parseLSquare())
@@ -298,6 +275,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
RETURN_ON_FAIL(parser.parseRBrace())
RETURN_ON_FAIL(parser.parseGreater())
+#undef ERROR_IF
#undef RETURN_ON_FAIL
// Construct struct-like storage for attribute.
@@ -367,18 +345,21 @@ LogicalResult SparseTensorEncodingAttr::verify(
return emitError() << "unexpected mismatch in dimension slices and "
"dimension level type size";
}
-
return success();
}
+#define RETURN_FAILURE_IF_FAILED(X) \
+ if (failed(X)) { \
+ return failure(); \
+ }
+
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
ArrayRef<int64_t> shape, Type elementType,
function_ref<InFlightDiagnostic()> emitError) const {
// Check structural integrity.
- if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
- getHigherOrdering(), getPointerBitWidth(),
- getIndexBitWidth(), getDimSlices())))
- return failure();
+ RETURN_FAILURE_IF_FAILED(verify(
+ emitError, getDimLevelType(), getDimOrdering(), getHigherOrdering(),
+ getPointerBitWidth(), getIndexBitWidth(), getDimSlices()))
// Check integrity with tensor type specifics. Dimension ordering is optional,
// but we always should have dimension level types for the full rank.
unsigned size = shape.size();
@@ -435,23 +416,17 @@ static bool isCOOType(SparseTensorEncodingAttr enc, uint64_t s, bool isUnique) {
bool mlir::sparse_tensor::isUniqueCOOType(RankedTensorType tp) {
SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp);
- if (!enc)
- return false;
-
- return isCOOType(enc, 0, /*isUnique=*/true);
+ return enc && isCOOType(enc, 0, /*isUnique=*/true);
}
unsigned mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
- unsigned rank = enc.getDimLevelType().size();
- if (rank <= 1)
- return rank;
-
+ const unsigned rank = enc.getDimLevelType().size();
// We only consider COO region with at least two dimensions for the purpose
// of AOS storage optimization.
- for (unsigned r = 0; r < rank - 1; r++) {
- if (isCOOType(enc, r, /*isUnique=*/false))
- return r;
- }
+ if (rank > 1)
+ for (unsigned r = 0; r < rank - 1; r++)
+ if (isCOOType(enc, r, /*isUnique=*/false))
+ return r;
return rank;
}
@@ -541,10 +516,8 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
std::optional<APInt> dim) const {
- std::optional<unsigned> intDim;
- if (dim)
- intDim = dim.value().getZExtValue();
- return getFieldType(kind, intDim);
+ return getFieldType(kind, dim ? std::optional(dim.value().getZExtValue())
+ : std::nullopt);
}
//===----------------------------------------------------------------------===//
@@ -552,17 +525,12 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
//===----------------------------------------------------------------------===//
static LogicalResult isInBounds(uint64_t dim, Value tensor) {
- uint64_t rank = tensor.getType().cast<RankedTensorType>().getRank();
- if (dim >= rank)
- return failure();
- return success(); // in bounds
+ return success(dim < tensor.getType().cast<RankedTensorType>().getRank());
}
static LogicalResult isMatchingWidth(Value result, unsigned width) {
- Type etp = result.getType().cast<MemRefType>().getElementType();
- if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
- return success();
- return failure();
+ const Type etp = result.getType().cast<MemRefType>().getElementType();
+ return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
}
static LogicalResult verifySparsifierGetterSetter(
@@ -663,11 +631,8 @@ LogicalResult ToValuesOp::verify() {
}
LogicalResult GetStorageSpecifierOp::verify() {
- if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(),
- getSpecifier(), getOperation()))) {
- return failure();
- }
-
+ RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
+ getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
// Checks the result type
if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
getResult().getType()) {
@@ -692,11 +657,8 @@ OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
}
LogicalResult SetStorageSpecifierOp::verify() {
- if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(),
- getSpecifier(), getOperation()))) {
- return failure();
- }
-
+ RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
+ getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
// Checks the input type
if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
getValue().getType()) {
@@ -748,59 +710,45 @@ LogicalResult BinaryOp::verify() {
// Check correct number of block arguments and return type for each
// non-empty region.
- LogicalResult regionResult = success();
if (!overlap.empty()) {
- regionResult = verifyNumBlockArgs(
- this, overlap, "overlap", TypeRange{leftType, rightType}, outputType);
- if (failed(regionResult))
- return regionResult;
+ RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
+ this, overlap, "overlap", TypeRange{leftType, rightType}, outputType))
}
if (!left.empty()) {
- regionResult =
- verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType);
- if (failed(regionResult))
- return regionResult;
+ RETURN_FAILURE_IF_FAILED(
+ verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType))
} else if (getLeftIdentity()) {
if (leftType != outputType)
return emitError("left=identity requires first argument to have the same "
"type as the output");
}
if (!right.empty()) {
- regionResult = verifyNumBlockArgs(this, right, "right",
- TypeRange{rightType}, outputType);
- if (failed(regionResult))
- return regionResult;
+ RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
+ this, right, "right", TypeRange{rightType}, outputType))
} else if (getRightIdentity()) {
if (rightType != outputType)
return emitError("right=identity requires second argument to have the "
"same type as the output");
}
-
return success();
}
LogicalResult UnaryOp::verify() {
Type inputType = getX().getType();
Type outputType = getOutput().getType();
- LogicalResult regionResult = success();
// Check correct number of block arguments and return type for each
// non-empty region.
Region &present = getPresentRegion();
if (!present.empty()) {
- regionResult = verifyNumBlockArgs(this, present, "present",
- TypeRange{inputType}, outputType);
- if (failed(regionResult))
- return regionResult;
+ RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
+ this, present, "present", TypeRange{inputType}, outputType))
}
Region &absent = getAbsentRegion();
if (!absent.empty()) {
- regionResult =
- verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType);
- if (failed(regionResult))
- return regionResult;
+ RETURN_FAILURE_IF_FAILED(
+ verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType))
}
-
return success();
}
@@ -880,8 +828,7 @@ void PushBackOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult PushBackOp::verify() {
- Value n = getN();
- if (n) {
+ if (Value n = getN()) {
auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
if (nValue && nValue.value() < 1)
return emitOpError("n must be not less than 1");
@@ -972,32 +919,21 @@ LogicalResult ForeachOp::verify() {
LogicalResult ReduceOp::verify() {
Type inputType = getX().getType();
- LogicalResult regionResult = success();
-
// Check correct number of block arguments and return type.
Region &formula = getRegion();
- regionResult = verifyNumBlockArgs(this, formula, "reduce",
- TypeRange{inputType, inputType}, inputType);
- if (failed(regionResult))
- return regionResult;
-
+ RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
+ this, formula, "reduce", TypeRange{inputType, inputType}, inputType))
return success();
}
LogicalResult SelectOp::verify() {
Builder b(getContext());
-
Type inputType = getX().getType();
Type boolType = b.getI1Type();
- LogicalResult regionResult = success();
-
// Check correct number of block arguments and return type.
Region &formula = getRegion();
- regionResult = verifyNumBlockArgs(this, formula, "select",
- TypeRange{inputType}, boolType);
- if (failed(regionResult))
- return regionResult;
-
+ RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(this, formula, "select",
+ TypeRange{inputType}, boolType))
return success();
}
@@ -1025,15 +961,8 @@ LogicalResult SortOp::verify() {
}
return success();
};
-
- LogicalResult result = checkTypes(getXs());
- if (failed(result))
- return result;
-
- if (n)
- return checkTypes(getYs(), false);
-
- return success();
+ RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
+ return n ? checkTypes(getYs(), false) : success();
}
LogicalResult SortCooOp::verify() {
@@ -1084,6 +1013,8 @@ LogicalResult YieldOp::verify() {
"reduce, select or foreach");
}
+#undef RETURN_FAILURE_IF_FAILED
+
//===----------------------------------------------------------------------===//
// TensorDialect Methods.
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list