[Mlir-commits] [mlir] 8fe6597 - [mlir][sparse] minor cleanup of Merger

Aart Bik llvmlistbot at llvm.org
Fri Jul 16 10:57:19 PDT 2021


Author: Aart Bik
Date: 2021-07-16T10:57:09-07:00
New Revision: 8fe65972cb9cbdf133131d30fb0f67ab9381ae1e

URL: https://github.com/llvm/llvm-project/commit/8fe65972cb9cbdf133131d30fb0f67ab9381ae1e
DIFF: https://github.com/llvm/llvm-project/commit/8fe65972cb9cbdf133131d30fb0f67ab9381ae1e.diff

LOG: [mlir][sparse] minor cleanup of Merger

Removed inconsistent name prefixes, added consistency checks
on debug strings, added more assertions to verify assumptions
that may be lifted in the future.

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D106108

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 65bbb8284b822..43f11b8603615 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -21,11 +21,11 @@ namespace sparse_tensor {
 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
     : kind(k), val(v) {
   switch (kind) {
-  case Kind::kTensor:
+  case kTensor:
     assert(x != -1u && y == -1u && !v);
     tensor = x;
     break;
-  case Kind::kInvariant:
+  case kInvariant:
     assert(x == -1u && y == -1u && v);
     break;
   case kAbsF:
@@ -99,10 +99,10 @@ unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
   for (unsigned p : latSets[s0])
     latSets[s].push_back(p);
   // Map binary 0-y to unary -y.
-  if (kind == Kind::kSubF)
-    s1 = mapSet(Kind::kNegF, s1);
-  else if (kind == Kind::kSubI)
-    s1 = mapSet(Kind::kNegI, s1);
+  if (kind == kSubF)
+    s1 = mapSet(kNegF, s1);
+  else if (kind == kSubI)
+    s1 = mapSet(kNegI, s1);
   // Followed by all in s1.
   for (unsigned p : latSets[s1])
     latSets[s].push_back(p);
@@ -110,7 +110,7 @@ unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
 }
 
 unsigned Merger::mapSet(Kind kind, unsigned s0) {
-  assert(Kind::kAbsF <= kind && kind <= Kind::kNegI);
+  assert(kAbsF <= kind && kind <= kNegI);
   unsigned s = addSet();
   for (unsigned p : latSets[s0]) {
     unsigned e = addExp(kind, latPoints[p].exp);
@@ -129,8 +129,7 @@ unsigned Merger::optimizeSet(unsigned s0) {
     if (p0 != p1) {
       // Is this a straightforward copy?
       unsigned e = latPoints[p1].exp;
-      if (tensorExps[e].kind == Kind::kTensor &&
-          tensorExps[e].tensor == outTensor)
+      if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
         continue;
       // Conjunction already covered?
       for (unsigned p2 : latSets[s]) {
@@ -162,9 +161,9 @@ llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
   }
   // Now apply the two basic rules.
   llvm::BitVector simple = latPoints[p0].bits;
-  bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
+  bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
   for (unsigned b = 0, be = simple.size(); b < be; b++) {
-    if (simple[b] && !isDim(b, Dim::kSparse)) {
+    if (simple[b] && !isDim(b, kSparse)) {
       if (reset)
         simple.reset(b);
       reset = true;
@@ -189,7 +188,7 @@ bool Merger::latGT(unsigned i, unsigned j) const {
 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
   llvm::BitVector tmp = latPoints[j].bits;
   tmp ^= latPoints[i].bits;
-  return !hasAnyDimOf(tmp, Dim::kSparse);
+  return !hasAnyDimOf(tmp, kSparse);
 }
 
 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
@@ -201,23 +200,27 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
 
 bool Merger::isConjunction(unsigned t, unsigned e) const {
   switch (tensorExps[e].kind) {
-  case Kind::kTensor:
+  case kTensor:
     return tensorExps[e].tensor == t;
   case kAbsF:
   case kCeilF:
   case kFloorF:
   case kNegF:
   case kNegI:
-  case Kind::kDivF: // note: x / c only
-  case Kind::kDivS:
-  case Kind::kDivU:
-  case Kind::kShrS: // note: x >> inv only
-  case Kind::kShrU:
-  case Kind::kShlI:
     return isConjunction(t, tensorExps[e].children.e0);
-  case Kind::kMulF:
-  case Kind::kMulI:
-  case Kind::kAndI:
+  case kDivF: // note: x / c only
+  case kDivS:
+  case kDivU:
+    assert(!maybeZero(tensorExps[e].children.e1));
+    return isConjunction(t, tensorExps[e].children.e0);
+  case kShrS: // note: x >> inv only
+  case kShrU:
+  case kShlI:
+    assert(isInvariant(tensorExps[e].children.e1));
+    return isConjunction(t, tensorExps[e].children.e0);
+  case kMulF:
+  case kMulI:
+  case kAndI:
     return isConjunction(t, tensorExps[e].children.e0) ||
            isConjunction(t, tensorExps[e].children.e1);
   default:
@@ -231,20 +234,66 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
 // Print methods (for debugging).
 //
 
-static const char *kOpSymbols[] = {
-    "",  "",  "abs", "ceil", "floor", "-", "-", "*",   "*",  "/", "/",
-    "+", "+", "-",   "-",    "&",     "|", "^", "a>>", ">>", "<<"};
+static const char *kindToOpSymbol(Kind kind) {
+  switch (kind) {
+  case kTensor:
+    return "tensor";
+  case kInvariant:
+    return "invariant";
+  case kAbsF:
+    return "abs";
+  case kCeilF:
+    return "ceil";
+  case kFloorF:
+    return "floor";
+  case kNegF:
+    return "-";
+  case kNegI:
+    return "-";
+  case kMulF:
+    return "*";
+  case kMulI:
+    return "*";
+  case kDivF:
+    return "/";
+  case kDivS:
+    return "/";
+  case kDivU:
+    return "/";
+  case kAddF:
+    return "+";
+  case kAddI:
+    return "+";
+  case kSubF:
+    return "-";
+  case kSubI:
+    return "-";
+  case kAndI:
+    return "&";
+  case kOrI:
+    return "|";
+  case kXorI:
+    return "^";
+  case kShrS:
+    return "a>>";
+  case kShrU:
+    return ">>";
+  case kShlI:
+    return "<<";
+  }
+  llvm_unreachable("unexpected kind for symbol");
+}
 
 void Merger::dumpExp(unsigned e) const {
   switch (tensorExps[e].kind) {
-  case Kind::kTensor:
+  case kTensor:
     if (tensorExps[e].tensor == syntheticTensor)
       llvm::dbgs() << "synthetic_";
     else if (tensorExps[e].tensor == outTensor)
       llvm::dbgs() << "output_";
     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
     break;
-  case Kind::kInvariant:
+  case kInvariant:
     llvm::dbgs() << "invariant";
     break;
   case kAbsF:
@@ -252,13 +301,13 @@ void Merger::dumpExp(unsigned e) const {
   case kFloorF:
   case kNegF:
   case kNegI:
-    llvm::dbgs() << kOpSymbols[tensorExps[e].kind] << " ";
+    llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
     dumpExp(tensorExps[e].children.e0);
     break;
   default:
     llvm::dbgs() << "(";
     dumpExp(tensorExps[e].children.e0);
-    llvm::dbgs() << " " << kOpSymbols[tensorExps[e].kind] << " ";
+    llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
     dumpExp(tensorExps[e].children.e1);
     llvm::dbgs() << ")";
   }
@@ -290,16 +339,16 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
       unsigned i = index(b);
       llvm::dbgs() << " i_" << t << "_" << i << "_";
       switch (dims[t][i]) {
-      case Dim::kSparse:
+      case kSparse:
         llvm::dbgs() << "S";
         break;
-      case Dim::kDense:
+      case kDense:
         llvm::dbgs() << "D";
         break;
-      case Dim::kSingle:
+      case kSingle:
         llvm::dbgs() << "T";
         break;
-      case Dim::kUndef:
+      case kUndef:
         llvm::dbgs() << "U";
         break;
       }
@@ -316,13 +365,13 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
 unsigned Merger::buildLattices(unsigned e, unsigned i) {
   Kind kind = tensorExps[e].kind;
   switch (kind) {
-  case Kind::kTensor:
-  case Kind::kInvariant: {
+  case kTensor:
+  case kInvariant: {
     // Either the index is really used in the tensor expression, or it is
     // set to the undefined index in that dimension. An invariant expression
     // is set to a synthetic tensor with undefined indices only.
     unsigned s = addSet();
-    unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor;
+    unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
     latSets[s].push_back(addLat(t, i, e));
     return s;
   }
@@ -338,9 +387,9 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     //  --+---+---+
     //    | 0 |-y |
     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i));
-  case Kind::kMulF:
-  case Kind::kMulI:
-  case Kind::kAndI:
+  case kMulF:
+  case kMulI:
+  case kAndI:
     // A multiplicative operation only needs to be performed
     // for the conjunction of sparse iteration spaces.
     //
@@ -351,9 +400,9 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     return takeConj(kind, // take binary conjunction
                     buildLattices(tensorExps[e].children.e0, i),
                     buildLattices(tensorExps[e].children.e1, i));
-  case Kind::kDivF:
-  case Kind::kDivS:
-  case Kind::kDivU:
+  case kDivF:
+  case kDivS:
+  case kDivU:
     // A division is tricky, since 0/0, 0/c, c/0 all have
     // specific outcomes for floating-point and integers.
     // Thus, we need to traverse the full iteration space.
@@ -367,15 +416,16 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     //       during expression building, so that the conjunction
     //       rules applies (viz. x/c = x*(1/c) as far as lattice
     //       construction is concerned).
+    assert(!maybeZero(tensorExps[e].children.e1));
     return takeConj(kind, // take binary conjunction
                     buildLattices(tensorExps[e].children.e0, i),
                     buildLattices(tensorExps[e].children.e1, i));
-  case Kind::kAddF:
-  case Kind::kAddI:
-  case Kind::kSubF:
-  case Kind::kSubI:
-  case Kind::kOrI:
-  case Kind::kXorI:
+  case kAddF:
+  case kAddI:
+  case kSubF:
+  case kSubI:
+  case kOrI:
+  case kXorI:
     // An additive operation needs to be performed
     // for the disjunction of sparse iteration spaces.
     //
@@ -386,12 +436,13 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     return takeDisj(kind, // take binary disjunction
                     buildLattices(tensorExps[e].children.e0, i),
                     buildLattices(tensorExps[e].children.e1, i));
-  case Kind::kShrS:
-  case Kind::kShrU:
-  case Kind::kShlI:
+  case kShrS:
+  case kShrU:
+  case kShlI:
     // A shift operation by an invariant amount (viz. tensor expressions
     // can only occur at the left-hand-side of the operator) can be handled
     // with the conjuction rule.
+    assert(isInvariant(tensorExps[e].children.e1));
     return takeConj(kind, // take binary conjunction
                     buildLattices(tensorExps[e].children.e0, i),
                     buildLattices(tensorExps[e].children.e1, i));
@@ -405,7 +456,7 @@ Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
 }
 
 bool Merger::maybeZero(unsigned e) const {
-  if (tensorExps[e].kind == Kind::kInvariant) {
+  if (tensorExps[e].kind == kInvariant) {
     if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
       return c.getValue() == 0;
     if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>())
@@ -415,7 +466,7 @@ bool Merger::maybeZero(unsigned e) const {
 }
 
 bool Merger::isInvariant(unsigned e) const {
-  return tensorExps[e].kind == Kind::kInvariant;
+  return tensorExps[e].kind == kInvariant;
 }
 
 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
@@ -427,30 +478,30 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
     if (arg.getOwner()->getParentOp() == op) {
       OpOperand *t = op.getInputAndOutputOperands()[argN];
       if (!op.isScalar(t))
-        return addExp(Kind::kTensor, argN);
+        return addExp(kTensor, argN);
       v = t->get(); // get scalar value
     }
     // Any other argument (marked as scalar argument for the generic op
     // or belonging to an enveloping op) is considered invariant.
-    return addExp(Kind::kInvariant, v);
+    return addExp(kInvariant, v);
   }
   // Something defined outside is invariant.
   Operation *def = v.getDefiningOp();
   if (def->getBlock() != &op.region().front())
-    return addExp(Kind::kInvariant, v);
+    return addExp(kInvariant, v);
   // Construct unary operations if subexpression can be built.
   if (def->getNumOperands() == 1) {
     auto x = buildTensorExp(op, def->getOperand(0));
     if (x.hasValue()) {
       unsigned e = x.getValue();
       if (isa<AbsFOp>(def))
-        return addExp(Kind::kAbsF, e);
+        return addExp(kAbsF, e);
       if (isa<CeilFOp>(def))
-        return addExp(Kind::kCeilF, e);
+        return addExp(kCeilF, e);
       if (isa<FloorFOp>(def))
-        return addExp(Kind::kFloorF, e);
+        return addExp(kFloorF, e);
       if (isa<NegFOp>(def))
-        return addExp(Kind::kNegF, e);
+        return addExp(kNegF, e);
       // TODO: no negi in std?
     }
   }
@@ -463,35 +514,35 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
       unsigned e0 = x.getValue();
       unsigned e1 = y.getValue();
       if (isa<MulFOp>(def))
-        return addExp(Kind::kMulF, e0, e1);
+        return addExp(kMulF, e0, e1);
       if (isa<MulIOp>(def))
-        return addExp(Kind::kMulI, e0, e1);
+        return addExp(kMulI, e0, e1);
       if (isa<DivFOp>(def) && !maybeZero(e1))
-        return addExp(Kind::kDivF, e0, e1);
+        return addExp(kDivF, e0, e1);
       if (isa<SignedDivIOp>(def) && !maybeZero(e1))
-        return addExp(Kind::kDivS, e0, e1);
+        return addExp(kDivS, e0, e1);
       if (isa<UnsignedDivIOp>(def) && !maybeZero(e1))
-        return addExp(Kind::kDivU, e0, e1);
+        return addExp(kDivU, e0, e1);
       if (isa<AddFOp>(def))
-        return addExp(Kind::kAddF, e0, e1);
+        return addExp(kAddF, e0, e1);
       if (isa<AddIOp>(def))
-        return addExp(Kind::kAddI, e0, e1);
+        return addExp(kAddI, e0, e1);
       if (isa<SubFOp>(def))
-        return addExp(Kind::kSubF, e0, e1);
+        return addExp(kSubF, e0, e1);
       if (isa<SubIOp>(def))
-        return addExp(Kind::kSubI, e0, e1);
+        return addExp(kSubI, e0, e1);
       if (isa<AndOp>(def))
-        return addExp(Kind::kAndI, e0, e1);
+        return addExp(kAndI, e0, e1);
       if (isa<OrOp>(def))
-        return addExp(Kind::kOrI, e0, e1);
+        return addExp(kOrI, e0, e1);
       if (isa<XOrOp>(def))
-        return addExp(Kind::kXorI, e0, e1);
+        return addExp(kXorI, e0, e1);
       if (isa<SignedShiftRightOp>(def) && isInvariant(e1))
-        return addExp(Kind::kShrS, e0, e1);
+        return addExp(kShrS, e0, e1);
       if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1))
-        return addExp(Kind::kShrU, e0, e1);
+        return addExp(kShrU, e0, e1);
       if (isa<ShiftLeftOp>(def) && isInvariant(e1))
-        return addExp(Kind::kShlI, e0, e1);
+        return addExp(kShlI, e0, e1);
     }
   }
   // Cannot build.
@@ -501,8 +552,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
                        Value v0, Value v1) {
   switch (tensorExps[e].kind) {
-  case Kind::kTensor:
-  case Kind::kInvariant:
+  case kTensor:
+  case kInvariant:
     llvm_unreachable("unexpected non-op");
   case kAbsF:
     return rewriter.create<AbsFOp>(loc, v0);
@@ -515,35 +566,35 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
   case kNegI:
     assert(v1); // no negi in std
     return rewriter.create<SubIOp>(loc, v0, v1);
-  case Kind::kMulF:
+  case kMulF:
     return rewriter.create<MulFOp>(loc, v0, v1);
-  case Kind::kMulI:
+  case kMulI:
     return rewriter.create<MulIOp>(loc, v0, v1);
-  case Kind::kDivF:
+  case kDivF:
     return rewriter.create<DivFOp>(loc, v0, v1);
-  case Kind::kDivS:
+  case kDivS:
     return rewriter.create<SignedDivIOp>(loc, v0, v1);
-  case Kind::kDivU:
+  case kDivU:
     return rewriter.create<UnsignedDivIOp>(loc, v0, v1);
-  case Kind::kAddF:
+  case kAddF:
     return rewriter.create<AddFOp>(loc, v0, v1);
-  case Kind::kAddI:
+  case kAddI:
     return rewriter.create<AddIOp>(loc, v0, v1);
-  case Kind::kSubF:
+  case kSubF:
     return rewriter.create<SubFOp>(loc, v0, v1);
-  case Kind::kSubI:
+  case kSubI:
     return rewriter.create<SubIOp>(loc, v0, v1);
-  case Kind::kAndI:
+  case kAndI:
     return rewriter.create<AndOp>(loc, v0, v1);
-  case Kind::kOrI:
+  case kOrI:
     return rewriter.create<OrOp>(loc, v0, v1);
-  case Kind::kXorI:
+  case kXorI:
     return rewriter.create<XOrOp>(loc, v0, v1);
-  case Kind::kShrS:
+  case kShrS:
     return rewriter.create<SignedShiftRightOp>(loc, v0, v1);
-  case Kind::kShrU:
+  case kShrU:
     return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1);
-  case Kind::kShlI:
+  case kShlI:
     return rewriter.create<ShiftLeftOp>(loc, v0, v1);
   }
   llvm_unreachable("unexpected expression kind in build");


        


More information about the Mlir-commits mailing list