[Mlir-commits] [mlir] [mlir][sparse] cleanup merger test, add header (PR #70279)

Aart Bik llvmlistbot at llvm.org
Wed Oct 25 18:50:49 PDT 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/70279

None

>From 463f67eb022ffaca2aca66676f7ac31ae4dbd6cb Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Wed, 25 Oct 2023 18:49:06 -0700
Subject: [PATCH] [mlir][sparse] cleanup merger test, add header

---
 .../Dialect/SparseTensor/MergerTest.cpp       | 227 ++++++++----------
 1 file changed, 101 insertions(+), 126 deletions(-)

diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 561753f631c16f1..e28d88e046fcdc1 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -1,3 +1,11 @@
+//===- MergerTest.cpp - Tests for the sparsifier's merger -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
 #include "llvm/Support/Compiler.h"
 #include "gmock/gmock.h"
@@ -73,56 +81,43 @@ namespace {
 /// Helper classes/functions for testing Merger.
 ///
 
-/// Simple recursive data structure used to match expressions in `Merger`.
-struct Pattern;
-/// Since the patterns we need are rather small and short-lived, we use
-/// `Pattern const&` for "pointers" to patterns, rather than using
-/// something more elaborate like `std::shared_ptr<Pattern> const&`.
-using PatternRef = const Pattern &;
-struct Pattern {
+/// Simple recursive data structure used to match expressions in `Merger`,
+/// which uses const references into the short-lived data strucutures.
+struct Match {
   struct Children {
-    Children(PatternRef e0, PatternRef e1) : e0(e0), e1(e1) {}
-    PatternRef e0;
-    PatternRef e1;
+    Children(const Match &e0, const Match &e1) : e0(e0), e1(e1) {}
+    const Match &e0;
+    const Match &e1;
   };
 
-  TensorExp::Kind kind;
+  Match() : kind(TensorExp::Kind::kSynZero) {}
+  Match(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
+  Match(TensorExp::Kind kind, const Match &e0, const Match &e1)
+      : kind(kind), children(e0, e1) {
+    assert(kind >= TensorExp::Kind::kMulF);
+  }
 
+  TensorExp::Kind kind;
   union {
-    /// Expressions representing tensors simply have a tensor number.
     TensorId tid;
-
-    /// Tensor operations point to their children.
     Children children;
   };
-
-  /// Constructors.
-  /// Rather than using these, please use the readable builder
-  /// functions below to make tests more readable.
-  Pattern() : kind(TensorExp::Kind::kSynZero) {}
-  Pattern(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
-  Pattern(TensorExp::Kind kind, PatternRef e0, PatternRef e1)
-      : kind(kind), children(e0, e1) {
-    assert(kind >= TensorExp::Kind::kMulF);
-  }
 };
 
 ///
-/// Readable Pattern builder functions.
+/// Readable Match builder functions.
 /// These should be preferred over the actual constructors.
 ///
 
-static Pattern tensorPattern(TensorId tid) { return Pattern(tid); }
-static Pattern synZeroPattern() { return Pattern(); }
+static Match tensorMatch(TensorId tid) { return Match(tid); }
+static Match synZeroMatch() { return Match(); }
 
 #define IMPL_BINOP_PATTERN(OP, KIND)                                           \
-  LLVM_ATTRIBUTE_UNUSED static Pattern OP##Pattern(PatternRef e0,              \
-                                                   PatternRef e1) {            \
-    return Pattern(KIND, e0, e1);                                              \
+  LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0,                \
+                                               const Match &e1) {              \
+    return Match(KIND, e0, e1);                                                \
   }
-
 FOREVERY_BINOP(IMPL_BINOP_PATTERN)
-
 #undef IMPL_BINOP_PATTERN
 
 class MergerTestBase : public ::testing::Test {
@@ -150,9 +145,7 @@ class MergerTestBase : public ::testing::Test {
   LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) {                \
     return merger.addExp(KIND, e0, e1);                                        \
   }
-
   FOREVERY_BINOP(IMPL_BINOP_EXPR)
-
 #undef IMPL_BINOP_EXPR
 
   ///
@@ -168,7 +161,7 @@ class MergerTestBase : public ::testing::Test {
   /// ordering within groups.  If `simple` is true, then compare the
   /// `lat.simple` field instead to test the result after optimization.
   bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n,
-                           PatternRef pattern, const BitVector &bits,
+                           const Match &pattern, const BitVector &bits,
                            bool simple) {
     for (unsigned k = lo, hi = lo + n; k < hi; ++k) {
       if (compareExpression(merger.lat(merger.set(s)[k]).exp, pattern) &&
@@ -180,13 +173,13 @@ class MergerTestBase : public ::testing::Test {
 
   /// Wrapper over latPointWithinRange for readability of tests.
   void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n,
-                                 PatternRef pattern, const BitVector &bits,
+                                 const Match &pattern, const BitVector &bits,
                                  bool simple = false) {
     EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple));
   }
 
   /// Wrapper over expectLatPointWithinRange for a single lat point.
-  void expectLatPoint(LatSetId s, unsigned lo, PatternRef pattern,
+  void expectLatPoint(LatSetId s, unsigned lo, const Match &pattern,
                       const BitVector &bits, bool simple = false) {
     EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple));
   }
@@ -216,7 +209,7 @@ class MergerTestBase : public ::testing::Test {
   /// Compares expressions for equality. Equality is defined recursively as:
   /// - Operations are equal if they have the same kind and children.
   /// - Leaf tensors are equal if they refer to the same tensor.
-  bool compareExpression(ExprId e, PatternRef pattern) {
+  bool compareExpression(ExprId e, const Match &pattern) {
     const auto &tensorExp = merger.exp(e);
     if (tensorExp.kind != pattern.kind)
       return false;
@@ -424,21 +417,19 @@ class MergerTest3T1LSo : public MergerTestBase {
     const auto t0 = tid(0);                                                    \
     const auto t1 = tid(1);                                                    \
     const auto t2 = tid(2);                                                    \
-    PatternRef p0 = tensorPattern(t0);                                         \
-    PatternRef p1 = tensorPattern(t1);                                         \
-    PatternRef p2 = tensorPattern(t2);                                         \
+    const Match &p0 = tensorMatch(t0);                                         \
+    const Match &p1 = tensorMatch(t1);                                         \
+    const Match &p2 = tensorMatch(t2);                                         \
     auto s = merger.buildLattices(e, l0);                                      \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
+    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
+    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                    loopsToBits({{l0, t1}}), true);                             \
   }
-
 FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
-
 #undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
 
 /// Vector multiplication (conjunction) of 2 vectors, i.e.;
@@ -461,21 +452,19 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
     const auto t1 = tid(1);                                                    \
     const auto t2 = tid(2);                                                    \
     const auto t3 = tid(3);                                                    \
-    PatternRef p0 = tensorPattern(t0);                                         \
-    PatternRef p1 = tensorPattern(t1);                                         \
-    PatternRef p2 = tensorPattern(t2);                                         \
+    const Match &p0 = tensorMatch(t0);                                         \
+    const Match &p1 = tensorMatch(t1);                                         \
+    const Match &p2 = tensorMatch(t2);                                         \
     auto s = merger.buildLattices(e, l0);                                      \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
+    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}}));               \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
+    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                    loopsToBits({{l0, t3}}), true);                             \
   }
-
 FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
-
 #undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
 
 /// Vector addition (disjunction) of 2 vectors. i.e.;
@@ -499,26 +488,24 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
     const auto t1 = tid(1);                                                    \
-    PatternRef p0 = tensorPattern(t0);                                         \
-    PatternRef p1 = tensorPattern(t1);                                         \
+    const Match &p0 = tensorMatch(t0);                                         \
+    const Match &p1 = tensorMatch(t1);                                         \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
+    expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
     expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}));           \
     expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}));           \
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
-                   loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
+    expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
+                   true);                                                      \
     expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true);     \
     expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true);     \
   }
-
 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
-
 #undef IMPL_MERGER_TEST_DISJ
 
 /// Vector multiplication (conjunction) of 2 vectors, i.e.;
@@ -533,22 +520,20 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
     const auto t1 = tid(1);                                                    \
-    PatternRef p0 = tensorPattern(t0);                                         \
-    PatternRef p1 = tensorPattern(t1);                                         \
+    const Match &p0 = tensorMatch(t0);                                         \
+    const Match &p1 = tensorMatch(t1);                                         \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
+    expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
-                   loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
+    expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
+                   true);                                                      \
   }
-
 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
-
 #undef IMPL_MERGER_TEST_CONJ
 
 /// Vector multiplication (conjunction) then addition (disjunction), i.e.;
@@ -567,29 +552,27 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
     const auto t0 = tid(0);                                                    \
     const auto t1 = tid(1);                                                    \
     const auto t2 = tid(2);                                                    \
-    PatternRef p0 = tensorPattern(t0);                                         \
-    PatternRef p1 = tensorPattern(t1);                                         \
-    PatternRef p2 = tensorPattern(t2);                                         \
+    const Match &p0 = tensorMatch(t0);                                         \
+    const Match &p1 = tensorMatch(t1);                                         \
+    const Match &p2 = tensorMatch(t2);                                         \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2),             \
+    expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2),                 \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1),                  \
+    expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1),                    \
                               loopsToBits({{l0, t0}, {l0, t1}}));              \
     expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}}));           \
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2),             \
+    expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2),                 \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1),                  \
+    expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1),                    \
                               loopsToBits({{l0, t0}, {l0, t1}}));              \
     expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}}));           \
   }
-
 FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
-
 #undef IMPL_MERGER_TEST_CONJ_DISJ
 
 /// Vector addition (disjunction) then addition (disjunction), i.e.;
@@ -612,19 +595,19 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
     const auto t0 = tid(0);                                                    \
     const auto t1 = tid(1);                                                    \
     const auto t2 = tid(2);                                                    \
-    PatternRef p0 = tensorPattern(t0);                                         \
-    PatternRef p1 = tensorPattern(t1);                                         \
-    PatternRef p2 = tensorPattern(t2);                                         \
+    const Match &p0 = tensorMatch(t0);                                         \
+    const Match &p1 = tensorMatch(t1);                                         \
+    const Match &p2 = tensorMatch(t2);                                         \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 7);                                                  \
-    expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),           \
+    expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2),               \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2),                 \
+    expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2),                   \
                               loopsToBits({{l0, t1}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2),                 \
+    expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2),                   \
                               loopsToBits({{l0, t0}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1),                 \
+    expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1),                   \
                               loopsToBits({{l0, t0}, {l0, t1}}));              \
     expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}}));           \
     expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}}));           \
@@ -632,21 +615,19 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 7);                                                  \
-    expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),           \
+    expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2),               \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
-    expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2),                 \
+    expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2),                   \
                               loopsToBits({{l0, t1}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2),                 \
+    expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2),                   \
                               loopsToBits({{l0, t0}, {l0, t2}}));              \
-    expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1),                 \
+    expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1),                   \
                               loopsToBits({{l0, t0}, {l0, t1}}));              \
     expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}}));           \
     expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}}));           \
     expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}}));           \
   }
-
 FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
-
 #undef IMPL_MERGER_TEST_DISJ_DISJ
 
 /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
@@ -663,21 +644,19 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
     const auto t0 = tid(0);                                                    \
     const auto t1 = tid(1);                                                    \
     const auto t2 = tid(2);                                                    \
-    PatternRef p0 = tensorPattern(t0);                                         \
-    PatternRef p1 = tensorPattern(t1);                                         \
-    PatternRef p2 = tensorPattern(t2);                                         \
+    const Match &p0 = tensorMatch(t0);                                         \
+    const Match &p1 = tensorMatch(t1);                                         \
+    const Match &p2 = tensorMatch(t2);                                         \
     auto s = merger.buildLattices(e, l0);                                      \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
+    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),           \
+    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true);         \
   }
-
 FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
-
 #undef IMPL_MERGER_TEST_CONJ_CONJ
 
 /// Vector addition (disjunction) of 2 vectors, i.e.;
@@ -702,25 +681,23 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
     const auto t1 = tid(1);                                                    \
-    PatternRef p0 = tensorPattern(t0);                                         \
-    PatternRef p1 = tensorPattern(t1);                                         \
+    const Match &p0 = tensorMatch(t0);                                         \
+    const Match &p1 = tensorMatch(t1);                                         \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 3);                                                  \
-    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
+    expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
     expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}));           \
     expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}));           \
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 2);                                                  \
-    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
-                   loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
+    expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
+                   true);                                                      \
     expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true);                   \
   }
-
 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
-
 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
 
 /// Vector multiplication (conjunction) of 2 vectors, i.e.:
@@ -740,20 +717,20 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
     const auto t1 = tid(1);                                                    \
-    PatternRef p0 = tensorPattern(t0);                                         \
-    PatternRef p1 = tensorPattern(t1);                                         \
+    const Match &p0 = tensorMatch(t0);                                         \
+    const Match &p1 = tensorMatch(t1);                                         \
     auto s = merger.buildLattices(e, l0);                                      \
                                                                                \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, 0, OP##Pattern(p0, p1),                                  \
+    expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
                                                                                \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, 0, OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), true);  \
+    expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}}), true);    \
   }
-
 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
+#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
 
 /// Vector element-wise comparison (disjunction) of 2 vectors. i.e.;
 ///   a(i) = b(i) + c(i)
@@ -775,20 +752,20 @@ TEST_F(MergerTest3T1L, vector_cmp) {
   const auto l0 = lid(0);
   const auto t0 = tid(0);
   const auto t1 = tid(1);
-  PatternRef zero = synZeroPattern();
-  PatternRef p0 = tensorPattern(t0);
-  PatternRef p1 = tensorPattern(t1);
+  const Match &zero = synZeroMatch();
+  const Match &p0 = tensorMatch(t0);
+  const Match &p1 = tensorMatch(t1);
   auto s = merger.buildLattices(e, l0);
-  expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
-  expectLatPointWithinRange(s, 1, 2, cmpiPattern(p0, zero),
+  expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
+  expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
                             loopsToBits({{l0, t0}}));
-  expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
+  expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
                             loopsToBits({{l0, t1}}));
   s = merger.optimizeSet(s);
-  expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
-  expectLatPointWithinRange(s, 1, 2, cmpiPattern(p0, zero),
+  expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
+  expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
                             loopsToBits({{l0, t0}}));
-  expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
+  expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
                             loopsToBits({{l0, t1}}));
 }
 
@@ -813,19 +790,17 @@ TEST_F(MergerTest3T1LD, vector_cmp) {
   const auto l0 = lid(0);
   const auto t0 = tid(0);
   const auto t1 = tid(1);
-  PatternRef zero = synZeroPattern();
-  PatternRef p0 = tensorPattern(t0);
-  PatternRef p1 = tensorPattern(t1);
+  const Match &zero = synZeroMatch();
+  const Match &p0 = tensorMatch(t0);
+  const Match &p1 = tensorMatch(t1);
   auto s = merger.buildLattices(e, l0);
-  expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
-  expectLatPointWithinRange(s, 1, 2, cmpiPattern(p0, zero),
+  expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
+  expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
                             loopsToBits({{l0, t0}}));
-  expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
+  expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
                             loopsToBits({{l0, t1}}));
   s = merger.optimizeSet(s);
-  expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
-  expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
+  expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
+  expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
                             loopsToBits({{l0, t1}}));
 }
-
-#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ



More information about the Mlir-commits mailing list