[Mlir-commits] [mlir] 06aa6ec - [mlir][sparse] refactor handling of merger leafs and ops
Aart Bik
llvmlistbot at llvm.org
Thu Jun 9 11:36:07 PDT 2022
Author: Aart Bik
Date: 2022-06-09T11:35:54-07:00
New Revision: 06aa6ec87dba8ee34e4ea4bd47e6c9f2c06ebd7c
URL: https://github.com/llvm/llvm-project/commit/06aa6ec87dba8ee34e4ea4bd47e6c9f2c06ebd7c
DIFF: https://github.com/llvm/llvm-project/commit/06aa6ec87dba8ee34e4ea4bd47e6c9f2c06ebd7c.diff
LOG: [mlir][sparse] refactor handling of merger leafs and ops
Using "default:" in the switch statemements that handle all our
merger ops has become a bit cumbersome since it is easy to overlook
parts of the code that need to handle ops specifically. By enforcing
full switch statements without "default:", we get a compiler warning
when cases are overlooked.
Reviewed By: wrengr
Differential Revision: https://reviews.llvm.org/D127263
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 80d2dbba187b8..1d81dafcd0eb8 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -25,6 +25,7 @@ namespace sparse_tensor {
TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
: kind(k), val(v), op(o) {
switch (kind) {
+ // Leaf.
case kTensor:
assert(x != -1u && y == -1u && !v && !o);
tensor = x;
@@ -36,6 +37,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
assert(x != -1u && y == -1u && !v && !o);
index = x;
break;
+ // Unary operations.
case kAbsF:
case kAbsC:
case kCeilF:
@@ -86,13 +88,32 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
children.e0 = x;
children.e1 = y;
break;
- case kBinary:
- assert(x != -1u && y != -1u && !v && o);
+ // Binary operations.
+ case kMulF:
+ case kMulC:
+ case kMulI:
+ case kDivF:
+ case kDivC:
+ case kDivS:
+ case kDivU:
+ case kAddF:
+ case kAddC:
+ case kAddI:
+ case kSubF:
+ case kSubC:
+ case kSubI:
+ case kAndI:
+ case kOrI:
+ case kXorI:
+ case kShrS:
+ case kShrU:
+ case kShlI:
+ assert(x != -1u && y != -1u && !v && !o);
children.e0 = x;
children.e1 = y;
break;
- default:
- assert(x != -1u && y != -1u && !v && !o);
+ case kBinary:
+ assert(x != -1u && y != -1u && !v && o);
children.e0 = x;
children.e1 = y;
break;
@@ -280,8 +301,13 @@ bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
bool Merger::isSingleCondition(unsigned t, unsigned e) const {
switch (tensorExps[e].kind) {
+ // Leaf.
case kTensor:
return tensorExps[e].tensor == t;
+ case kInvariant:
+ case kIndex:
+ return false;
+ // Unary operations.
case kAbsF:
case kAbsC:
case kCeilF:
@@ -313,6 +339,10 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kCRe:
case kBitCast:
return isSingleCondition(t, tensorExps[e].children.e0);
+ case kBinaryBranch:
+ case kUnary:
+ return false;
+ // Binary operations.
case kDivF: // note: x / c only
case kDivC:
case kDivS:
@@ -339,7 +369,12 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kAddI:
return isSingleCondition(t, tensorExps[e].children.e0) &&
isSingleCondition(t, tensorExps[e].children.e1);
- default:
+ case kSubF:
+ case kSubC:
+ case kSubI:
+ case kOrI:
+ case kXorI:
+ case kBinary:
return false;
}
}
@@ -352,12 +387,14 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
static const char *kindToOpSymbol(Kind kind) {
switch (kind) {
+ // Leaf.
case kTensor:
return "tensor";
case kInvariant:
return "invariant";
case kIndex:
return "index";
+ // Unary operations.
case kAbsF:
case kAbsC:
return "abs";
@@ -404,6 +441,7 @@ static const char *kindToOpSymbol(Kind kind) {
return "binary_branch";
case kUnary:
return "unary";
+ // Binary operations.
case kMulF:
case kMulC:
case kMulI:
@@ -441,6 +479,7 @@ static const char *kindToOpSymbol(Kind kind) {
void Merger::dumpExp(unsigned e) const {
switch (tensorExps[e].kind) {
+ // Leaf.
case kTensor:
if (tensorExps[e].tensor == syntheticTensor)
llvm::dbgs() << "synthetic_";
@@ -454,7 +493,9 @@ void Merger::dumpExp(unsigned e) const {
case kIndex:
llvm::dbgs() << "index_" << tensorExps[e].index;
break;
+ // Unary operations.
case kAbsF:
+ case kAbsC:
case kCeilF:
case kFloorF:
case kSqrtF:
@@ -462,10 +503,13 @@ void Merger::dumpExp(unsigned e) const {
case kExpm1F:
case kExpm1C:
case kLog1pF:
+ case kLog1pC:
case kSinF:
+ case kSinC:
case kTanhF:
case kTanhC:
case kNegF:
+ case kNegC:
case kNegI:
case kTruncF:
case kExtF:
@@ -477,11 +521,35 @@ void Merger::dumpExp(unsigned e) const {
case kCastU:
case kCastIdx:
case kTruncI:
+ case kCIm:
+ case kCRe:
case kBitCast:
+ case kBinaryBranch:
+ case kUnary:
llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
dumpExp(tensorExps[e].children.e0);
break;
- default:
+ // Binary operations.
+ case kMulF:
+ case kMulC:
+ case kMulI:
+ case kDivF:
+ case kDivC:
+ case kDivS:
+ case kDivU:
+ case kAddF:
+ case kAddC:
+ case kAddI:
+ case kSubF:
+ case kSubC:
+ case kSubI:
+ case kAndI:
+ case kOrI:
+ case kXorI:
+ case kShrS:
+ case kShrU:
+ case kShlI:
+ case kBinary:
llvm::dbgs() << "(";
dumpExp(tensorExps[e].children.e0);
llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
@@ -542,6 +610,7 @@ void Merger::dumpBits(const BitVector &bits) const {
unsigned Merger::buildLattices(unsigned e, unsigned i) {
Kind kind = tensorExps[e].kind;
switch (kind) {
+ // Leaf.
case kTensor:
case kInvariant:
case kIndex: {
@@ -560,11 +629,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
latSets[s].push_back(addLat(t, i, e));
return s;
}
+ // Unary operations.
case kAbsF:
case kAbsC:
case kCeilF:
- case kCIm:
- case kCRe:
case kFloorF:
case kSqrtF:
case kSqrtC:
@@ -589,6 +657,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
case kCastU:
case kCastIdx:
case kTruncI:
+ case kCIm:
+ case kCRe:
case kBitCast:
// A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
// lattice set of the operand through the operator into a new set.
@@ -625,6 +695,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
unsigned rhs = addExp(kInvariant, absentVal);
return takeDisj(kind, child0, buildLattices(rhs, i), unop);
}
+ // Binary operations.
case kMulF:
case kMulC:
case kMulI:
@@ -955,16 +1026,17 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
Value v0, Value v1) {
switch (tensorExps[e].kind) {
+ // Leaf.
case kTensor:
case kInvariant:
case kIndex:
llvm_unreachable("unexpected non-op");
- // Unary ops.
+ // Unary operations.
case kAbsF:
return rewriter.create<math::AbsOp>(loc, v0);
case kAbsC: {
- auto type = v0.getType().template cast<ComplexType>();
- auto eltType = type.getElementType().template cast<FloatType>();
+ auto type = v0.getType().cast<ComplexType>();
+ auto eltType = type.getElementType().cast<FloatType>();
return rewriter.create<complex::AbsOp>(loc, eltType, v0);
}
case kCeilF:
@@ -1021,18 +1093,19 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
case kTruncI:
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
- case kCIm:
+ case kCIm: {
+ auto type = v0.getType().cast<ComplexType>();
+ auto eltType = type.getElementType().cast<FloatType>();
+ return rewriter.create<complex::ImOp>(loc, eltType, v0);
+ }
case kCRe: {
- auto type = v0.getType().template cast<ComplexType>();
- auto eltType = type.getElementType().template cast<FloatType>();
- if (tensorExps[e].kind == kCIm)
- return rewriter.create<complex::ImOp>(loc, eltType, v0);
-
+ auto type = v0.getType().cast<ComplexType>();
+ auto eltType = type.getElementType().cast<FloatType>();
return rewriter.create<complex::ReOp>(loc, eltType, v0);
}
case kBitCast:
return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
- // Binary ops.
+ // Binary operations.
case kMulF:
return rewriter.create<arith::MulFOp>(loc, v0, v1);
case kMulC:
@@ -1071,8 +1144,7 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
case kShlI:
return rewriter.create<arith::ShLIOp>(loc, v0, v1);
- // Semiring ops with custom logic.
- case kBinaryBranch:
+ case kBinaryBranch: // semi-ring ops with custom logic.
return insertYieldOp(rewriter, loc,
*tensorExps[e].op->getBlock()->getParent(), {v0});
case kUnary:
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 4bdfa71d8bc49..f64251953c9f5 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -136,43 +136,78 @@ class MergerTestBase : public ::testing::Test {
}
/// Compares expressions for equality. Equality is defined recursively as:
- /// - Two expressions can only be equal if they have the same Kind.
- /// - Two binary expressions are equal if they have the same Kind and their
- /// children are equal.
- /// - Expressions with Kind invariant or tensor are equal if they have the
- /// same expression id.
+ /// - Operations are equal if they have the same kind and children.
+ /// - Leaf tensors are equal if they refer to the same tensor.
bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
auto tensorExp = merger.exp(e);
if (tensorExp.kind != pattern->kind)
return false;
- assert(tensorExp.kind != Kind::kInvariant &&
- "Invariant comparison not yet supported");
switch (tensorExp.kind) {
- case Kind::kTensor:
+ // Leaf.
+ case kTensor:
return tensorExp.tensor == pattern->tensorNum;
- case Kind::kAbsF:
- case Kind::kCeilF:
- case Kind::kFloorF:
- case Kind::kNegF:
- case Kind::kNegI:
+ case kInvariant:
+ case kIndex:
+ llvm_unreachable("invariant not handled yet");
+ // Unary operations.
+ case kAbsF:
+ case kAbsC:
+ case kCeilF:
+ case kFloorF:
+ case kSqrtF:
+ case kSqrtC:
+ case kExpm1F:
+ case kExpm1C:
+ case kLog1pF:
+ case kLog1pC:
+ case kSinF:
+ case kSinC:
+ case kTanhF:
+ case kTanhC:
+ case kNegF:
+ case kNegC:
+ case kNegI:
+ case kTruncF:
+ case kExtF:
+ case kCastFS:
+ case kCastFU:
+ case kCastSF:
+ case kCastUF:
+ case kCastS:
+ case kCastU:
+ case kCastIdx:
+ case kTruncI:
+ case kCIm:
+ case kCRe:
+ case kBitCast:
+ case kBinaryBranch:
+ case kUnary:
+ case kShlI:
+ case kBinary:
return compareExpression(tensorExp.children.e0, pattern->e0);
- case Kind::kMulF:
- case Kind::kMulI:
- case Kind::kDivF:
- case Kind::kDivS:
- case Kind::kDivU:
- case Kind::kAddF:
- case Kind::kAddI:
- case Kind::kSubF:
- case Kind::kSubI:
- case Kind::kAndI:
- case Kind::kOrI:
- case Kind::kXorI:
+ // Binary operations.
+ case kMulF:
+ case kMulC:
+ case kMulI:
+ case kDivF:
+ case kDivC:
+ case kDivS:
+ case kDivU:
+ case kAddF:
+ case kAddC:
+ case kAddI:
+ case kSubF:
+ case kSubC:
+ case kSubI:
+ case kAndI:
+ case kOrI:
+ case kXorI:
+ case kShrS:
+ case kShrU:
return compareExpression(tensorExp.children.e0, pattern->e0) &&
compareExpression(tensorExp.children.e1, pattern->e1);
- default:
- llvm_unreachable("Unhandled Kind");
}
+ llvm_unreachable("unexpected kind");
}
unsigned numTensors;
More information about the Mlir-commits
mailing list