[Mlir-commits] [mlir] 164b918 - [mlir][sparse] Removing shared_ptr from the MergerTest.cpp unit test

wren romano llvmlistbot at llvm.org
Tue Mar 28 11:25:34 PDT 2023


Author: wren romano
Date: 2023-03-28T11:25:20-07:00
New Revision: 164b918d73779f776ab3ac6a633d0832ca9e6622

URL: https://github.com/llvm/llvm-project/commit/164b918d73779f776ab3ac6a633d0832ca9e6622
DIFF: https://github.com/llvm/llvm-project/commit/164b918d73779f776ab3ac6a633d0832ca9e6622.diff

LOG: [mlir][sparse] Removing shared_ptr from the MergerTest.cpp unit test

This is a preliminary change to make way for converting the Merger's identifier types from mere typedefs to actual types (which causes some issues that this patch fixes).

Depends On D146676

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D146561

Added: 
    

Modified: 
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 599e8abd52f3..2497659e68dd 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -80,27 +80,39 @@ namespace {
 /// Helper classes/functions for testing Merger.
 ///
 
-/// Simple recursive data structure used to match expressions in Mergers.
+/// Simple recursive data structure used to match expressions in `Merger`.
+struct Pattern;
+/// Since the patterns we need are rather small and short-lived, we use
+/// `Pattern const&` for "pointers" to patterns, rather than using
+/// something more elaborate like `std::shared_ptr<Pattern> const&`.
+/// (But since we use a typedef rather than spelling it out everywhere,
+/// that's easy enough to swap out if we need something more elaborate
+/// in the future.)
+using PatternRef = const Pattern &;
 struct Pattern {
+  struct Children {
+    Children(PatternRef e0, PatternRef e1) : e0(e0), e1(e1) {}
+    PatternRef e0;
+    PatternRef e1;
+  };
+
   TensorExp::Kind kind;
 
-  /// Expressions representing tensors simply have a tensor number.
-  unsigned tensorNum;
+  union {
+    /// Expressions representing tensors simply have a tensor number.
+    TensorId tid;
 
-  /// Tensor operations point to their children.
-  std::shared_ptr<Pattern> e0;
-  std::shared_ptr<Pattern> e1;
+    /// Tensor operations point to their children.
+    Children children;
+  };
 
   /// Constructors.
   /// Rather than using these, please use the readable helper constructor
   /// functions below to make tests more readable.
-  Pattern(unsigned tensorNum)
-      : kind(TensorExp::Kind::kTensor), tensorNum(tensorNum) {}
-  Pattern(TensorExp::Kind kind, const std::shared_ptr<Pattern> &e0,
-          const std::shared_ptr<Pattern> &e1)
-      : kind(kind), e0(e0), e1(e1) {
+  Pattern(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
+  Pattern(TensorExp::Kind kind, PatternRef e0, PatternRef e1)
+      : kind(kind), children(e0, e1) {
     assert(kind >= TensorExp::Kind::kMulF);
-    assert(e0 && e1);
   }
 };
 
@@ -109,15 +121,12 @@ struct Pattern {
 /// These should be preferred over the actual constructors.
 ///
 
-static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
-  return std::make_shared<Pattern>(tensorNum);
-}
+static Pattern tensorPattern(TensorId tid) { return Pattern(tid); }
 
 #define IMPL_BINOP_PATTERN(OP, KIND)                                           \
-  LLVM_ATTRIBUTE_UNUSED static std::shared_ptr<Pattern> OP##Pattern(           \
-      const std::shared_ptr<Pattern> &e0,                                      \
-      const std::shared_ptr<Pattern> &e1) {                                    \
-    return std::make_shared<Pattern>(KIND, e0, e1);                            \
+  LLVM_ATTRIBUTE_UNUSED static Pattern OP##Pattern(PatternRef e0,              \
+                                                   PatternRef e1) {            \
+    return Pattern(KIND, e0, e1);                                              \
   }
 
 FOREVERY_BINOP(IMPL_BINOP_PATTERN)
@@ -127,20 +136,32 @@ FOREVERY_BINOP(IMPL_BINOP_PATTERN)
 class MergerTestBase : public ::testing::Test {
 protected:
   MergerTestBase(unsigned numTensors, unsigned numLoops)
-      : numTensors(numTensors), numLoops(numLoops),
-        merger(numTensors, numLoops, /*numFilterLoops=*/0,
-               /*maxRank=*/numLoops) {}
+      : merger(numTensors, numLoops, /*numFilterLoops=*/0,
+               /*maxRank=*/numLoops) {
+    tensors.reserve(numTensors);
+    for (unsigned t = 0; t < numTensors; t++)
+      tensors.push_back(merger.addExp(TensorExp::Kind::kTensor, tid(t)));
+  }
 
   ///
   /// Expression construction helpers.
   ///
 
-  unsigned tensor(unsigned tensor) {
-    return merger.addExp(TensorExp::Kind::kTensor, tensor);
+  TensorId tid(unsigned t) const {
+    assert(t < merger.getNumTensors());
+    return t;
+  }
+  LoopId lid(unsigned i) const {
+    assert(i < merger.getNumLoops());
+    return i;
+  }
+  ExprId tensor(unsigned t) const {
+    assert(t < tensors.size());
+    return tensors[t];
   }
 
 #define IMPL_BINOP_EXPR(OP, KIND)                                              \
-  LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) {          \
+  LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) {                \
     return merger.addExp(KIND, e0, e1);                                        \
   }
 
@@ -152,83 +173,77 @@ class MergerTestBase : public ::testing::Test {
   /// Comparison helpers.
   ///
 
-  /// For readability of tests.
-  unsigned lat(unsigned lat) { return lat; }
-
-  /// Returns true if a lattice point with an expression matching the given
-  /// pattern and bits matching the given bits is present in lattice points
-  /// [p, p+n) of lattice set s. This is useful for testing partial ordering
-  /// constraints between lattice points. We generally know how contiguous
-  /// groups of lattice points should be ordered with respect to other groups,
-  /// but there is no required ordering within groups.
-  /// If simple is true, then compare the lat.simple field instead to test the
-  /// result after optimization
-  bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
-                           const std::shared_ptr<Pattern> &pattern,
-                           const BitVector &bits, bool simple) {
-    for (unsigned i = p; i < p + n; ++i) {
-      if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
-          compareBits(s, i, bits, simple))
+  /// Returns true if any lattice point with an expression matching
+  /// the given `pattern` and bits matching the given `bits` is present
+  /// in the `[lo, lo+n)` slice of the lattice set `s`.  This is useful
+  /// for testing partial ordering constraints between lattice points.
+  /// We generally know how contiguous groups of lattice points should
+  /// be ordered with respect to other groups, but there is no required
+  /// ordering within groups.  If `simple` is true, then compare the
+  /// `lat.simple` field instead to test the result after optimization.
+  bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n,
+                           PatternRef pattern, const BitVector &bits,
+                           bool simple) {
+    for (unsigned k = lo, hi = lo + n; k < hi; ++k) {
+      if (compareExpression(merger.lat(merger.set(s)[k]).exp, pattern) &&
+          compareBits(s, k, bits, simple))
         return true;
     }
     return false;
   }
 
   /// Wrapper over latPointWithinRange for readability of tests.
-  void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
-                                 const std::shared_ptr<Pattern> &pattern,
-                                 const BitVector &bits, bool simple = false) {
-    EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple));
+  void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n,
+                                 PatternRef pattern, const BitVector &bits,
+                                 bool simple = false) {
+    EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple));
   }
 
   /// Wrapper over expectLatPointWithinRange for a single lat point.
-  void expectLatPoint(unsigned s, unsigned p,
-                      const std::shared_ptr<Pattern> &pattern,
+  void expectLatPoint(LatSetId s, unsigned lo, PatternRef pattern,
                       const BitVector &bits, bool simple = false) {
-    EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple));
+    EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple));
   }
 
   /// Converts a vector of (loop, tensor) pairs to a bitvector with the
   /// corresponding bits set.
-  BitVector
-  loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) {
-    BitVector testBits = BitVector(numTensors + 1, false);
-    for (auto l : loops) {
-      auto loop = std::get<0>(l);
-      auto tensor = std::get<1>(l);
+  BitVector loopsToBits(const std::vector<std::pair<LoopId, TensorId>> &loops) {
+    // NOTE: this `numTensors` includes both the output- and synthetic-tensors.
+    const auto numTensors = merger.getNumTensors();
+    BitVector testBits = BitVector(numTensors, false);
+    for (auto [loop, tensor] : loops)
       testBits.set(numTensors * loop + tensor);
-    }
     return testBits;
   }
 
-  /// Returns true if the bits of lattice point p in set s match the given bits.
-  /// If simple is true, then compare the lat.simple field instead to test the
-  /// result after optimization
-  bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) {
-    if (simple)
-      return merger.lat(merger.set(s)[p]).simple == bits;
-    return merger.lat(merger.set(s)[p]).bits == bits;
+  /// Returns true if the bits of the `k`th point in set `s` matches
+  /// the given `bits`.  If `simple` is true, then compares the `lat.simple`
+  /// field instead, to test the result after optimization
+  bool compareBits(LatSetId s, unsigned k, const BitVector &bits, bool simple) {
+    const auto &point = merger.lat(merger.set(s)[k]);
+    return (simple ? point.simple : point.bits) == bits;
   }
 
   /// Check that there are n lattice points in set s.
-  void expectNumLatPoints(unsigned s, unsigned n) {
+  void expectNumLatPoints(LatSetId s, unsigned n) {
     EXPECT_THAT(merger.set(s).size(), n);
   }
 
   /// Compares expressions for equality. Equality is defined recursively as:
   /// - Operations are equal if they have the same kind and children.
   /// - Leaf tensors are equal if they refer to the same tensor.
-  bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
-    auto tensorExp = merger.exp(e);
-    if (tensorExp.kind != pattern->kind)
+  bool compareExpression(ExprId e, PatternRef pattern) {
+    const auto &tensorExp = merger.exp(e);
+    if (tensorExp.kind != pattern.kind)
       return false;
     switch (tensorExp.kind) {
     // Leaf.
     case TensorExp::Kind::kTensor:
-      return tensorExp.tensor == pattern->tensorNum;
+      return tensorExp.tensor == pattern.tid;
     case TensorExp::Kind::kInvariant:
-    case TensorExp::Kind::kLoopVar:
       llvm_unreachable("invariant not handled yet");
+    case TensorExp::Kind::kLoopVar:
+      llvm_unreachable("loop-variables not handled yet");
     // Unary operations.
     case TensorExp::Kind::kAbsF:
     case TensorExp::Kind::kAbsC:
@@ -264,7 +279,7 @@ class MergerTestBase : public ::testing::Test {
     case TensorExp::Kind::kSelect:
     case TensorExp::Kind::kBinaryBranch:
     case TensorExp::Kind::kUnary:
-      return compareExpression(tensorExp.children.e0, pattern->e0);
+      return compareExpression(tensorExp.children.e0, pattern.children.e0);
     // Binary operations.
     case TensorExp::Kind::kMulF:
     case TensorExp::Kind::kMulC:
@@ -287,72 +302,51 @@ class MergerTestBase : public ::testing::Test {
     case TensorExp::Kind::kShlI:
     case TensorExp::Kind::kBinary:
     case TensorExp::Kind::kReduce:
-      return compareExpression(tensorExp.children.e0, pattern->e0) &&
-             compareExpression(tensorExp.children.e1, pattern->e1);
+      return compareExpression(tensorExp.children.e0, pattern.children.e0) &&
+             compareExpression(tensorExp.children.e1, pattern.children.e1);
     }
     llvm_unreachable("unexpected kind");
   }
 
-  unsigned numTensors;
-  unsigned numLoops;
+  // This field is public for convenience.
   Merger merger;
+
+private:
+  // This field is private to prevent mutation after the ctor.
+  SmallVector<ExprId> tensors;
 };
 
 ///
 /// Tests with all sparse inputs.
 ///
 
+/// Three tensors (two inputs, one output); and a single loop.
 class MergerTest3T1L : public MergerTestBase {
 protected:
-  // Our three tensors (two inputs, one output).
-  const unsigned t0 = 0, t1 = 1, t2 = 2;
-
-  // Our single loop.
-  const unsigned l0 = 0;
-
   MergerTest3T1L() : MergerTestBase(3, 1) {
-    EXPECT_TRUE(merger.getOutTensorID() == t2);
-
+    EXPECT_TRUE(merger.getOutTensorID() == tid(2));
     // Tensor 0: sparse input vector.
-    merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
-    merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
-
+    merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed);
     // Tensor 1: sparse input vector.
-    merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
-    merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed);
-
+    merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Compressed);
     // Tensor 2: dense output vector.
-    merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
-    merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense);
+    merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Dense);
   }
 };
 
+/// Four tensors (three inputs, one output); and a single loop.
 class MergerTest4T1L : public MergerTestBase {
 protected:
-  // Our four tensors (three inputs, one output).
-  const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
-
-  // Our single loop.
-  const unsigned l0 = 0;
-
   MergerTest4T1L() : MergerTestBase(4, 1) {
-    EXPECT_TRUE(merger.getOutTensorID() == t3);
-
+    EXPECT_TRUE(merger.getOutTensorID() == tid(3));
     // Tensor 0: sparse input vector.
-    merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
-    merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
-
+    merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed);
     // Tensor 1: sparse input vector.
-    merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
-    merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed);
-
+    merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Compressed);
     // Tensor 2: sparse input vector
-    merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
-    merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed);
-
+    merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Compressed);
     // Tensor 3: dense output vector
-    merger.addExp(TensorExp::Kind::kTensor, t3, -1u);
-    merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense);
+    merger.setLevelAndType(tid(3), lid(0), 0, DimLevelType::Dense);
   }
 };
 
@@ -360,28 +354,17 @@ class MergerTest4T1L : public MergerTestBase {
 /// Tests with both sparse and dense input.
 ///
 
+/// Three tensors (two inputs, one output); and a single loop.
 class MergerTest3T1LD : public MergerTestBase {
 protected:
-  // Our three tensors (two inputs, one output).
-  const unsigned t0 = 0, t1 = 1, t2 = 2;
-
-  // Our single loop.
-  const unsigned l0 = 0;
-
   MergerTest3T1LD() : MergerTestBase(3, 1) {
-    EXPECT_TRUE(merger.getOutTensorID() == t2);
-
+    EXPECT_TRUE(merger.getOutTensorID() == tid(2));
     // Tensor 0: sparse input vector.
-    merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
-    merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
-
+    merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed);
     // Tensor 1: dense input vector.
-    merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
-    merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense);
-
+    merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Dense);
     // Tensor 2: dense output vector.
-    merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
-    merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense);
+    merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Dense);
   }
 };
 
@@ -389,32 +372,19 @@ class MergerTest3T1LD : public MergerTestBase {
 /// Tests with both undef and dense input.
 ///
 
+/// Three tensors (three inputs, one output); and a single loop.
 class MergerTest4T1LU : public MergerTestBase {
 protected:
-  // Our three tensors (three inputs, one output).
-  const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
-
-  // Our single loop.
-  const unsigned l0 = 0;
-
   MergerTest4T1LU() : MergerTestBase(4, 1) {
-    EXPECT_TRUE(merger.getOutTensorID() == t3);
-
+    EXPECT_TRUE(merger.getOutTensorID() == tid(3));
     // Tensor 0: undef input vector.
-    merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
-    merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef);
-
+    merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Undef);
     // Tensor 1: dense input vector.
-    merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
-    merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense);
-
+    merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Dense);
     // Tensor 2: undef input vector.
-    merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
-    merger.setLevelAndType(t2, l0, 0, DimLevelType::Undef);
-
+    merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Undef);
     // Tensor 3: dense output vector.
-    merger.addExp(TensorExp::Kind::kTensor, t3, -1u);
-    merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense);
+    merger.setLevelAndType(tid(3), lid(0), 0, DimLevelType::Dense);
   }
 };
 
@@ -422,31 +392,19 @@ class MergerTest4T1LU : public MergerTestBase {
 /// Tests with operation on sparse output.
 ///
 
+/// Three tensors (two inputs, one output, one synthetic); and a single loop.
 class MergerTest3T1LSo : public MergerTestBase {
 protected:
-  // Our three tensors (two inputs, one output, one synthetic).
-  const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
-
-  // Our single loop.
-  const unsigned l0 = 0;
-
   MergerTest3T1LSo() : MergerTestBase(3, 1) {
-    EXPECT_TRUE(merger.getOutTensorID() == t2);
-    EXPECT_TRUE(merger.getSynTensorID() == t3);
-
+    EXPECT_TRUE(merger.getOutTensorID() == tid(2));
+    EXPECT_TRUE(merger.getSynTensorID() == tid(3));
     merger.setHasSparseOut(true);
-
     // Tensor 0: undef input vector.
-    merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
-    merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef);
-
+    merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Undef);
     // Tensor 1: undef input vector.
-    merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
-    merger.setLevelAndType(t1, l0, 0, DimLevelType::Undef);
-
+    merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Undef);
     // Tensor 2: sparse output vector.
-    merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
-    merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed);
+    merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Compressed);
   }
 };
 
@@ -465,18 +423,22 @@ class MergerTest3T1LSo : public MergerTestBase {
 /// }
 #define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2)                         \
   TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) {                          \
-    auto em = CONJ1##Expr(t0, t1);                                             \
-    auto e = CONJ2##Expr(em, t2);                                              \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto p2 = tensorPattern(t2);                                               \
+    const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
+    const auto e = CONJ2##Expr(em, tensor(2));                                 \
+    const auto l0 = lid(0);                                                    \
+    const auto t0 = tid(0);                                                    \
+    const auto t1 = tid(1);                                                    \
+    const auto t2 = tid(2);                                                    \
+    const PatternRef p0 = tensorPattern(t0);                                   \
+    const PatternRef p1 = tensorPattern(t1);                                   \
+    const PatternRef p2 = tensorPattern(t2);                                   \
     auto s = merger.buildLattices(e, l0);                                      \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
+    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
+    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
                    loopsToBits({{l0, t1}}), true);                             \
   }
 
@@ -497,18 +459,23 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
 /// }
 #define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2)                    \
   TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) {                         \
-    auto em = CONJ1##Expr(t0, t1);                                             \
-    auto e = CONJ2##Expr(em, t2);                                              \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto p2 = tensorPattern(t2);                                               \
+    const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
+    const auto e = CONJ2##Expr(em, tensor(2));                                 \
+    const auto l0 = lid(0);                                                    \
+    const auto t0 = tid(0);                                                    \
+    const auto t1 = tid(1);                                                    \
+    const auto t2 = tid(2);                                                    \
+    const auto t3 = tid(3);                                                    \
+    const PatternRef p0 = tensorPattern(t0);                                   \
+    const PatternRef p1 = tensorPattern(t1);                                   \
+    const PatternRef p2 = tensorPattern(t2);                                   \
     auto s = merger.buildLattices(e, l0);                                      \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
+    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}}));               \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
+    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
                    loopsToBits({{l0, t3}}), true);                             \
   }
 
@@ -533,25 +500,26 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
 /// }
 #define IMPL_MERGER_TEST_DISJ(OP)                                              \
   TEST_F(MergerTest3T1L, vector_##OP) {                                        \
-    auto e = OP##Expr(tensor(t0), tensor(t1));                                 \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
+    const auto e = OP##Expr(tensor(0), tensor(1));                             \
+    const auto l0 = lid(0);                                                    \
+    const auto t0 = tid(0);                                                    \
+    const auto t1 = tid(1);                                                    \
+    const PatternRef p0 = tensorPattern(t0);                                   \
+    const PatternRef p1 = tensorPattern(t1);                                   \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
+    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
-    expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}));      \
-    expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}));      \
+    expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}));           \
+    expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}));           \
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
+    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
-    expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}),       \
-                              true);                                           \
-    expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}),       \
-                              true);                                           \
+    expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true);     \
+    expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true);     \
   }
 
 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
@@ -566,18 +534,21 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
 /// }
 #define IMPL_MERGER_TEST_CONJ(OP)                                              \
   TEST_F(MergerTest3T1L, vector_##OP) {                                        \
-    auto e = OP##Expr(t0, t1);                                                 \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
+    const auto e = OP##Expr(tensor(0), tensor(1));                             \
+    const auto l0 = lid(0);                                                    \
+    const auto t0 = tid(0);                                                    \
+    const auto t1 = tid(1);                                                    \
+    const PatternRef p0 = tensorPattern(t0);                                   \
+    const PatternRef p1 = tensorPattern(t1);                                   \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
+    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
+    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
   }
 
@@ -595,27 +566,31 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
 /// }
 #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ)                                 \
   TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) {                             \
-    auto em = CONJ##Expr(t0, t1);                                              \
-    auto e = DISJ##Expr(em, t2);                                               \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto p2 = tensorPattern(t2);                                               \
+    const auto em = CONJ##Expr(tensor(0), tensor(1));                          \
+    const auto e = DISJ##Expr(em, tensor(2));                                  \
+    const auto l0 = lid(0);                                                    \
+    const auto t0 = tid(0);                                                    \
+    const auto t1 = tid(1);                                                    \
+    const auto t2 = tid(2);                                                    \
+    const PatternRef p0 = tensorPattern(t0);                                   \
+    const PatternRef p1 = tensorPattern(t1);                                   \
+    const PatternRef p2 = tensorPattern(t2);                                   \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2),        \
+    expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2),             \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1),             \
+    expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1),                  \
                               loopsToBits({{l0, t0}, {l0, t1}}));              \
-    expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}}));      \
+    expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}}));           \
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2),        \
+    expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2),             \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1),             \
+    expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1),                  \
                               loopsToBits({{l0, t0}, {l0, t1}}));              \
-    expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}}));      \
+    expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}}));           \
   }
 
 FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
@@ -636,39 +611,43 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
 /// }
 #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2)                               \
   TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) {                           \
-    auto em = DISJ1##Expr(t0, t1);                                             \
-    auto e = DISJ2##Expr(em, t2);                                              \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto p2 = tensorPattern(t2);                                               \
+    const auto em = DISJ1##Expr(tensor(0), tensor(1));                         \
+    const auto e = DISJ2##Expr(em, tensor(2));                                 \
+    const auto l0 = lid(0);                                                    \
+    const auto t0 = tid(0);                                                    \
+    const auto t1 = tid(1);                                                    \
+    const auto t2 = tid(2);                                                    \
+    const PatternRef p0 = tensorPattern(t0);                                   \
+    const PatternRef p1 = tensorPattern(t1);                                   \
+    const PatternRef p2 = tensorPattern(t2);                                   \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 7);                                                  \
-    expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),      \
+    expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),           \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2),            \
+    expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2),                 \
                               loopsToBits({{l0, t1}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2),            \
+    expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2),                 \
                               loopsToBits({{l0, t0}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1),            \
+    expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1),                 \
                               loopsToBits({{l0, t0}, {l0, t1}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}}));      \
-    expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}}));      \
-    expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}}));      \
+    expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}}));           \
+    expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}}));           \
+    expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}}));           \
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 7);                                                  \
-    expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),      \
+    expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),           \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2),            \
+    expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2),                 \
                               loopsToBits({{l0, t1}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2),            \
+    expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2),                 \
                               loopsToBits({{l0, t0}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1),            \
+    expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1),                 \
                               loopsToBits({{l0, t0}, {l0, t1}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}}));      \
-    expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}}));      \
-    expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}}));      \
+    expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}}));           \
+    expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}}));           \
+    expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}}));           \
   }
 
 FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
@@ -683,18 +662,22 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
 /// }
 #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2)                               \
   TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) {                           \
-    auto em = CONJ1##Expr(t0, t1);                                             \
-    auto e = CONJ2##Expr(em, t2);                                              \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto p2 = tensorPattern(t2);                                               \
+    const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
+    const auto e = CONJ2##Expr(em, tensor(2));                                 \
+    const auto l0 = lid(0);                                                    \
+    const auto t0 = tid(0);                                                    \
+    const auto t1 = tid(1);                                                    \
+    const auto t2 = tid(2);                                                    \
+    const PatternRef p0 = tensorPattern(t0);                                   \
+    const PatternRef p1 = tensorPattern(t1);                                   \
+    const PatternRef p2 = tensorPattern(t2);                                   \
     auto s = merger.buildLattices(e, l0);                                      \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
+    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
+    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true);         \
   }
 
@@ -720,22 +703,25 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
 /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
 #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP)                                    \
   TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
-    auto e = OP##Expr(tensor(t0), tensor(t1));                                 \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
+    const auto e = OP##Expr(tensor(0), tensor(1));                             \
+    const auto l0 = lid(0);                                                    \
+    const auto t0 = tid(0);                                                    \
+    const auto t1 = tid(1);                                                    \
+    const PatternRef p0 = tensorPattern(t0);                                   \
+    const PatternRef p1 = tensorPattern(t1);                                   \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
+    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
-    expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}));      \
-    expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}));      \
+    expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}));           \
+    expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}));           \
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 2);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
+    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
-    expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true);              \
+    expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true);                   \
   }
 
 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
@@ -755,19 +741,21 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
 /// since i_01 is a dense dimension.
 #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP)                                    \
   TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
-    auto e = OP##Expr(t0, t1);                                                 \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
+    const auto e = OP##Expr(tensor(0), tensor(1));                             \
+    const auto l0 = lid(0);                                                    \
+    const auto t0 = tid(0);                                                    \
+    const auto t1 = tid(1);                                                    \
+    const PatternRef p0 = tensorPattern(t0);                                   \
+    const PatternRef p1 = tensorPattern(t1);                                   \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
+    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}),    \
-                   true);                                                      \
+    expectLatPoint(s, 0, OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), true);  \
   }
 
 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)


        


More information about the Mlir-commits mailing list