[Mlir-commits] [mlir] 06aa6ec - [mlir][sparse] refactor handling of merger leafs and ops

Aart Bik llvmlistbot at llvm.org
Thu Jun 9 11:36:07 PDT 2022


Author: Aart Bik
Date: 2022-06-09T11:35:54-07:00
New Revision: 06aa6ec87dba8ee34e4ea4bd47e6c9f2c06ebd7c

URL: https://github.com/llvm/llvm-project/commit/06aa6ec87dba8ee34e4ea4bd47e6c9f2c06ebd7c
DIFF: https://github.com/llvm/llvm-project/commit/06aa6ec87dba8ee34e4ea4bd47e6c9f2c06ebd7c.diff

LOG: [mlir][sparse] refactor handling of merger leafs and ops

Using "default:" in the switch statemements that handle all our
merger ops has become a bit cumbersome since it is easy to overlook
parts of the code that need to handle ops specifically. By enforcing
full switch statements without "default:", we get a compiler warning
when cases are overlooked.

Reviewed By: wrengr

Differential Revision: https://reviews.llvm.org/D127263

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 80d2dbba187b8..1d81dafcd0eb8 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -25,6 +25,7 @@ namespace sparse_tensor {
 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
     : kind(k), val(v), op(o) {
   switch (kind) {
+  // Leaf.
   case kTensor:
     assert(x != -1u && y == -1u && !v && !o);
     tensor = x;
@@ -36,6 +37,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
     assert(x != -1u && y == -1u && !v && !o);
     index = x;
     break;
+  // Unary operations.
   case kAbsF:
   case kAbsC:
   case kCeilF:
@@ -86,13 +88,32 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
     children.e0 = x;
     children.e1 = y;
     break;
-  case kBinary:
-    assert(x != -1u && y != -1u && !v && o);
+  // Binary operations.
+  case kMulF:
+  case kMulC:
+  case kMulI:
+  case kDivF:
+  case kDivC:
+  case kDivS:
+  case kDivU:
+  case kAddF:
+  case kAddC:
+  case kAddI:
+  case kSubF:
+  case kSubC:
+  case kSubI:
+  case kAndI:
+  case kOrI:
+  case kXorI:
+  case kShrS:
+  case kShrU:
+  case kShlI:
+    assert(x != -1u && y != -1u && !v && !o);
     children.e0 = x;
     children.e1 = y;
     break;
-  default:
-    assert(x != -1u && y != -1u && !v && !o);
+  case kBinary:
+    assert(x != -1u && y != -1u && !v && o);
     children.e0 = x;
     children.e1 = y;
     break;
@@ -280,8 +301,13 @@ bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
 
 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   switch (tensorExps[e].kind) {
+  // Leaf.
   case kTensor:
     return tensorExps[e].tensor == t;
+  case kInvariant:
+  case kIndex:
+    return false;
+  // Unary operations.
   case kAbsF:
   case kAbsC:
   case kCeilF:
@@ -313,6 +339,10 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   case kCRe:
   case kBitCast:
     return isSingleCondition(t, tensorExps[e].children.e0);
+  case kBinaryBranch:
+  case kUnary:
+    return false;
+  // Binary operations.
   case kDivF: // note: x / c only
   case kDivC:
   case kDivS:
@@ -339,7 +369,12 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   case kAddI:
     return isSingleCondition(t, tensorExps[e].children.e0) &&
            isSingleCondition(t, tensorExps[e].children.e1);
-  default:
+  case kSubF:
+  case kSubC:
+  case kSubI:
+  case kOrI:
+  case kXorI:
+  case kBinary:
     return false;
   }
 }
@@ -352,12 +387,14 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
 
 static const char *kindToOpSymbol(Kind kind) {
   switch (kind) {
+  // Leaf.
   case kTensor:
     return "tensor";
   case kInvariant:
     return "invariant";
   case kIndex:
     return "index";
+  // Unary operations.
   case kAbsF:
   case kAbsC:
     return "abs";
@@ -404,6 +441,7 @@ static const char *kindToOpSymbol(Kind kind) {
     return "binary_branch";
   case kUnary:
     return "unary";
+  // Binary operations.
   case kMulF:
   case kMulC:
   case kMulI:
@@ -441,6 +479,7 @@ static const char *kindToOpSymbol(Kind kind) {
 
 void Merger::dumpExp(unsigned e) const {
   switch (tensorExps[e].kind) {
+  // Leaf.
   case kTensor:
     if (tensorExps[e].tensor == syntheticTensor)
       llvm::dbgs() << "synthetic_";
@@ -454,7 +493,9 @@ void Merger::dumpExp(unsigned e) const {
   case kIndex:
     llvm::dbgs() << "index_" << tensorExps[e].index;
     break;
+  // Unary operations.
   case kAbsF:
+  case kAbsC:
   case kCeilF:
   case kFloorF:
   case kSqrtF:
@@ -462,10 +503,13 @@ void Merger::dumpExp(unsigned e) const {
   case kExpm1F:
   case kExpm1C:
   case kLog1pF:
+  case kLog1pC:
   case kSinF:
+  case kSinC:
   case kTanhF:
   case kTanhC:
   case kNegF:
+  case kNegC:
   case kNegI:
   case kTruncF:
   case kExtF:
@@ -477,11 +521,35 @@ void Merger::dumpExp(unsigned e) const {
   case kCastU:
   case kCastIdx:
   case kTruncI:
+  case kCIm:
+  case kCRe:
   case kBitCast:
+  case kBinaryBranch:
+  case kUnary:
     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
     dumpExp(tensorExps[e].children.e0);
     break;
-  default:
+  // Binary operations.
+  case kMulF:
+  case kMulC:
+  case kMulI:
+  case kDivF:
+  case kDivC:
+  case kDivS:
+  case kDivU:
+  case kAddF:
+  case kAddC:
+  case kAddI:
+  case kSubF:
+  case kSubC:
+  case kSubI:
+  case kAndI:
+  case kOrI:
+  case kXorI:
+  case kShrS:
+  case kShrU:
+  case kShlI:
+  case kBinary:
     llvm::dbgs() << "(";
     dumpExp(tensorExps[e].children.e0);
     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
@@ -542,6 +610,7 @@ void Merger::dumpBits(const BitVector &bits) const {
 unsigned Merger::buildLattices(unsigned e, unsigned i) {
   Kind kind = tensorExps[e].kind;
   switch (kind) {
+  // Leaf.
   case kTensor:
   case kInvariant:
   case kIndex: {
@@ -560,11 +629,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     latSets[s].push_back(addLat(t, i, e));
     return s;
   }
+  // Unary operations.
   case kAbsF:
   case kAbsC:
   case kCeilF:
-  case kCIm:
-  case kCRe:
   case kFloorF:
   case kSqrtF:
   case kSqrtC:
@@ -589,6 +657,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
   case kCastU:
   case kCastIdx:
   case kTruncI:
+  case kCIm:
+  case kCRe:
   case kBitCast:
     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
     // lattice set of the operand through the operator into a new set.
@@ -625,6 +695,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
       unsigned rhs = addExp(kInvariant, absentVal);
       return takeDisj(kind, child0, buildLattices(rhs, i), unop);
     }
+  // Binary operations.
   case kMulF:
   case kMulC:
   case kMulI:
@@ -955,16 +1026,17 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
                        Value v0, Value v1) {
   switch (tensorExps[e].kind) {
+  // Leaf.
   case kTensor:
   case kInvariant:
   case kIndex:
     llvm_unreachable("unexpected non-op");
-  // Unary ops.
+  // Unary operations.
   case kAbsF:
     return rewriter.create<math::AbsOp>(loc, v0);
   case kAbsC: {
-    auto type = v0.getType().template cast<ComplexType>();
-    auto eltType = type.getElementType().template cast<FloatType>();
+    auto type = v0.getType().cast<ComplexType>();
+    auto eltType = type.getElementType().cast<FloatType>();
     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
   }
   case kCeilF:
@@ -1021,18 +1093,19 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
   case kTruncI:
     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
-  case kCIm:
+  case kCIm: {
+    auto type = v0.getType().cast<ComplexType>();
+    auto eltType = type.getElementType().cast<FloatType>();
+    return rewriter.create<complex::ImOp>(loc, eltType, v0);
+  }
   case kCRe: {
-    auto type = v0.getType().template cast<ComplexType>();
-    auto eltType = type.getElementType().template cast<FloatType>();
-    if (tensorExps[e].kind == kCIm)
-      return rewriter.create<complex::ImOp>(loc, eltType, v0);
-
+    auto type = v0.getType().cast<ComplexType>();
+    auto eltType = type.getElementType().cast<FloatType>();
     return rewriter.create<complex::ReOp>(loc, eltType, v0);
   }
   case kBitCast:
     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
-  // Binary ops.
+  // Binary operations.
   case kMulF:
     return rewriter.create<arith::MulFOp>(loc, v0, v1);
   case kMulC:
@@ -1071,8 +1144,7 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
   case kShlI:
     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
-  // Semiring ops with custom logic.
-  case kBinaryBranch:
+  case kBinaryBranch: // semi-ring ops with custom logic.
     return insertYieldOp(rewriter, loc,
                          *tensorExps[e].op->getBlock()->getParent(), {v0});
   case kUnary:

diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 4bdfa71d8bc49..f64251953c9f5 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -136,43 +136,78 @@ class MergerTestBase : public ::testing::Test {
   }
 
   /// Compares expressions for equality. Equality is defined recursively as:
-  /// - Two expressions can only be equal if they have the same Kind.
-  /// - Two binary expressions are equal if they have the same Kind and their
-  ///     children are equal.
-  /// - Expressions with Kind invariant or tensor are equal if they have the
-  ///     same expression id.
+  /// - Operations are equal if they have the same kind and children.
+  /// - Leaf tensors are equal if they refer to the same tensor.
   bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
     auto tensorExp = merger.exp(e);
     if (tensorExp.kind != pattern->kind)
       return false;
-    assert(tensorExp.kind != Kind::kInvariant &&
-           "Invariant comparison not yet supported");
     switch (tensorExp.kind) {
-    case Kind::kTensor:
+    // Leaf.
+    case kTensor:
       return tensorExp.tensor == pattern->tensorNum;
-    case Kind::kAbsF:
-    case Kind::kCeilF:
-    case Kind::kFloorF:
-    case Kind::kNegF:
-    case Kind::kNegI:
+    case kInvariant:
+    case kIndex:
+      llvm_unreachable("invariant not handled yet");
+    // Unary operations.
+    case kAbsF:
+    case kAbsC:
+    case kCeilF:
+    case kFloorF:
+    case kSqrtF:
+    case kSqrtC:
+    case kExpm1F:
+    case kExpm1C:
+    case kLog1pF:
+    case kLog1pC:
+    case kSinF:
+    case kSinC:
+    case kTanhF:
+    case kTanhC:
+    case kNegF:
+    case kNegC:
+    case kNegI:
+    case kTruncF:
+    case kExtF:
+    case kCastFS:
+    case kCastFU:
+    case kCastSF:
+    case kCastUF:
+    case kCastS:
+    case kCastU:
+    case kCastIdx:
+    case kTruncI:
+    case kCIm:
+    case kCRe:
+    case kBitCast:
+    case kBinaryBranch:
+    case kUnary:
+    case kShlI:
+    case kBinary:
       return compareExpression(tensorExp.children.e0, pattern->e0);
-    case Kind::kMulF:
-    case Kind::kMulI:
-    case Kind::kDivF:
-    case Kind::kDivS:
-    case Kind::kDivU:
-    case Kind::kAddF:
-    case Kind::kAddI:
-    case Kind::kSubF:
-    case Kind::kSubI:
-    case Kind::kAndI:
-    case Kind::kOrI:
-    case Kind::kXorI:
+    // Binary operations.
+    case kMulF:
+    case kMulC:
+    case kMulI:
+    case kDivF:
+    case kDivC:
+    case kDivS:
+    case kDivU:
+    case kAddF:
+    case kAddC:
+    case kAddI:
+    case kSubF:
+    case kSubC:
+    case kSubI:
+    case kAndI:
+    case kOrI:
+    case kXorI:
+    case kShrS:
+    case kShrU:
       return compareExpression(tensorExp.children.e0, pattern->e0) &&
              compareExpression(tensorExp.children.e1, pattern->e1);
-    default:
-      llvm_unreachable("Unhandled Kind");
     }
+    llvm_unreachable("unexpected kind");
   }
 
   unsigned numTensors;


        


More information about the Mlir-commits mailing list