[Mlir-commits] [mlir] 0d0cff3 - [mlir][sparse] Add Merger unit tests

Gus Smith llvmlistbot at llvm.org
Thu Jul 8 16:10:40 PDT 2021


Author: Gus Smith
Date: 2021-07-08T23:10:33Z
New Revision: 0d0cff3ace39378acfc66d6564dc99e19b8a561f

URL: https://github.com/llvm/llvm-project/commit/0d0cff3ace39378acfc66d6564dc99e19b8a561f
DIFF: https://github.com/llvm/llvm-project/commit/0d0cff3ace39378acfc66d6564dc99e19b8a561f.diff

LOG: [mlir][sparse] Add Merger unit tests

We opt to use unit tests rather than check tests as the lattice/merger code is a small C++ component with a well-defined API. Testing this API via check tests would be far less direct and readable. In addition, as the check tests will only be able to test the API indirectly, the tests may break based on unrelated changes; e.g. changes in linalg.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D104956

Added: 
    mlir/unittests/Dialect/SparseTensor/CMakeLists.txt
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Modified: 
    mlir/unittests/Dialect/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 22f1475a1393..6b441567b548 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -7,4 +7,5 @@ target_link_libraries(MLIRDialectTests
   MLIRDialect)
 
 add_subdirectory(Quant)
+add_subdirectory(SparseTensor)
 add_subdirectory(SPIRV)

diff  --git a/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt
new file mode 100644
index 000000000000..f9594aab3bbc
--- /dev/null
+++ b/mlir/unittests/Dialect/SparseTensor/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRSparseTensorTests
+  MergerTest.cpp
+)
+target_link_libraries(MLIRSparseTensorTests
+  PRIVATE
+  MLIRSparseTensorUtils
+)

diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
new file mode 100644
index 000000000000..ad647f0e7a29
--- /dev/null
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -0,0 +1,252 @@
+#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <memory>
+
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+/// Simple recursive data structure used to match expressions in Mergers.
+struct Pattern {
+  Kind kind;
+
+  /// Expressions representing tensors simply have a tensor number.
+  unsigned tensorNum;
+
+  /// Tensor operations point to their children.
+  std::shared_ptr<Pattern> e0;
+  std::shared_ptr<Pattern> e1;
+
+  /// Constructors.
+  /// Rather than using these, please use the readable helper constructor
+  /// functions below to make tests more readable.
+  Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
+  Pattern(Kind kind, std::shared_ptr<Pattern> e0, std::shared_ptr<Pattern> e1)
+      : kind(kind), e0(e0), e1(e1) {
+    assert(kind >= Kind::kMulF);
+    assert(e0 && e1);
+  }
+};
+
+///
+/// Readable Pattern builder functions.
+/// These should be preferred over the actual constructors.
+///
+
+static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
+  return std::make_shared<Pattern>(tensorNum);
+}
+
+static std::shared_ptr<Pattern> addfPattern(std::shared_ptr<Pattern> e0,
+                                            std::shared_ptr<Pattern> e1) {
+  return std::make_shared<Pattern>(Kind::kAddF, e0, e1);
+}
+
+static std::shared_ptr<Pattern> mulfPattern(std::shared_ptr<Pattern> e0,
+                                            std::shared_ptr<Pattern> e1) {
+  return std::make_shared<Pattern>(Kind::kMulF, e0, e1);
+}
+
+class MergerTestBase : public ::testing::Test {
+protected:
+  MergerTestBase(unsigned numTensors, unsigned numLoops)
+      : numTensors(numTensors), numLoops(numLoops),
+        merger(numTensors, numLoops) {}
+
+  ///
+  /// Expression construction helpers.
+  ///
+
+  unsigned tensor(unsigned tensor) {
+    return merger.addExp(Kind::kTensor, tensor);
+  }
+
+  unsigned addf(unsigned e0, unsigned e1) {
+    return merger.addExp(Kind::kAddF, e0, e1);
+  }
+
+  unsigned mulf(unsigned e0, unsigned e1) {
+    return merger.addExp(Kind::kMulF, e0, e1);
+  }
+
+  ///
+  /// Comparison helpers.
+  ///
+
+  /// For readability of tests.
+  unsigned lat(unsigned lat) { return lat; }
+
+  /// Returns true if a lattice point with an expression matching the given
+  /// pattern and bits matching the given bits is present in lattice points
+  /// [p, p+n) of lattice set s. This is useful for testing partial ordering
+  /// constraints between lattice points. We generally know how contiguous
+  /// groups of lattice points should be ordered with respect to other groups,
+  /// but there is no required ordering within groups.
+  bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
+                           std::shared_ptr<Pattern> pattern,
+                           llvm::BitVector bits) {
+    for (unsigned i = p; i < p + n; ++i) {
+      if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
+          compareBits(s, i, bits))
+        return true;
+    }
+    return false;
+  }
+
+  /// Wrapper over latPointWithinRange for readability of tests.
+  void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
+                                 std::shared_ptr<Pattern> pattern,
+                                 llvm::BitVector bits) {
+    EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits));
+  }
+
+  /// Wrapper over expectLatPointWithinRange for a single lat point.
+  void expectLatPoint(unsigned s, unsigned p, std::shared_ptr<Pattern> pattern,
+                      llvm::BitVector bits) {
+    EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits));
+  }
+
+  /// Converts a vector of (loop, tensor) pairs to a bitvector with the
+  /// corresponding bits set.
+  llvm::BitVector
+  loopsToBits(std::vector<std::tuple<unsigned, unsigned>> loops) {
+    llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false);
+    for (auto l : loops) {
+      auto loop = std::get<0>(l);
+      auto tensor = std::get<1>(l);
+      testBits.set(numTensors * loop + tensor);
+    }
+    return testBits;
+  }
+
+  /// Returns true if the bits of lattice point p in set s match the given bits.
+  bool compareBits(unsigned s, unsigned p, llvm::BitVector bits) {
+    return merger.lat(merger.set(s)[p]).bits == bits;
+  }
+
+  /// Check that there are n lattice points in set s.
+  void expectNumLatPoints(unsigned s, unsigned n) {
+    EXPECT_THAT(merger.set(s).size(), n);
+  }
+
+  /// Compares expressions for equality. Equality is defined recursively as:
+  /// - Two expressions can only be equal if they have the same Kind.
+  /// - Two binary expressions are equal if they have the same Kind and their
+  ///     children are equal.
+  /// - Expressions with Kind invariant or tensor are equal if they have the
+  ///     same expression id.
+  bool compareExpression(unsigned e, std::shared_ptr<Pattern> pattern) {
+    auto tensorExp = merger.exp(e);
+    if (tensorExp.kind != pattern->kind)
+      return false;
+    assert(tensorExp.kind != Kind::kInvariant &&
+           "Invariant comparison not yet supported");
+    switch (tensorExp.kind) {
+    case Kind::kTensor:
+      return tensorExp.tensor == pattern->tensorNum;
+    case Kind::kZero:
+      return true;
+    case Kind::kMulF:
+    case Kind::kMulI:
+    case Kind::kAddF:
+    case Kind::kAddI:
+    case Kind::kSubF:
+    case Kind::kSubI:
+      return compareExpression(tensorExp.children.e0, pattern->e0) &&
+             compareExpression(tensorExp.children.e1, pattern->e1);
+    default:
+      llvm_unreachable("Unhandled Kind");
+    }
+  }
+
+  unsigned numTensors;
+  unsigned numLoops;
+  Merger merger;
+};
+
+class MergerTest3T1L : public MergerTestBase {
+protected:
+  // Our three tensors (two inputs, one output).
+  const unsigned t0 = 0, t1 = 1, t2 = 2;
+
+  // Our single loop.
+  const unsigned l0 = 0;
+
+  MergerTest3T1L() : MergerTestBase(3, 1) {
+    // Tensor 0: sparse input vector.
+    merger.addExp(Kind::kTensor, t0, -1u);
+    merger.setDim(t0, l0, Dim::kSparse);
+
+    // Tensor 1: sparse input vector.
+    merger.addExp(Kind::kTensor, t1, -1u);
+    merger.setDim(t1, l0, Dim::kSparse);
+
+    // Tensor 2: dense output vector.
+    merger.addExp(Kind::kTensor, t2, -1u);
+    merger.setDim(t2, l0, Dim::kDense);
+  }
+};
+
+} // anonymous namespace
+
+/// Vector addition of 2 vectors, i.e.:
+///   a(i) = b(i) + c(i)
+/// which should form the 3 lattice points
+/// {
+///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
+///   lat( i_00 / tensor_0 )
+///   lat( i_01 / tensor_1 )
+/// }
+/// and after optimization, will reduce to the 2 lattice points
+/// {
+///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
+///   lat( i_00 / tensor_0 )
+/// }
+TEST_F(MergerTest3T1L, VectorAdd2) {
+  // Construct expression.
+  auto e = addf(tensor(t0), tensor(t1));
+
+  // Build lattices and check.
+  auto s = merger.buildLattices(e, l0);
+  expectNumLatPoints(s, 3);
+  expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
+                 loopsToBits({{l0, t0}, {l0, t1}}));
+  expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
+                            loopsToBits({{l0, t0}}));
+  expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
+                            loopsToBits({{l0, t1}}));
+
+  // Optimize lattices and check.
+  s = merger.optimizeSet(s);
+  expectNumLatPoints(s, 3);
+  expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
+                 loopsToBits({{l0, t0}, {l0, t1}}));
+  expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
+                            loopsToBits({{l0, t0}}));
+  expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
+                            loopsToBits({{l0, t1}}));
+}
+
+/// Vector multiplication of 2 vectors, i.e.:
+///   a(i) = b(i) * c(i)
+/// which should form the single lattice point
+/// {
+///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
+/// }
+TEST_F(MergerTest3T1L, VectorMul2) {
+  // Construct expression.
+  auto e = mulf(t0, t1);
+
+  // Build lattices and check.
+  auto s = merger.buildLattices(e, l0);
+  expectNumLatPoints(s, 1);
+  expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
+                 loopsToBits({{l0, t0}, {l0, t1}}));
+
+  // Optimize lattices and check.
+  s = merger.optimizeSet(s);
+  expectNumLatPoints(s, 1);
+  expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
+                 loopsToBits({{l0, t0}, {l0, t1}}));
+}


        


More information about the Mlir-commits mailing list