[Mlir-commits] [mlir] 01dffc5 - [mlir][sparse] Favors defined dimension when optimize lattice points.
Peiming Liu
llvmlistbot at llvm.org
Wed Oct 5 18:16:39 PDT 2022
Author: Peiming Liu
Date: 2022-10-06T01:16:30Z
New Revision: 01dffc5ae8a1450fd733a578f3363822385ebc18
URL: https://github.com/llvm/llvm-project/commit/01dffc5ae8a1450fd733a578f3363822385ebc18
DIFF: https://github.com/llvm/llvm-project/commit/01dffc5ae8a1450fd733a578f3363822385ebc18.diff
LOG: [mlir][sparse] Favors defined dimension when optimize lattice points.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D135337
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 3791971501b0a..60fab7b5d0070 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -262,14 +262,24 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
break;
}
}
- // Now apply the two basic rules.
+
BitVector simple = latPoints[p0].bits;
bool reset = isSingleton && hasAnySparse(simple);
+ unsigned offset = 0;
+ if (!reset)
+ // Starts resetting from a dense dimension, so that the first bit (if kept)
+ // is not undefined dimension type.
+ for (unsigned b = 0, be = simple.size(); b < be; b++)
+ if (simple[b] && isDimLevelType(b, DimLvlType::kDense))
+ offset = b;
+
+ // Now apply the two basic rules.
for (unsigned b = 0, be = simple.size(); b < be; b++) {
- if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) &&
- !isDimLevelType(b, DimLvlType::kSingleton))) {
+ unsigned i = (offset + b) % be;
+ if (simple[i] && (!isDimLevelType(i, DimLvlType::kCompressed) &&
+ !isDimLevelType(i, DimLvlType::kSingleton))) {
if (reset)
- simple.reset(b);
+ simple.reset(i);
reset = true;
}
}
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index c0e75dc3f0e78..9851f456d4b55 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -377,8 +377,64 @@ class MergerTest3T1LD : public MergerTestBase {
}
};
+///
+/// Tests with both undef and dense input.
+///
+class MergerTest3T1LU : public MergerTestBase {
+protected:
+ // Our three tensors (two inputs, one output).
+ const unsigned t0 = 0, t1 = 1, t2 = 2;
+
+ // Our single loop.
+ const unsigned l0 = 0;
+
+ MergerTest3T1LU() : MergerTestBase(3, 1) {
+ // Tensor 0: undef input vector.
+ merger.addExp(Kind::kTensor, t0, -1u);
+ merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef));
+
+ // Tensor 1: dense input vector.
+ merger.addExp(Kind::kTensor, t1, -1u);
+ merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense));
+
+ // Tensor 2: dense output vector.
+ merger.addExp(Kind::kTensor, t2, -1u);
+ merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense));
+ }
+};
} // namespace
+/// Vector multiplication (conjunction) of 2 vectors, i.e.;
+/// a(i) = b(i) * c(i)
+/// which should form the single lattice point
+/// {
+/// lat( i_00_U i_01_D / (tensor_0 * tensor_1) )
+/// }
+/// after optimization, the dense dimesion should be kept, despite it appears
+/// after the undef dimension
+/// {
+/// lat( i_01_D / (tensor_0 * tensor_1) )
+/// }
+#define IMPL_MERGER_TEST_CONJ(OP) \
+ TEST_F(MergerTest3T1LU, vector_##OP) { \
+ auto e = OP##Expr(t0, t1); \
+ auto p0 = tensorPattern(t0); \
+ auto p1 = tensorPattern(t1); \
+ auto s = merger.buildLattices(e, l0); \
+ \
+ expectNumLatPoints(s, 1); \
+ expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
+ loopsToBits({{l0, t0}, {l0, t1}})); \
+ \
+ s = merger.optimizeSet(s); \
+ expectNumLatPoints(s, 1); \
+ expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t1}}), \
+ true); \
+ }
+FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
+
+#undef IMPL_MERGER_TEST_CONJ
+
/// Vector addition (disjunction) of 2 vectors. i.e.;
/// a(i) = b(i) + c(i)
/// which should form the 3 lattice points
More information about the Mlir-commits
mailing list