[Mlir-commits] [mlir] [mlir][sparse] add merger support on Batch LevelType. (PR #83186)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 27 13:09:08 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/83186.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+15-3) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (+1-2) 
- (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+3-5) 
- (modified) mlir/unittests/Dialect/SparseTensor/MergerTest.cpp (+40-18) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 9e79b6aca1c9ba..5563cb907e9353 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -333,16 +333,28 @@ struct LevelType {
     return lvlBits & static_cast<uint64_t>(p);
   }
 
+  /// Check if the `LevelType` is considered to be sparse.
+  constexpr bool hasSparseSemantic() const {
+    return isa<LevelFormat::Compressed, LevelFormat::Singleton,
+               LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
+  }
+
+  /// Check if the `LevelType` is considered to be dense-like.
+  constexpr bool hasDenseSemantic() const {
+    return isa<LevelFormat::Dense, LevelFormat::Batch>();
+  }
+
   /// Check if the `LevelType` needs positions array.
   constexpr bool isWithPosLT() const {
-    return isa<LevelFormat::Compressed>() ||
-           isa<LevelFormat::LooseCompressed>();
+    assert(!isa<LevelFormat::Undef>());
+    return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
   }
 
   /// Check if the `LevelType` needs coordinates array.
   constexpr bool isWithCrdLT() const {
+    assert(!isa<LevelFormat::Undef>());
     // All sparse levels has coordinate array.
-    return !isa<LevelFormat::Dense, LevelFormat::Batch>();
+    return hasSparseSemantic();
   }
 
   std::string toMLIRString() const {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 490ef3071af1b7..7f9820df984b29 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -509,8 +509,7 @@ class Merger {
   bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
     if (isLvlWithNonTrivialIdxExp(b)) {
       auto lt = getLoopDependentLevelType(b);
-      return isCompressedLT(lt) || isSingletonLT(lt) ||
-             isLooseCompressedLT(lt) || isNOutOfMLT(lt);
+      return lt.hasSparseSemantic();
     }
     return false;
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 731cd79a1e3b4b..72b722c69ae34b 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -476,7 +476,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
     // Starts resetting from a dense level, so that the first bit (if kept)
     // is not undefined level-type.
     for (unsigned b = 0; b < be; b++) {
-      if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) {
+      if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
         offset = be - b - 1; // relative to the end
         break;
       }
@@ -489,8 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
     // Slice on dense level has `locate` property as well, and can be optimized.
     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
       const auto lt = getLvlType(b);
-      if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
-          !isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
+      if (!lt.hasSparseSemantic()) {
         if (reset)
           simple.reset(b);
         reset = true;
@@ -670,8 +669,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
 bool Merger::hasAnySparse(const BitVector &bits) const {
   for (TensorLoopId b : bits.set_bits()) {
     const auto lt = getLvlType(b);
-    if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
-        isNOutOfMLT(lt))
+    if (lt.hasSparseSemantic())
       return true;
   }
   return hasSparseIdxReduction(bits);
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 62a19c084cac0f..943e7d5c120b87 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -120,7 +120,8 @@ static Match synZeroMatch() { return Match(); }
 FOREVERY_BINOP(IMPL_BINOP_PATTERN)
 #undef IMPL_BINOP_PATTERN
 
-class MergerTestBase : public ::testing::Test {
+// Parameterize LevelFormat to test both Dense and Batch LevelFormat.
+class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
 protected:
   MergerTestBase(unsigned numTensors, unsigned numLoops)
       : merger(numTensors, numLoops, /*maxRank=*/numLoops) {
@@ -317,10 +318,14 @@ class MergerTest3T1L : public MergerTestBase {
     // Tensor 1: sparse input vector.
     merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
     // Tensor 2: dense output vector.
-    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
   }
 };
 
+INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
+                         ::testing::Values(LevelFormat::Dense,
+                                           LevelFormat::Batch));
+
 /// Four tensors (three inputs, one output); and a single loop.
 class MergerTest4T1L : public MergerTestBase {
 protected:
@@ -333,10 +338,14 @@ class MergerTest4T1L : public MergerTestBase {
     // Tensor 2: sparse input vector
     merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
     // Tensor 3: dense output vector
-    merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
   }
 };
 
+INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
+                         ::testing::Values(LevelFormat::Dense,
+                                           LevelFormat::Batch));
+
 ///
 /// Tests with both sparse and dense input.
 ///
@@ -349,12 +358,16 @@ class MergerTest3T1LD : public MergerTestBase {
     // Tensor 0: sparse input vector.
     merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
     // Tensor 1: dense input vector.
-    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
     // Tensor 2: dense output vector.
-    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
   }
 };
 
+INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
+                         ::testing::Values(LevelFormat::Dense,
+                                           LevelFormat::Batch));
+
 ///
 /// Tests with both undef and dense input.
 ///
@@ -367,14 +380,18 @@ class MergerTest4T1LU : public MergerTestBase {
     // Tensor 0: undef input vector.
     merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
     // Tensor 1: dense input vector.
-    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
     // Tensor 2: undef input vector.
     merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
     // Tensor 3: dense output vector.
-    merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
   }
 };
 
+INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
+                         ::testing::Values(LevelFormat::Dense,
+                                           LevelFormat::Batch));
+
 ///
 /// Tests with operation on sparse output.
 ///
@@ -395,6 +412,11 @@ class MergerTest3T1LSo : public MergerTestBase {
   }
 };
 
+// This testsuite does not use any dense-like format, just one of {Dense, Batch}
+// is enough.
+INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
+                         ::testing::Values(LevelFormat::Dense));
+
 } // namespace
 
 /// Vector multiplication (conjunction) of 3 vectors, i.e.;
@@ -409,7 +431,7 @@ class MergerTest3T1LSo : public MergerTestBase {
 ///   lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
 /// }
 #define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2)                         \
-  TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) {                          \
+  TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) {                          \
     const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
     const auto e = CONJ2##Expr(em, tensor(2));                                 \
     const auto l0 = lid(0);                                                    \
@@ -443,7 +465,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
 ///   lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
 /// }
 #define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2)                    \
-  TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) {                         \
+  TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) {                         \
     const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
     const auto e = CONJ2##Expr(em, tensor(2));                                 \
     const auto l0 = lid(0);                                                    \
@@ -482,7 +504,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
 ///   lat( i_01 / tensor_1 )
 /// }
 #define IMPL_MERGER_TEST_DISJ(OP, UNUSED)                                      \
-  TEST_F(MergerTest3T1L, vector_##OP) {                                        \
+  TEST_P(MergerTest3T1L, vector_##OP) {                                        \
     const auto e = OP##Expr(tensor(0), tensor(1));                             \
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
@@ -514,7 +536,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
 ///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
 /// }
 #define IMPL_MERGER_TEST_CONJ(OP, UNUSED)                                      \
-  TEST_F(MergerTest3T1L, vector_##OP) {                                        \
+  TEST_P(MergerTest3T1L, vector_##OP) {                                        \
     const auto e = OP##Expr(tensor(0), tensor(1));                             \
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
@@ -544,7 +566,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
 ///    lat( i_02 / tensor_2 )
 /// }
 #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ)                                 \
-  TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) {                             \
+  TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) {                             \
     const auto em = CONJ##Expr(tensor(0), tensor(1));                          \
     const auto e = DISJ##Expr(em, tensor(2));                                  \
     const auto l0 = lid(0);                                                    \
@@ -587,7 +609,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
 ///   lat( i_00 / tensor_0 )
 /// }
 #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2)                               \
-  TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) {                           \
+  TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) {                           \
     const auto em = DISJ1##Expr(tensor(0), tensor(1));                         \
     const auto e = DISJ2##Expr(em, tensor(2));                                 \
     const auto l0 = lid(0);                                                    \
@@ -636,7 +658,7 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
 ///    lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
 /// }
 #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2)                               \
-  TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) {                           \
+  TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) {                           \
     const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
     const auto e = CONJ2##Expr(em, tensor(2));                                 \
     const auto l0 = lid(0);                                                    \
@@ -675,7 +697,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
 /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
 /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
 #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED)                            \
-  TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
+  TEST_P(MergerTest3T1LD, vector_opted_##OP) {                                 \
     const auto e = OP##Expr(tensor(0), tensor(1));                             \
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
@@ -711,7 +733,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
 /// }
 /// since i_01 is a dense dimension.
 #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED)                            \
-  TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
+  TEST_P(MergerTest3T1LD, vector_opted_##OP) {                                 \
     const auto e = OP##Expr(tensor(0), tensor(1));                             \
     const auto l0 = lid(0);                                                    \
     const auto t0 = tid(0);                                                    \
@@ -746,7 +768,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
 ///   lat( i_00 / tensor_0 cmp 0 )
 ///   lat( i_01 / 0 cmp tensor_1 )
 /// }
-TEST_F(MergerTest3T1L, vector_cmp) {
+TEST_P(MergerTest3T1L, vector_cmp) {
   const auto e = cmpiExpr(tensor(0), tensor(1));
   const auto l0 = lid(0);
   const auto t0 = tid(0);
@@ -784,7 +806,7 @@ TEST_F(MergerTest3T1L, vector_cmp) {
 ///
 /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
 /// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
-TEST_F(MergerTest3T1LD, vector_cmp) {
+TEST_P(MergerTest3T1LD, vector_cmp) {
   const auto e = cmpiExpr(tensor(0), tensor(1));
   const auto l0 = lid(0);
   const auto t0 = tid(0);

``````````

</details>


https://github.com/llvm/llvm-project/pull/83186


More information about the Mlir-commits mailing list