[Mlir-commits] [mlir] 988d4f5 - Revert "[mlir][sparse] add more unittest cases to sparse dialect merger"

Stella Stamenova llvmlistbot at llvm.org
Tue Jul 5 08:58:04 PDT 2022


Author: Stella Stamenova
Date: 2022-07-05T08:56:16-07:00
New Revision: 988d4f576fdfbf69c6c26279969138433098f0aa

URL: https://github.com/llvm/llvm-project/commit/988d4f576fdfbf69c6c26279969138433098f0aa
DIFF: https://github.com/llvm/llvm-project/commit/988d4f576fdfbf69c6c26279969138433098f0aa.diff

LOG: Revert "[mlir][sparse] add more unittest cases to sparse dialect merger"

This broke the windows mlir bot: https://lab.llvm.org/buildbot/#/builders/13/builds/22743

This reverts commit daeb2dcea09820d92f81db84623cf1c6df825e14 and 537db49596f65a05c0309cf3333fc44f1657e999.

Added: 
    

Modified: 
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index fad41b0178349..f64251953c9f5 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -1,5 +1,4 @@
 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
-#include "llvm/Support/Compiler.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 #include <memory>
@@ -9,68 +8,6 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
-///
-/// Defines macros to iterate binary and the combination of binary operations.
-///
-
-#define FOREVERY_BINOP(DO)                                                     \
-  DO(mulf, Kind::kMulF)                                                        \
-  DO(mulc, Kind::kMulC)                                                        \
-  DO(muli, Kind::kMulI)                                                        \
-  DO(addf, Kind::kAddF)                                                        \
-  DO(addc, Kind::kAddC)                                                        \
-  DO(addi, Kind::kAddI)                                                        \
-  DO(subf, Kind::kSubF)                                                        \
-  DO(subc, Kind::kSubC)                                                        \
-  DO(subi, Kind::kSubI)                                                        \
-  DO(andi, Kind::kAndI)                                                        \
-  DO(xori, Kind::kXorI)                                                        \
-  DO(ori, Kind::kOrI)
-
-// TODO: Disjunctive binary operations that need special handling are not
-// included, e.g., Division are not tested (for now) as it need a constant
-// non-zero dividend.
-// ##__VA_ARGS__ handles cases when __VA_ARGS__ is empty.
-#define FOREVERY_COMMON_DISJ_BINOP(TEST, ...)                                  \
-  TEST(addf, ##__VA_ARGS__)                                                    \
-  TEST(addc, ##__VA_ARGS__)                                                    \
-  TEST(addi, ##__VA_ARGS__)                                                    \
-  TEST(xori, ##__VA_ARGS__)                                                    \
-  TEST(ori, ##__VA_ARGS__)
-
-// TODO: Conjunctive binary operations that need special handling are not
-// included, e.g., substraction yields a 
diff erent pattern as it is mapped to
-// negate operation.
-#define FOREVERY_COMMON_CONJ_BINOP(TEST, ...)                                  \
-  TEST(mulf, ##__VA_ARGS__)                                                    \
-  TEST(mulc, ##__VA_ARGS__)                                                    \
-  TEST(muli, ##__VA_ARGS__)                                                    \
-  TEST(andi, ##__VA_ARGS__)
-
-#define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST)                          \
-  FOREVERY_COMMON_CONJ_BINOP(TEST, addf)                                       \
-  FOREVERY_COMMON_CONJ_BINOP(TEST, addc)                                       \
-  FOREVERY_COMMON_CONJ_BINOP(TEST, addi)                                       \
-  FOREVERY_COMMON_CONJ_BINOP(TEST, xori)                                       \
-  FOREVERY_COMMON_CONJ_BINOP(TEST, ori)
-
-#define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST)                          \
-  FOREVERY_COMMON_CONJ_BINOP(TEST, mulf)                                       \
-  FOREVERY_COMMON_CONJ_BINOP(TEST, mulc)                                       \
-  FOREVERY_COMMON_CONJ_BINOP(TEST, muli)                                       \
-  FOREVERY_COMMON_CONJ_BINOP(TEST, andi)
-
-#define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST)                          \
-  FOREVERY_COMMON_DISJ_BINOP(TEST, addf)                                       \
-  FOREVERY_COMMON_DISJ_BINOP(TEST, addc)                                       \
-  FOREVERY_COMMON_DISJ_BINOP(TEST, addi)                                       \
-  FOREVERY_COMMON_DISJ_BINOP(TEST, ori)                                        \
-  FOREVERY_COMMON_DISJ_BINOP(TEST, xori)
-
-///
-/// Helper classes/functions for testing Merger.
-///
-
 /// Simple recursive data structure used to match expressions in Mergers.
 struct Pattern {
   Kind kind;
@@ -103,16 +40,17 @@ static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
   return std::make_shared<Pattern>(tensorNum);
 }
 
-#define IMPL_BINOP_PATTERN(OP, KIND)                                           \
-  LLVM_ATTRIBUTE_UNUSED static std::shared_ptr<Pattern> OP##Pattern(           \
-      const std::shared_ptr<Pattern> &e0,                                      \
-      const std::shared_ptr<Pattern> &e1) {                                    \
-    return std::make_shared<Pattern>(KIND, e0, e1);                            \
-  }
-
-FOREVERY_BINOP(IMPL_BINOP_PATTERN)
+static std::shared_ptr<Pattern>
+addfPattern(const std::shared_ptr<Pattern> &e0,
+            const std::shared_ptr<Pattern> &e1) {
+  return std::make_shared<Pattern>(Kind::kAddF, e0, e1);
+}
 
-#undef IMPL_BINOP_PATTERN
+static std::shared_ptr<Pattern>
+mulfPattern(const std::shared_ptr<Pattern> &e0,
+            const std::shared_ptr<Pattern> &e1) {
+  return std::make_shared<Pattern>(Kind::kMulF, e0, e1);
+}
 
 class MergerTestBase : public ::testing::Test {
 protected:
@@ -128,14 +66,13 @@ class MergerTestBase : public ::testing::Test {
     return merger.addExp(Kind::kTensor, tensor);
   }
 
-#define IMPL_BINOP_EXPR(OP, KIND)                                              \
-  LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) {          \
-    return merger.addExp(KIND, e0, e1);                                        \
+  unsigned addf(unsigned e0, unsigned e1) {
+    return merger.addExp(Kind::kAddF, e0, e1);
   }
 
-  FOREVERY_BINOP(IMPL_BINOP_EXPR)
-
-#undef IMPL_BINOP_EXPR
+  unsigned mulf(unsigned e0, unsigned e1) {
+    return merger.addExp(Kind::kMulF, e0, e1);
+  }
 
   ///
   /// Comparison helpers.
@@ -150,14 +87,12 @@ class MergerTestBase : public ::testing::Test {
   /// constraints between lattice points. We generally know how contiguous
   /// groups of lattice points should be ordered with respect to other groups,
   /// but there is no required ordering within groups.
-  /// If simple is true, then compare the lat.simple field instead to test the
-  /// result after optimization
   bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
                            const std::shared_ptr<Pattern> &pattern,
-                           const BitVector &bits, bool simple) {
+                           const BitVector &bits) {
     for (unsigned i = p; i < p + n; ++i) {
       if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
-          compareBits(s, i, bits, simple))
+          compareBits(s, i, bits))
         return true;
     }
     return false;
@@ -166,15 +101,15 @@ class MergerTestBase : public ::testing::Test {
   /// Wrapper over latPointWithinRange for readability of tests.
   void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
                                  const std::shared_ptr<Pattern> &pattern,
-                                 const BitVector &bits, bool simple = false) {
-    EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple));
+                                 const BitVector &bits) {
+    EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits));
   }
 
   /// Wrapper over expectLatPointWithinRange for a single lat point.
   void expectLatPoint(unsigned s, unsigned p,
                       const std::shared_ptr<Pattern> &pattern,
-                      const BitVector &bits, bool simple = false) {
-    EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple));
+                      const BitVector &bits) {
+    EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits));
   }
 
   /// Converts a vector of (loop, tensor) pairs to a bitvector with the
@@ -191,11 +126,7 @@ class MergerTestBase : public ::testing::Test {
   }
 
   /// Returns true if the bits of lattice point p in set s match the given bits.
-  /// If simple is true, then compare the lat.simple field instead to test the
-  /// result after optimization
-  bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) {
-    if (simple)
-      return merger.lat(merger.set(s)[p]).simple == bits;
+  bool compareBits(unsigned s, unsigned p, const BitVector &bits) {
     return merger.lat(merger.set(s)[p]).bits == bits;
   }
 
@@ -284,10 +215,6 @@ class MergerTestBase : public ::testing::Test {
   Merger merger;
 };
 
-///
-/// Tests with all sparse inputs.
-///
-
 class MergerTest3T1L : public MergerTestBase {
 protected:
   // Our three tensors (two inputs, one output).
@@ -311,63 +238,9 @@ class MergerTest3T1L : public MergerTestBase {
   }
 };
 
-class MergerTest4T1L : public MergerTestBase {
-protected:
-  // Our four tensors (three inputs, one output).
-  const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
-
-  // Our single loop.
-  const unsigned l0 = 0;
-
-  MergerTest4T1L() : MergerTestBase(4, 1) {
-    // Tensor 0: sparse input vector.
-    merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDim(t0, l0, Dim::kSparse);
-
-    // Tensor 1: sparse input vector.
-    merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDim(t1, l0, Dim::kSparse);
-
-    // Tensor 2: sparse input vector
-    merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDim(t2, l0, Dim::kSparse);
-
-    // Tensor 3: dense output vector
-    merger.addExp(Kind::kTensor, t3, -1u);
-    merger.setDim(t3, l0, Dim::kDense);
-  }
-};
-
-///
-/// Tests with both sparse and dense input.
-///
-
-class MergerTest3T1LD : public MergerTestBase {
-protected:
-  // Our three tensors (two inputs, one output).
-  const unsigned t0 = 0, t1 = 1, t2 = 2;
-
-  // Our single loop.
-  const unsigned l0 = 0;
-
-  MergerTest3T1LD() : MergerTestBase(3, 1) {
-    // Tensor 0: sparse input vector.
-    merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDim(t0, l0, Dim::kSparse);
-
-    // Tensor 1: dense input vector.
-    merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDim(t1, l0, Dim::kDense);
-
-    // Tensor 2: dense output vector.
-    merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDim(t2, l0, Dim::kDense);
-  }
-};
-
 } // namespace
 
-/// Vector addition (disjunction) of 2 vectors. i.e.;
+/// Vector addition of 2 vectors, i.e.:
 ///   a(i) = b(i) + c(i)
 /// which should form the 3 lattice points
 /// {
@@ -375,254 +248,55 @@ class MergerTest3T1LD : public MergerTestBase {
 ///   lat( i_00 / tensor_0 )
 ///   lat( i_01 / tensor_1 )
 /// }
-/// and after optimization, the lattice points do not change (as there is no
-/// duplicated point and all input vectors are sparse vector).
+/// and after optimization, will reduce to the 2 lattice points
 /// {
 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
 ///   lat( i_00 / tensor_0 )
-///   lat( i_01 / tensor_1 )
 /// }
-#define IMPL_MERGER_TEST_DISJ(OP)                                              \
-  TEST_F(MergerTest3T1L, vector_##OP) {                                        \
-    auto e = OP##Expr(tensor(t0), tensor(t1));                                 \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto s = merger.buildLattices(e, l0);                                      \
-                                                                               \
-    expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
-                   loopsToBits({{l0, t0}, {l0, t1}}));                         \
-    expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}));      \
-    expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}));      \
-                                                                               \
-    s = merger.optimizeSet(s);                                                 \
-    expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
-                   loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
-    expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}),       \
-                              true);                                           \
-    expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}),       \
-                              true);                                           \
-  }
-
-FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
-
-#undef IMPL_MERGER_TEST_DISJ
+TEST_F(MergerTest3T1L, VectorAdd2) {
+  // Construct expression.
+  auto e = addf(tensor(t0), tensor(t1));
+
+  // Build lattices and check.
+  auto s = merger.buildLattices(e, l0);
+  expectNumLatPoints(s, 3);
+  expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
+                 loopsToBits({{l0, t0}, {l0, t1}}));
+  expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
+                            loopsToBits({{l0, t0}}));
+  expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
+                            loopsToBits({{l0, t1}}));
+
+  // Optimize lattices and check.
+  s = merger.optimizeSet(s);
+  expectNumLatPoints(s, 3);
+  expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
+                 loopsToBits({{l0, t0}, {l0, t1}}));
+  expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
+                            loopsToBits({{l0, t0}}));
+  expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
+                            loopsToBits({{l0, t1}}));
+}
 
-/// Vector multiplication (conjunction) of 2 vectors, i.e.;
+/// Vector multiplication of 2 vectors, i.e.:
 ///   a(i) = b(i) * c(i)
 /// which should form the single lattice point
 /// {
 ///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
 /// }
-#define IMPL_MERGER_TEST_CONJ(OP)                                              \
-  TEST_F(MergerTest3T1L, vector_##OP) {                                        \
-    auto e = OP##Expr(t0, t1);                                                 \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto s = merger.buildLattices(e, l0);                                      \
-                                                                               \
-    expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
-                   loopsToBits({{l0, t0}, {l0, t1}}));                         \
-                                                                               \
-    s = merger.optimizeSet(s);                                                 \
-    expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
-                   loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
-  }
-
-FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
-
-#undef IMPL_MERGER_TEST_CONJ
-
-/// Vector multiplication (conjunction) then addition (disjunction), i.e.;
-///   a(i) = b(i) * c(i) + d(i);
-/// which should form
-/// {
-///    lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
-///    lat( i_00 i_01 / tensor_0 * tensor_1
-///    lat( i_02 / tensor_2 )
-/// }
-#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ)                                 \
-  TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) {                             \
-    auto em = CONJ##Expr(t0, t1);                                              \
-    auto e = DISJ##Expr(em, t2);                                               \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto p2 = tensorPattern(t2);                                               \
-    auto s = merger.buildLattices(e, l0);                                      \
-                                                                               \
-    expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2),        \
-                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1),             \
-                              loopsToBits({{l0, t0}, {l0, t1}}));              \
-    expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}}));      \
-                                                                               \
-    s = merger.optimizeSet(s);                                                 \
-    expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2),        \
-                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1),             \
-                              loopsToBits({{l0, t0}, {l0, t1}}));              \
-    expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}}));      \
-  }
-
-FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
-
-#undef IMPL_MERGER_TEST_CONJ_DISJ
-
-/// Vector addition (disjunction) then addition (disjunction), i.e.;
-///   a(i) = b(i) + c(i) + d(i)
-/// which should form
-/// {
-///   lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
-///   lat( i_02 i_01 / tensor_2 + tensor_1 )
-///   lat( i_02 i_00 / tensor_2 + tensor_0 )
-///   lat( i_01 i_00 / tensor_1 + tensor_0 )
-///   lat( i_02 / tensor_2 )
-///   lat( i_01 / tensor_1 )
-///   lat( i_00 / tensor_0 )
-/// }
-#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2)                               \
-  TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) {                           \
-    auto em = DISJ1##Expr(t0, t1);                                             \
-    auto e = DISJ2##Expr(em, t2);                                              \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto p2 = tensorPattern(t2);                                               \
-    auto s = merger.buildLattices(e, l0);                                      \
-                                                                               \
-    expectNumLatPoints(s, 7);                                                  \
-    expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),      \
-                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2),            \
-                              loopsToBits({{l0, t1}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2),            \
-                              loopsToBits({{l0, t0}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1),            \
-                              loopsToBits({{l0, t0}, {l0, t1}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}}));      \
-    expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}}));      \
-    expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}}));      \
-                                                                               \
-    s = merger.optimizeSet(s);                                                 \
-    expectNumLatPoints(s, 7);                                                  \
-    expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),      \
-                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2),            \
-                              loopsToBits({{l0, t1}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2),            \
-                              loopsToBits({{l0, t0}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1),            \
-                              loopsToBits({{l0, t0}, {l0, t1}}));              \
-    expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}}));      \
-    expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}}));      \
-    expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}}));      \
-  }
-
-FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
-
-#undef IMPL_MERGER_TEST_DISJ_DISJ
-
-/// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
-///   a(i) = b(i) * c(i) * d(i);
-/// which should form
-/// {
-///    lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
-/// }
-#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2)                               \
-  TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) {                           \
-    auto em = CONJ1##Expr(t0, t1);                                             \
-    auto e = CONJ2##Expr(em, t2);                                              \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto p2 = tensorPattern(t2);                                               \
-    auto s = merger.buildLattices(e, l0);                                      \
-    expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
-                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    s = merger.optimizeSet(s);                                                 \
-    expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
-                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true);         \
-  }
-
-FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
-
-#undef IMPL_MERGER_TEST_CONJ_CONJ
-
-/// Vector addition (disjunction) of 2 vectors, i.e.;
-///   a(i) = b(i) + c(i)
-/// which should form the 3 lattice points
-/// {
-///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
-///   lat( i_00 / sparse_tensor_0 )
-///   lat( i_01 / dense_tensor_1 )
-/// }
-/// which should be optimized to
-/// {
-///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
-///   lat( i_01 / dense_tensor_0 ) (no sparse dimension)
-/// }
-///
-/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense 
diff 
-/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
-#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP)                                    \
-  TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
-    auto e = OP##Expr(tensor(t0), tensor(t1));                                 \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto s = merger.buildLattices(e, l0);                                      \
-                                                                               \
-    expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
-                   loopsToBits({{l0, t0}, {l0, t1}}));                         \
-    expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}));      \
-    expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}));      \
-                                                                               \
-    s = merger.optimizeSet(s);                                                 \
-    expectNumLatPoints(s, 2);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
-                   loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
-    expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true);              \
-  }
-
-FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
-
-#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
-
-/// Vector multiplication (conjunction) of 2 vectors, i.e.:
-///   a(i) = b(i) * c(i)
-/// which should form the single lattice point
-/// {
-///   lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
-/// }
-/// it should be optimized to
-/// {
-///   lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
-/// }
-/// since i_01 is a dense dimension.
-#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP)                                    \
-  TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
-    auto e = OP##Expr(t0, t1);                                                 \
-    auto p0 = tensorPattern(t0);                                               \
-    auto p1 = tensorPattern(t1);                                               \
-    auto s = merger.buildLattices(e, l0);                                      \
-                                                                               \
-    expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
-                   loopsToBits({{l0, t0}, {l0, t1}}));                         \
-                                                                               \
-    s = merger.optimizeSet(s);                                                 \
-    expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}),    \
-                   true);                                                      \
-  }
-
-FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
-
-#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
-
-// TODO: mult-dim tests
+TEST_F(MergerTest3T1L, VectorMul2) {
+  // Construct expression.
+  auto e = mulf(t0, t1);
+
+  // Build lattices and check.
+  auto s = merger.buildLattices(e, l0);
+  expectNumLatPoints(s, 1);
+  expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
+                 loopsToBits({{l0, t0}, {l0, t1}}));
+
+  // Optimize lattices and check.
+  s = merger.optimizeSet(s);
+  expectNumLatPoints(s, 1);
+  expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
+                 loopsToBits({{l0, t0}, {l0, t1}}));
+}


        


More information about the Mlir-commits mailing list