[Mlir-commits] [mlir] fd9b3e4 - [mlir][sparse] cleanup merger test, add header (#70279)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 26 09:07:02 PDT 2023
Author: Aart Bik
Date: 2023-10-26T09:06:57-07:00
New Revision: fd9b3e471e80d5c6aaf525f9fc3a04569eddb960
URL: https://github.com/llvm/llvm-project/commit/fd9b3e471e80d5c6aaf525f9fc3a04569eddb960
DIFF: https://github.com/llvm/llvm-project/commit/fd9b3e471e80d5c6aaf525f9fc3a04569eddb960.diff
LOG: [mlir][sparse] cleanup merger test, add header (#70279)
Added:
Modified:
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Removed:
################################################################################
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 561753f631c16f1..e28d88e046fcdc1 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -1,3 +1,11 @@
+//===- MergerTest.cpp - Tests for the sparsifier's merger -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "llvm/Support/Compiler.h"
#include "gmock/gmock.h"
@@ -73,56 +81,43 @@ namespace {
/// Helper classes/functions for testing Merger.
///
-/// Simple recursive data structure used to match expressions in `Merger`.
-struct Pattern;
-/// Since the patterns we need are rather small and short-lived, we use
-/// `Pattern const&` for "pointers" to patterns, rather than using
-/// something more elaborate like `std::shared_ptr<Pattern> const&`.
-using PatternRef = const Pattern &;
-struct Pattern {
+/// Simple recursive data structure used to match expressions in `Merger`,
+/// which uses const references into the short-lived data strucutures.
+struct Match {
struct Children {
- Children(PatternRef e0, PatternRef e1) : e0(e0), e1(e1) {}
- PatternRef e0;
- PatternRef e1;
+ Children(const Match &e0, const Match &e1) : e0(e0), e1(e1) {}
+ const Match &e0;
+ const Match &e1;
};
- TensorExp::Kind kind;
+ Match() : kind(TensorExp::Kind::kSynZero) {}
+ Match(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
+ Match(TensorExp::Kind kind, const Match &e0, const Match &e1)
+ : kind(kind), children(e0, e1) {
+ assert(kind >= TensorExp::Kind::kMulF);
+ }
+ TensorExp::Kind kind;
union {
- /// Expressions representing tensors simply have a tensor number.
TensorId tid;
-
- /// Tensor operations point to their children.
Children children;
};
-
- /// Constructors.
- /// Rather than using these, please use the readable builder
- /// functions below to make tests more readable.
- Pattern() : kind(TensorExp::Kind::kSynZero) {}
- Pattern(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
- Pattern(TensorExp::Kind kind, PatternRef e0, PatternRef e1)
- : kind(kind), children(e0, e1) {
- assert(kind >= TensorExp::Kind::kMulF);
- }
};
///
-/// Readable Pattern builder functions.
+/// Readable Match builder functions.
/// These should be preferred over the actual constructors.
///
-static Pattern tensorPattern(TensorId tid) { return Pattern(tid); }
-static Pattern synZeroPattern() { return Pattern(); }
+static Match tensorMatch(TensorId tid) { return Match(tid); }
+static Match synZeroMatch() { return Match(); }
#define IMPL_BINOP_PATTERN(OP, KIND) \
- LLVM_ATTRIBUTE_UNUSED static Pattern OP##Pattern(PatternRef e0, \
- PatternRef e1) { \
- return Pattern(KIND, e0, e1); \
+ LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0, \
+ const Match &e1) { \
+ return Match(KIND, e0, e1); \
}
-
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
-
#undef IMPL_BINOP_PATTERN
class MergerTestBase : public ::testing::Test {
@@ -150,9 +145,7 @@ class MergerTestBase : public ::testing::Test {
LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) { \
return merger.addExp(KIND, e0, e1); \
}
-
FOREVERY_BINOP(IMPL_BINOP_EXPR)
-
#undef IMPL_BINOP_EXPR
///
@@ -168,7 +161,7 @@ class MergerTestBase : public ::testing::Test {
/// ordering within groups. If `simple` is true, then compare the
/// `lat.simple` field instead to test the result after optimization.
bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n,
- PatternRef pattern, const BitVector &bits,
+ const Match &pattern, const BitVector &bits,
bool simple) {
for (unsigned k = lo, hi = lo + n; k < hi; ++k) {
if (compareExpression(merger.lat(merger.set(s)[k]).exp, pattern) &&
@@ -180,13 +173,13 @@ class MergerTestBase : public ::testing::Test {
/// Wrapper over latPointWithinRange for readability of tests.
void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n,
- PatternRef pattern, const BitVector &bits,
+ const Match &pattern, const BitVector &bits,
bool simple = false) {
EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple));
}
/// Wrapper over expectLatPointWithinRange for a single lat point.
- void expectLatPoint(LatSetId s, unsigned lo, PatternRef pattern,
+ void expectLatPoint(LatSetId s, unsigned lo, const Match &pattern,
const BitVector &bits, bool simple = false) {
EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple));
}
@@ -216,7 +209,7 @@ class MergerTestBase : public ::testing::Test {
/// Compares expressions for equality. Equality is defined recursively as:
/// - Operations are equal if they have the same kind and children.
/// - Leaf tensors are equal if they refer to the same tensor.
- bool compareExpression(ExprId e, PatternRef pattern) {
+ bool compareExpression(ExprId e, const Match &pattern) {
const auto &tensorExp = merger.exp(e);
if (tensorExp.kind != pattern.kind)
return false;
@@ -424,21 +417,19 @@ class MergerTest3T1LSo : public MergerTestBase {
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const auto t2 = tid(2); \
- PatternRef p0 = tensorPattern(t0); \
- PatternRef p1 = tensorPattern(t1); \
- PatternRef p2 = tensorPattern(t2); \
+ const Match &p0 = tensorMatch(t0); \
+ const Match &p1 = tensorMatch(t1); \
+ const Match &p2 = tensorMatch(t2); \
auto s = merger.buildLattices(e, l0); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t1}}), true); \
}
-
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
-
#undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
/// Vector multiplication (conjunction) of 2 vectors, i.e.;
@@ -461,21 +452,19 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
const auto t1 = tid(1); \
const auto t2 = tid(2); \
const auto t3 = tid(3); \
- PatternRef p0 = tensorPattern(t0); \
- PatternRef p1 = tensorPattern(t1); \
- PatternRef p2 = tensorPattern(t2); \
+ const Match &p0 = tensorMatch(t0); \
+ const Match &p1 = tensorMatch(t1); \
+ const Match &p2 = tensorMatch(t2); \
auto s = merger.buildLattices(e, l0); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}})); \
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t3}}), true); \
}
-
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
-
#undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
/// Vector addition (disjunction) of 2 vectors. i.e.;
@@ -499,26 +488,24 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
- PatternRef p0 = tensorPattern(t0); \
- PatternRef p1 = tensorPattern(t1); \
+ const Match &p0 = tensorMatch(t0); \
+ const Match &p1 = tensorMatch(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 3); \
- expectLatPoint(s, 0, OP##Pattern(p0, p1), \
+ expectLatPoint(s, 0, OP##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 3); \
- expectLatPoint(s, 0, OP##Pattern(p0, p1), \
- loopsToBits({{l0, t0}, {l0, t1}}), true); \
+ expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
+ true); \
expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true); \
expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true); \
}
-
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
-
#undef IMPL_MERGER_TEST_DISJ
/// Vector multiplication (conjunction) of 2 vectors, i.e.;
@@ -533,22 +520,20 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
- PatternRef p0 = tensorPattern(t0); \
- PatternRef p1 = tensorPattern(t1); \
+ const Match &p0 = tensorMatch(t0); \
+ const Match &p1 = tensorMatch(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 1); \
- expectLatPoint(s, 0, OP##Pattern(p0, p1), \
+ expectLatPoint(s, 0, OP##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, 0, OP##Pattern(p0, p1), \
- loopsToBits({{l0, t0}, {l0, t1}}), true); \
+ expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
+ true); \
}
-
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
-
#undef IMPL_MERGER_TEST_CONJ
/// Vector multiplication (conjunction) then addition (disjunction), i.e.;
@@ -567,29 +552,27 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const auto t2 = tid(2); \
- PatternRef p0 = tensorPattern(t0); \
- PatternRef p1 = tensorPattern(t1); \
- PatternRef p2 = tensorPattern(t2); \
+ const Match &p0 = tensorMatch(t0); \
+ const Match &p1 = tensorMatch(t1); \
+ const Match &p2 = tensorMatch(t2); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 3); \
- expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1), \
+ expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 3); \
- expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1), \
+ expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
}
-
FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
-
#undef IMPL_MERGER_TEST_CONJ_DISJ
/// Vector addition (disjunction) then addition (disjunction), i.e.;
@@ -612,19 +595,19 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const auto t2 = tid(2); \
- PatternRef p0 = tensorPattern(t0); \
- PatternRef p1 = tensorPattern(t1); \
- PatternRef p2 = tensorPattern(t2); \
+ const Match &p0 = tensorMatch(t0); \
+ const Match &p1 = tensorMatch(t1); \
+ const Match &p2 = tensorMatch(t2); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 7); \
- expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2), \
+ expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \
loopsToBits({{l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2), \
+ expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \
loopsToBits({{l0, t0}, {l0, t2}})); \
- expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1), \
+ expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
@@ -632,21 +615,19 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 7); \
- expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2), \
+ expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \
loopsToBits({{l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2), \
+ expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \
loopsToBits({{l0, t0}, {l0, t2}})); \
- expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1), \
+ expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \
}
-
FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
-
#undef IMPL_MERGER_TEST_DISJ_DISJ
/// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
@@ -663,21 +644,19 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const auto t2 = tid(2); \
- PatternRef p0 = tensorPattern(t0); \
- PatternRef p1 = tensorPattern(t1); \
- PatternRef p2 = tensorPattern(t2); \
+ const Match &p0 = tensorMatch(t0); \
+ const Match &p1 = tensorMatch(t1); \
+ const Match &p2 = tensorMatch(t2); \
auto s = merger.buildLattices(e, l0); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \
}
-
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
-
#undef IMPL_MERGER_TEST_CONJ_CONJ
/// Vector addition (disjunction) of 2 vectors, i.e.;
@@ -702,25 +681,23 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
- PatternRef p0 = tensorPattern(t0); \
- PatternRef p1 = tensorPattern(t1); \
+ const Match &p0 = tensorMatch(t0); \
+ const Match &p1 = tensorMatch(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 3); \
- expectLatPoint(s, 0, OP##Pattern(p0, p1), \
+ expectLatPoint(s, 0, OP##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 2); \
- expectLatPoint(s, 0, OP##Pattern(p0, p1), \
- loopsToBits({{l0, t0}, {l0, t1}}), true); \
+ expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
+ true); \
expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true); \
}
-
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
-
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
/// Vector multiplication (conjunction) of 2 vectors, i.e.:
@@ -740,20 +717,20 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
- PatternRef p0 = tensorPattern(t0); \
- PatternRef p1 = tensorPattern(t1); \
+ const Match &p0 = tensorMatch(t0); \
+ const Match &p1 = tensorMatch(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 1); \
- expectLatPoint(s, 0, OP##Pattern(p0, p1), \
+ expectLatPoint(s, 0, OP##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, 0, OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), true); \
+ expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}}), true); \
}
-
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
+#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
/// Vector element-wise comparison (disjunction) of 2 vectors. i.e.;
/// a(i) = b(i) + c(i)
@@ -775,20 +752,20 @@ TEST_F(MergerTest3T1L, vector_cmp) {
const auto l0 = lid(0);
const auto t0 = tid(0);
const auto t1 = tid(1);
- PatternRef zero = synZeroPattern();
- PatternRef p0 = tensorPattern(t0);
- PatternRef p1 = tensorPattern(t1);
+ const Match &zero = synZeroMatch();
+ const Match &p0 = tensorMatch(t0);
+ const Match &p1 = tensorMatch(t1);
auto s = merger.buildLattices(e, l0);
- expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
- expectLatPointWithinRange(s, 1, 2, cmpiPattern(p0, zero),
+ expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
+ expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
loopsToBits({{l0, t0}}));
- expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
+ expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
loopsToBits({{l0, t1}}));
s = merger.optimizeSet(s);
- expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
- expectLatPointWithinRange(s, 1, 2, cmpiPattern(p0, zero),
+ expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
+ expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
loopsToBits({{l0, t0}}));
- expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
+ expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
loopsToBits({{l0, t1}}));
}
@@ -813,19 +790,17 @@ TEST_F(MergerTest3T1LD, vector_cmp) {
const auto l0 = lid(0);
const auto t0 = tid(0);
const auto t1 = tid(1);
- PatternRef zero = synZeroPattern();
- PatternRef p0 = tensorPattern(t0);
- PatternRef p1 = tensorPattern(t1);
+ const Match &zero = synZeroMatch();
+ const Match &p0 = tensorMatch(t0);
+ const Match &p1 = tensorMatch(t1);
auto s = merger.buildLattices(e, l0);
- expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
- expectLatPointWithinRange(s, 1, 2, cmpiPattern(p0, zero),
+ expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
+ expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
loopsToBits({{l0, t0}}));
- expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
+ expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
loopsToBits({{l0, t1}}));
s = merger.optimizeSet(s);
- expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
- expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
+ expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
+ expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
loopsToBits({{l0, t1}}));
}
-
-#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
More information about the Mlir-commits
mailing list