[Mlir-commits] [mlir] d30dccd - [mlir][sparse] Favors synthetic tensor over other undefined tensors

Peiming Liu llvmlistbot at llvm.org
Thu Oct 6 13:51:46 PDT 2022


Author: Peiming Liu
Date: 2022-10-06T20:51:38Z
New Revision: d30dccd21f2c3bc1ed6cd054c131436a1af548e1

URL: https://github.com/llvm/llvm-project/commit/d30dccd21f2c3bc1ed6cd054c131436a1af548e1
DIFF: https://github.com/llvm/llvm-project/commit/d30dccd21f2c3bc1ed6cd054c131436a1af548e1.diff

LOG: [mlir][sparse] Favors synthetic tensor over other undefined tensors

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135385

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 60fab7b5d0070..187a6c0b188b2 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -265,21 +265,26 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
 
   BitVector simple = latPoints[p0].bits;
   bool reset = isSingleton && hasAnySparse(simple);
-  unsigned offset = 0;
+  unsigned be = simple.size();
+  unsigned offset = 0; // relative to the end
   if (!reset)
     // Starts resetting from a dense dimension, so that the first bit (if kept)
     // is not undefined dimension type.
-    for (unsigned b = 0, be = simple.size(); b < be; b++)
-      if (simple[b] && isDimLevelType(b, DimLvlType::kDense))
-        offset = b;
+    for (unsigned b = 0; b < be; b++) {
+      if (simple[b] && isDimLevelType(b, DimLvlType::kDense)) {
+        offset = be - b - 1; // relative to the end
+        break;
+      }
+    }
 
-  // Now apply the two basic rules.
-  for (unsigned b = 0, be = simple.size(); b < be; b++) {
-    unsigned i = (offset + b) % be;
-    if (simple[i] && (!isDimLevelType(i, DimLvlType::kCompressed) &&
-                      !isDimLevelType(i, DimLvlType::kSingleton))) {
+  // Now apply the two basic rules. We also iterate the bits reversely to always
+  // keep the rightmost bit (which could possibly be a synthetic tensor).
+  for (unsigned b = be - 1 - offset, i = 0; i < be;
+       b = b == 0 ? be - 1 : b - 1, i++) {
+    if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) &&
+                      !isDimLevelType(b, DimLvlType::kSingleton))) {
       if (reset)
-        simple.reset(i);
+        simple.reset(b);
       reset = true;
     }
   }

diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 9851f456d4b55..4e95b8f6b2ebb 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -380,15 +380,16 @@ class MergerTest3T1LD : public MergerTestBase {
 ///
 /// Tests with both undef and dense input.
 ///
-class MergerTest3T1LU : public MergerTestBase {
+
+class MergerTest4T1LU : public MergerTestBase {
 protected:
   // Our three tensors (two inputs, one output).
-  const unsigned t0 = 0, t1 = 1, t2 = 2;
+  const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
 
   // Our single loop.
   const unsigned l0 = 0;
 
-  MergerTest3T1LU() : MergerTestBase(3, 1) {
+  MergerTest4T1LU() : MergerTestBase(4, 1) {
     // Tensor 0: undef input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
     merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef));
@@ -397,43 +398,110 @@ class MergerTest3T1LU : public MergerTestBase {
     merger.addExp(Kind::kTensor, t1, -1u);
     merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense));
 
-    // Tensor 2: dense output vector.
+    // Tensor 2: undef input vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense));
+    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kUndef));
+
+    // Tensor 3: dense output vector.
+    merger.addExp(Kind::kTensor, t3, -1u);
+    merger.setDimLevelFormat(t3, l0, DimLevelFormat(DimLvlType::kDense));
+  }
+};
+
+///
+/// Tests with operation on sparse output.
+///
+
+class MergerTest3T1L_SO : public MergerTestBase {
+protected:
+  // Our three tensors (two inputs, one output, one synthetic).
+  const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
+
+  // Our single loop.
+  const unsigned l0 = 0;
+
+  MergerTest3T1L_SO() : MergerTestBase(3, 1) {
+    merger.setHasSparseOut(true);
+
+    // Tensor 0: undef input vector.
+    merger.addExp(Kind::kTensor, t0, -1u);
+    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef));
+
+    // Tensor 1: undef input vector.
+    merger.addExp(Kind::kTensor, t1, -1u);
+    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kUndef));
+
+    // Tensor 2: sparse output vector.
+    merger.addExp(Kind::kTensor, t2, -1u);
+    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kCompressed));
   }
 };
+
 } // namespace
 
-/// Vector multiplication (conjunction) of 2 vectors, i.e.;
-///   a(i) = b(i) * c(i)
+/// Vector multiplication (conjunction) of 3 vectors, i.e.;
+///   a(i) = b(i) * c(i) * d(i)
 /// which should form the single lattice point
 /// {
-///   lat( i_00_U i_01_D / (tensor_0 * tensor_1) )
+///   lat( i_00_U i_01_D i_02_U / (tensor_0 * tensor_1 * tensor2) )
 /// }
 /// after optimization, the dense dimesion should be kept, despite it appears
-/// after the undef dimension
+/// in the middle
 /// {
-///   lat( i_01_D / (tensor_0 * tensor_1) )
+///   lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
 /// }
-#define IMPL_MERGER_TEST_CONJ(OP)                                              \
-  TEST_F(MergerTest3T1LU, vector_##OP) {                                       \
-    auto e = OP##Expr(t0, t1);                                                 \
+#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2)                         \
+  TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) {                          \
+    auto em = CONJ1##Expr(t0, t1);                                             \
+    auto e = CONJ2##Expr(em, t2);                                              \
     auto p0 = tensorPattern(t0);                                               \
     auto p1 = tensorPattern(t1);                                               \
+    auto p2 = tensorPattern(t2);                                               \
     auto s = merger.buildLattices(e, l0);                                      \
-                                                                               \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
-                   loopsToBits({{l0, t0}, {l0, t1}}));                         \
-                                                                               \
+    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
+                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
     s = merger.optimizeSet(s);                                                 \
     expectNumLatPoints(s, 1);                                                  \
-    expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t1}}),    \
-                   true);                                                      \
+    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
+                   loopsToBits({{l0, t1}}), true);                             \
   }
-FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
 
-#undef IMPL_MERGER_TEST_CONJ
+FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
+
+#undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
+
+/// Vector multiplication (conjunction) of 2 vectors, i.e.;
+///   o(i) = b(i) * c(i) * o(i)
+/// which should form the single lattice point (note how a synthetic tensor
+/// i_03_U is created for the sparse output)
+/// {
+///   lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) )
+/// }
+/// after optimization, the synthetic tensor should be preserved.
+/// {
+///   lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
+/// }
+#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2)                    \
+  TEST_F(MergerTest3T1L_SO, vector_##CONJ1##_##CONJ2) {                        \
+    auto em = CONJ1##Expr(t0, t1);                                             \
+    auto e = CONJ2##Expr(em, t2);                                              \
+    auto p0 = tensorPattern(t0);                                               \
+    auto p1 = tensorPattern(t1);                                               \
+    auto p2 = tensorPattern(t2);                                               \
+    auto s = merger.buildLattices(e, l0);                                      \
+    expectNumLatPoints(s, 1);                                                  \
+    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
+                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}}));               \
+    s = merger.optimizeSet(s);                                                 \
+    expectNumLatPoints(s, 1);                                                  \
+    expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
+                   loopsToBits({{l0, t3}}), true);                             \
+  }
+
+FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
+
+#undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
 
 /// Vector addition (disjunction) of 2 vectors. i.e.;
 ///   a(i) = b(i) + c(i)


        


More information about the Mlir-commits mailing list