[Mlir-commits] [mlir] 164b918 - [mlir][sparse] Removing shared_ptr from the MergerTest.cpp unit test
wren romano
llvmlistbot at llvm.org
Tue Mar 28 11:25:34 PDT 2023
Author: wren romano
Date: 2023-03-28T11:25:20-07:00
New Revision: 164b918d73779f776ab3ac6a633d0832ca9e6622
URL: https://github.com/llvm/llvm-project/commit/164b918d73779f776ab3ac6a633d0832ca9e6622
DIFF: https://github.com/llvm/llvm-project/commit/164b918d73779f776ab3ac6a633d0832ca9e6622.diff
LOG: [mlir][sparse] Removing shared_ptr from the MergerTest.cpp unit test
This is a preliminary change to make way for converting the Merger's identifier types from mere typedefs to actual types (which causes some issues that this patch fixes).
Depends On D146676
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D146561
Added:
Modified:
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Removed:
################################################################################
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 599e8abd52f3..2497659e68dd 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -80,27 +80,39 @@ namespace {
/// Helper classes/functions for testing Merger.
///
-/// Simple recursive data structure used to match expressions in Mergers.
+/// Simple recursive data structure used to match expressions in `Merger`.
+struct Pattern;
+/// Since the patterns we need are rather small and short-lived, we use
+/// `Pattern const&` for "pointers" to patterns, rather than using
+/// something more elaborate like `std::shared_ptr<Pattern> const&`.
+/// (But since we use a typedef rather than spelling it out everywhere,
+/// that's easy enough to swap out if we need something more elaborate
+/// in the future.)
+using PatternRef = const Pattern &;
struct Pattern {
+ struct Children {
+ Children(PatternRef e0, PatternRef e1) : e0(e0), e1(e1) {}
+ PatternRef e0;
+ PatternRef e1;
+ };
+
TensorExp::Kind kind;
- /// Expressions representing tensors simply have a tensor number.
- unsigned tensorNum;
+ union {
+ /// Expressions representing tensors simply have a tensor number.
+ TensorId tid;
- /// Tensor operations point to their children.
- std::shared_ptr<Pattern> e0;
- std::shared_ptr<Pattern> e1;
+ /// Tensor operations point to their children.
+ Children children;
+ };
/// Constructors.
/// Rather than using these, please use the readable helper constructor
/// functions below to make tests more readable.
- Pattern(unsigned tensorNum)
- : kind(TensorExp::Kind::kTensor), tensorNum(tensorNum) {}
- Pattern(TensorExp::Kind kind, const std::shared_ptr<Pattern> &e0,
- const std::shared_ptr<Pattern> &e1)
- : kind(kind), e0(e0), e1(e1) {
+ Pattern(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
+ Pattern(TensorExp::Kind kind, PatternRef e0, PatternRef e1)
+ : kind(kind), children(e0, e1) {
assert(kind >= TensorExp::Kind::kMulF);
- assert(e0 && e1);
}
};
@@ -109,15 +121,12 @@ struct Pattern {
/// These should be preferred over the actual constructors.
///
-static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
- return std::make_shared<Pattern>(tensorNum);
-}
+static Pattern tensorPattern(TensorId tid) { return Pattern(tid); }
#define IMPL_BINOP_PATTERN(OP, KIND) \
- LLVM_ATTRIBUTE_UNUSED static std::shared_ptr<Pattern> OP##Pattern( \
- const std::shared_ptr<Pattern> &e0, \
- const std::shared_ptr<Pattern> &e1) { \
- return std::make_shared<Pattern>(KIND, e0, e1); \
+ LLVM_ATTRIBUTE_UNUSED static Pattern OP##Pattern(PatternRef e0, \
+ PatternRef e1) { \
+ return Pattern(KIND, e0, e1); \
}
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
@@ -127,20 +136,32 @@ FOREVERY_BINOP(IMPL_BINOP_PATTERN)
class MergerTestBase : public ::testing::Test {
protected:
MergerTestBase(unsigned numTensors, unsigned numLoops)
- : numTensors(numTensors), numLoops(numLoops),
- merger(numTensors, numLoops, /*numFilterLoops=*/0,
- /*maxRank=*/numLoops) {}
+ : merger(numTensors, numLoops, /*numFilterLoops=*/0,
+ /*maxRank=*/numLoops) {
+ tensors.reserve(numTensors);
+ for (unsigned t = 0; t < numTensors; t++)
+ tensors.push_back(merger.addExp(TensorExp::Kind::kTensor, tid(t)));
+ }
///
/// Expression construction helpers.
///
- unsigned tensor(unsigned tensor) {
- return merger.addExp(TensorExp::Kind::kTensor, tensor);
+ TensorId tid(unsigned t) const {
+ assert(t < merger.getNumTensors());
+ return t;
+ }
+ LoopId lid(unsigned i) const {
+ assert(i < merger.getNumLoops());
+ return i;
+ }
+ ExprId tensor(unsigned t) const {
+ assert(t < tensors.size());
+ return tensors[t];
}
#define IMPL_BINOP_EXPR(OP, KIND) \
- LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) { \
+ LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) { \
return merger.addExp(KIND, e0, e1); \
}
@@ -152,83 +173,77 @@ class MergerTestBase : public ::testing::Test {
/// Comparison helpers.
///
- /// For readability of tests.
- unsigned lat(unsigned lat) { return lat; }
-
- /// Returns true if a lattice point with an expression matching the given
- /// pattern and bits matching the given bits is present in lattice points
- /// [p, p+n) of lattice set s. This is useful for testing partial ordering
- /// constraints between lattice points. We generally know how contiguous
- /// groups of lattice points should be ordered with respect to other groups,
- /// but there is no required ordering within groups.
- /// If simple is true, then compare the lat.simple field instead to test the
- /// result after optimization
- bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
- const std::shared_ptr<Pattern> &pattern,
- const BitVector &bits, bool simple) {
- for (unsigned i = p; i < p + n; ++i) {
- if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
- compareBits(s, i, bits, simple))
+ /// Returns true if any lattice point with an expression matching
+ /// the given `pattern` and bits matching the given `bits` is present
+ /// in the `[lo, lo+n)` slice of the lattice set `s`. This is useful
+ /// for testing partial ordering constraints between lattice points.
+ /// We generally know how contiguous groups of lattice points should
+ /// be ordered with respect to other groups, but there is no required
+ /// ordering within groups. If `simple` is true, then compare the
+ /// `lat.simple` field instead to test the result after optimization.
+ bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n,
+ PatternRef pattern, const BitVector &bits,
+ bool simple) {
+ for (unsigned k = lo, hi = lo + n; k < hi; ++k) {
+ if (compareExpression(merger.lat(merger.set(s)[k]).exp, pattern) &&
+ compareBits(s, k, bits, simple))
return true;
}
return false;
}
/// Wrapper over latPointWithinRange for readability of tests.
- void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
- const std::shared_ptr<Pattern> &pattern,
- const BitVector &bits, bool simple = false) {
- EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple));
+ void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n,
+ PatternRef pattern, const BitVector &bits,
+ bool simple = false) {
+ EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple));
}
/// Wrapper over expectLatPointWithinRange for a single lat point.
- void expectLatPoint(unsigned s, unsigned p,
- const std::shared_ptr<Pattern> &pattern,
+ void expectLatPoint(LatSetId s, unsigned lo, PatternRef pattern,
const BitVector &bits, bool simple = false) {
- EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple));
+ EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple));
}
/// Converts a vector of (loop, tensor) pairs to a bitvector with the
/// corresponding bits set.
- BitVector
- loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) {
- BitVector testBits = BitVector(numTensors + 1, false);
- for (auto l : loops) {
- auto loop = std::get<0>(l);
- auto tensor = std::get<1>(l);
+ BitVector loopsToBits(const std::vector<std::pair<LoopId, TensorId>> &loops) {
+ // NOTE: this `numTensors` includes both the output- and synthetic-tensors.
+ const auto numTensors = merger.getNumTensors();
+ BitVector testBits = BitVector(numTensors, false);
+ for (auto [loop, tensor] : loops)
testBits.set(numTensors * loop + tensor);
- }
return testBits;
}
- /// Returns true if the bits of lattice point p in set s match the given bits.
- /// If simple is true, then compare the lat.simple field instead to test the
- /// result after optimization
- bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) {
- if (simple)
- return merger.lat(merger.set(s)[p]).simple == bits;
- return merger.lat(merger.set(s)[p]).bits == bits;
+ /// Returns true if the bits of the `k`th point in set `s` matches
+ /// the given `bits`. If `simple` is true, then compares the `lat.simple`
+ /// field instead, to test the result after optimization
+ bool compareBits(LatSetId s, unsigned k, const BitVector &bits, bool simple) {
+ const auto &point = merger.lat(merger.set(s)[k]);
+ return (simple ? point.simple : point.bits) == bits;
}
/// Check that there are n lattice points in set s.
- void expectNumLatPoints(unsigned s, unsigned n) {
+ void expectNumLatPoints(LatSetId s, unsigned n) {
EXPECT_THAT(merger.set(s).size(), n);
}
/// Compares expressions for equality. Equality is defined recursively as:
/// - Operations are equal if they have the same kind and children.
/// - Leaf tensors are equal if they refer to the same tensor.
- bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
- auto tensorExp = merger.exp(e);
- if (tensorExp.kind != pattern->kind)
+ bool compareExpression(ExprId e, PatternRef pattern) {
+ const auto &tensorExp = merger.exp(e);
+ if (tensorExp.kind != pattern.kind)
return false;
switch (tensorExp.kind) {
// Leaf.
case TensorExp::Kind::kTensor:
- return tensorExp.tensor == pattern->tensorNum;
+ return tensorExp.tensor == pattern.tid;
case TensorExp::Kind::kInvariant:
- case TensorExp::Kind::kLoopVar:
llvm_unreachable("invariant not handled yet");
+ case TensorExp::Kind::kLoopVar:
+ llvm_unreachable("loop-variables not handled yet");
// Unary operations.
case TensorExp::Kind::kAbsF:
case TensorExp::Kind::kAbsC:
@@ -264,7 +279,7 @@ class MergerTestBase : public ::testing::Test {
case TensorExp::Kind::kSelect:
case TensorExp::Kind::kBinaryBranch:
case TensorExp::Kind::kUnary:
- return compareExpression(tensorExp.children.e0, pattern->e0);
+ return compareExpression(tensorExp.children.e0, pattern.children.e0);
// Binary operations.
case TensorExp::Kind::kMulF:
case TensorExp::Kind::kMulC:
@@ -287,72 +302,51 @@ class MergerTestBase : public ::testing::Test {
case TensorExp::Kind::kShlI:
case TensorExp::Kind::kBinary:
case TensorExp::Kind::kReduce:
- return compareExpression(tensorExp.children.e0, pattern->e0) &&
- compareExpression(tensorExp.children.e1, pattern->e1);
+ return compareExpression(tensorExp.children.e0, pattern.children.e0) &&
+ compareExpression(tensorExp.children.e1, pattern.children.e1);
}
llvm_unreachable("unexpected kind");
}
- unsigned numTensors;
- unsigned numLoops;
+ // This field is public for convenience.
Merger merger;
+
+private:
+ // This field is private to prevent mutation after the ctor.
+ SmallVector<ExprId> tensors;
};
///
/// Tests with all sparse inputs.
///
+/// Three tensors (two inputs, one output); and a single loop.
class MergerTest3T1L : public MergerTestBase {
protected:
- // Our three tensors (two inputs, one output).
- const unsigned t0 = 0, t1 = 1, t2 = 2;
-
- // Our single loop.
- const unsigned l0 = 0;
-
MergerTest3T1L() : MergerTestBase(3, 1) {
- EXPECT_TRUE(merger.getOutTensorID() == t2);
-
+ EXPECT_TRUE(merger.getOutTensorID() == tid(2));
// Tensor 0: sparse input vector.
- merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
- merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
-
+ merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed);
// Tensor 1: sparse input vector.
- merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
- merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed);
-
+ merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Compressed);
// Tensor 2: dense output vector.
- merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
- merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense);
+ merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Dense);
}
};
+/// Four tensors (three inputs, one output); and a single loop.
class MergerTest4T1L : public MergerTestBase {
protected:
- // Our four tensors (three inputs, one output).
- const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
-
- // Our single loop.
- const unsigned l0 = 0;
-
MergerTest4T1L() : MergerTestBase(4, 1) {
- EXPECT_TRUE(merger.getOutTensorID() == t3);
-
+ EXPECT_TRUE(merger.getOutTensorID() == tid(3));
// Tensor 0: sparse input vector.
- merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
- merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
-
+ merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed);
// Tensor 1: sparse input vector.
- merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
- merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed);
-
+ merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Compressed);
// Tensor 2: sparse input vector
- merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
- merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed);
-
+ merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Compressed);
// Tensor 3: dense output vector
- merger.addExp(TensorExp::Kind::kTensor, t3, -1u);
- merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense);
+ merger.setLevelAndType(tid(3), lid(0), 0, DimLevelType::Dense);
}
};
@@ -360,28 +354,17 @@ class MergerTest4T1L : public MergerTestBase {
/// Tests with both sparse and dense input.
///
+/// Three tensors (two inputs, one output); and a single loop.
class MergerTest3T1LD : public MergerTestBase {
protected:
- // Our three tensors (two inputs, one output).
- const unsigned t0 = 0, t1 = 1, t2 = 2;
-
- // Our single loop.
- const unsigned l0 = 0;
-
MergerTest3T1LD() : MergerTestBase(3, 1) {
- EXPECT_TRUE(merger.getOutTensorID() == t2);
-
+ EXPECT_TRUE(merger.getOutTensorID() == tid(2));
// Tensor 0: sparse input vector.
- merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
- merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
-
+ merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Compressed);
// Tensor 1: dense input vector.
- merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
- merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense);
-
+ merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Dense);
// Tensor 2: dense output vector.
- merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
- merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense);
+ merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Dense);
}
};
@@ -389,32 +372,19 @@ class MergerTest3T1LD : public MergerTestBase {
/// Tests with both undef and dense input.
///
+/// Three tensors (three inputs, one output); and a single loop.
class MergerTest4T1LU : public MergerTestBase {
protected:
- // Our three tensors (three inputs, one output).
- const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
-
- // Our single loop.
- const unsigned l0 = 0;
-
MergerTest4T1LU() : MergerTestBase(4, 1) {
- EXPECT_TRUE(merger.getOutTensorID() == t3);
-
+ EXPECT_TRUE(merger.getOutTensorID() == tid(3));
// Tensor 0: undef input vector.
- merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
- merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef);
-
+ merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Undef);
// Tensor 1: dense input vector.
- merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
- merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense);
-
+ merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Dense);
// Tensor 2: undef input vector.
- merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
- merger.setLevelAndType(t2, l0, 0, DimLevelType::Undef);
-
+ merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Undef);
// Tensor 3: dense output vector.
- merger.addExp(TensorExp::Kind::kTensor, t3, -1u);
- merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense);
+ merger.setLevelAndType(tid(3), lid(0), 0, DimLevelType::Dense);
}
};
@@ -422,31 +392,19 @@ class MergerTest4T1LU : public MergerTestBase {
/// Tests with operation on sparse output.
///
+/// Three tensors (two inputs, one output, one synthetic); and a single loop.
class MergerTest3T1LSo : public MergerTestBase {
protected:
- // Our three tensors (two inputs, one output, one synthetic).
- const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
-
- // Our single loop.
- const unsigned l0 = 0;
-
MergerTest3T1LSo() : MergerTestBase(3, 1) {
- EXPECT_TRUE(merger.getOutTensorID() == t2);
- EXPECT_TRUE(merger.getSynTensorID() == t3);
-
+ EXPECT_TRUE(merger.getOutTensorID() == tid(2));
+ EXPECT_TRUE(merger.getSynTensorID() == tid(3));
merger.setHasSparseOut(true);
-
// Tensor 0: undef input vector.
- merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
- merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef);
-
+ merger.setLevelAndType(tid(0), lid(0), 0, DimLevelType::Undef);
// Tensor 1: undef input vector.
- merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
- merger.setLevelAndType(t1, l0, 0, DimLevelType::Undef);
-
+ merger.setLevelAndType(tid(1), lid(0), 0, DimLevelType::Undef);
// Tensor 2: sparse output vector.
- merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
- merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed);
+ merger.setLevelAndType(tid(2), lid(0), 0, DimLevelType::Compressed);
}
};
@@ -465,18 +423,22 @@ class MergerTest3T1LSo : public MergerTestBase {
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
- auto em = CONJ1##Expr(t0, t1); \
- auto e = CONJ2##Expr(em, t2); \
- auto p0 = tensorPattern(t0); \
- auto p1 = tensorPattern(t1); \
- auto p2 = tensorPattern(t2); \
+ const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
+ const auto e = CONJ2##Expr(em, tensor(2)); \
+ const auto l0 = lid(0); \
+ const auto t0 = tid(0); \
+ const auto t1 = tid(1); \
+ const auto t2 = tid(2); \
+ const PatternRef p0 = tensorPattern(t0); \
+ const PatternRef p1 = tensorPattern(t1); \
+ const PatternRef p2 = tensorPattern(t2); \
auto s = merger.buildLattices(e, l0); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
loopsToBits({{l0, t1}}), true); \
}
@@ -497,18 +459,23 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
- auto em = CONJ1##Expr(t0, t1); \
- auto e = CONJ2##Expr(em, t2); \
- auto p0 = tensorPattern(t0); \
- auto p1 = tensorPattern(t1); \
- auto p2 = tensorPattern(t2); \
+ const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
+ const auto e = CONJ2##Expr(em, tensor(2)); \
+ const auto l0 = lid(0); \
+ const auto t0 = tid(0); \
+ const auto t1 = tid(1); \
+ const auto t2 = tid(2); \
+ const auto t3 = tid(3); \
+ const PatternRef p0 = tensorPattern(t0); \
+ const PatternRef p1 = tensorPattern(t1); \
+ const PatternRef p2 = tensorPattern(t2); \
auto s = merger.buildLattices(e, l0); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}})); \
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
loopsToBits({{l0, t3}}), true); \
}
@@ -533,25 +500,26 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
/// }
#define IMPL_MERGER_TEST_DISJ(OP) \
TEST_F(MergerTest3T1L, vector_##OP) { \
- auto e = OP##Expr(tensor(t0), tensor(t1)); \
- auto p0 = tensorPattern(t0); \
- auto p1 = tensorPattern(t1); \
+ const auto e = OP##Expr(tensor(0), tensor(1)); \
+ const auto l0 = lid(0); \
+ const auto t0 = tid(0); \
+ const auto t1 = tid(1); \
+ const PatternRef p0 = tensorPattern(t0); \
+ const PatternRef p1 = tensorPattern(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 3); \
- expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
+ expectLatPoint(s, 0, OP##Pattern(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
- expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \
- expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \
+ expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
+ expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 3); \
- expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
+ expectLatPoint(s, 0, OP##Pattern(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}}), true); \
- expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}), \
- true); \
- expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}), \
- true); \
+ expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true); \
+ expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true); \
}
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
@@ -566,18 +534,21 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
/// }
#define IMPL_MERGER_TEST_CONJ(OP) \
TEST_F(MergerTest3T1L, vector_##OP) { \
- auto e = OP##Expr(t0, t1); \
- auto p0 = tensorPattern(t0); \
- auto p1 = tensorPattern(t1); \
+ const auto e = OP##Expr(tensor(0), tensor(1)); \
+ const auto l0 = lid(0); \
+ const auto t0 = tid(0); \
+ const auto t1 = tid(1); \
+ const PatternRef p0 = tensorPattern(t0); \
+ const PatternRef p1 = tensorPattern(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 1); \
- expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
+ expectLatPoint(s, 0, OP##Pattern(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
+ expectLatPoint(s, 0, OP##Pattern(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}}), true); \
}
@@ -595,27 +566,31 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
/// }
#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
- auto em = CONJ##Expr(t0, t1); \
- auto e = DISJ##Expr(em, t2); \
- auto p0 = tensorPattern(t0); \
- auto p1 = tensorPattern(t1); \
- auto p2 = tensorPattern(t2); \
+ const auto em = CONJ##Expr(tensor(0), tensor(1)); \
+ const auto e = DISJ##Expr(em, tensor(2)); \
+ const auto l0 = lid(0); \
+ const auto t0 = tid(0); \
+ const auto t1 = tid(1); \
+ const auto t2 = tid(2); \
+ const PatternRef p0 = tensorPattern(t0); \
+ const PatternRef p1 = tensorPattern(t1); \
+ const PatternRef p2 = tensorPattern(t2); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 3); \
- expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \
+ expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
- expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \
+ expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 3); \
- expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \
+ expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
- expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \
+ expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
}
FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
@@ -636,39 +611,43 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
/// }
#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
- auto em = DISJ1##Expr(t0, t1); \
- auto e = DISJ2##Expr(em, t2); \
- auto p0 = tensorPattern(t0); \
- auto p1 = tensorPattern(t1); \
- auto p2 = tensorPattern(t2); \
+ const auto em = DISJ1##Expr(tensor(0), tensor(1)); \
+ const auto e = DISJ2##Expr(em, tensor(2)); \
+ const auto l0 = lid(0); \
+ const auto t0 = tid(0); \
+ const auto t1 = tid(1); \
+ const auto t2 = tid(2); \
+ const PatternRef p0 = tensorPattern(t0); \
+ const PatternRef p1 = tensorPattern(t1); \
+ const PatternRef p2 = tensorPattern(t2); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 7); \
- expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \
+ expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2), \
loopsToBits({{l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \
+ expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2), \
loopsToBits({{l0, t0}, {l0, t2}})); \
- expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \
+ expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
- expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \
- expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \
- expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \
+ expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
+ expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
+ expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 7); \
- expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \
+ expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2), \
loopsToBits({{l0, t1}, {l0, t2}})); \
- expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \
+ expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2), \
loopsToBits({{l0, t0}, {l0, t2}})); \
- expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \
+ expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
- expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \
- expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \
- expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \
+ expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
+ expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
+ expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \
}
FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
@@ -683,18 +662,22 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
- auto em = CONJ1##Expr(t0, t1); \
- auto e = CONJ2##Expr(em, t2); \
- auto p0 = tensorPattern(t0); \
- auto p1 = tensorPattern(t1); \
- auto p2 = tensorPattern(t2); \
+ const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
+ const auto e = CONJ2##Expr(em, tensor(2)); \
+ const auto l0 = lid(0); \
+ const auto t0 = tid(0); \
+ const auto t1 = tid(1); \
+ const auto t2 = tid(2); \
+ const PatternRef p0 = tensorPattern(t0); \
+ const PatternRef p1 = tensorPattern(t1); \
+ const PatternRef p2 = tensorPattern(t2); \
auto s = merger.buildLattices(e, l0); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
+ expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \
}
@@ -720,22 +703,25 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP) \
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
- auto e = OP##Expr(tensor(t0), tensor(t1)); \
- auto p0 = tensorPattern(t0); \
- auto p1 = tensorPattern(t1); \
+ const auto e = OP##Expr(tensor(0), tensor(1)); \
+ const auto l0 = lid(0); \
+ const auto t0 = tid(0); \
+ const auto t1 = tid(1); \
+ const PatternRef p0 = tensorPattern(t0); \
+ const PatternRef p1 = tensorPattern(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 3); \
- expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
+ expectLatPoint(s, 0, OP##Pattern(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
- expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \
- expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \
+ expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
+ expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 2); \
- expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
+ expectLatPoint(s, 0, OP##Pattern(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}}), true); \
- expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true); \
+ expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true); \
}
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
@@ -755,19 +741,21 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
/// since i_01 is a dense dimension.
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP) \
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
- auto e = OP##Expr(t0, t1); \
- auto p0 = tensorPattern(t0); \
- auto p1 = tensorPattern(t1); \
+ const auto e = OP##Expr(tensor(0), tensor(1)); \
+ const auto l0 = lid(0); \
+ const auto t0 = tid(0); \
+ const auto t1 = tid(1); \
+ const PatternRef p0 = tensorPattern(t0); \
+ const PatternRef p1 = tensorPattern(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 1); \
- expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
+ expectLatPoint(s, 0, OP##Pattern(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
- expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), \
- true); \
+ expectLatPoint(s, 0, OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), true); \
}
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
More information about the Mlir-commits
mailing list