[Mlir-commits] [mlir] d82e93e - [mlir][sparse] add merger support on Batch LevelType. (#83186)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 27 13:18:47 PST 2024
Author: Peiming Liu
Date: 2024-02-27T13:18:43-08:00
New Revision: d82e93e7f129d9e8b72570efdf4a15d6ec3d4336
URL: https://github.com/llvm/llvm-project/commit/d82e93e7f129d9e8b72570efdf4a15d6ec3d4336
DIFF: https://github.com/llvm/llvm-project/commit/d82e93e7f129d9e8b72570efdf4a15d6ec3d4336.diff
LOG: [mlir][sparse] add merger support on Batch LevelType. (#83186)
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 9e79b6aca1c9ba..5563cb907e9353 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -333,16 +333,28 @@ struct LevelType {
return lvlBits & static_cast<uint64_t>(p);
}
+ /// Check if the `LevelType` is considered to be sparse.
+ constexpr bool hasSparseSemantic() const {
+ return isa<LevelFormat::Compressed, LevelFormat::Singleton,
+ LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
+ }
+
+ /// Check if the `LevelType` is considered to be dense-like.
+ constexpr bool hasDenseSemantic() const {
+ return isa<LevelFormat::Dense, LevelFormat::Batch>();
+ }
+
/// Check if the `LevelType` needs positions array.
constexpr bool isWithPosLT() const {
- return isa<LevelFormat::Compressed>() ||
- isa<LevelFormat::LooseCompressed>();
+ assert(!isa<LevelFormat::Undef>());
+ return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
}
/// Check if the `LevelType` needs coordinates array.
constexpr bool isWithCrdLT() const {
+ assert(!isa<LevelFormat::Undef>());
// All sparse levels has coordinate array.
- return !isa<LevelFormat::Dense, LevelFormat::Batch>();
+ return hasSparseSemantic();
}
std::string toMLIRString() const {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 490ef3071af1b7..7f9820df984b29 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -509,8 +509,7 @@ class Merger {
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
if (isLvlWithNonTrivialIdxExp(b)) {
auto lt = getLoopDependentLevelType(b);
- return isCompressedLT(lt) || isSingletonLT(lt) ||
- isLooseCompressedLT(lt) || isNOutOfMLT(lt);
+ return lt.hasSparseSemantic();
}
return false;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 731cd79a1e3b4b..72b722c69ae34b 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -476,7 +476,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Starts resetting from a dense level, so that the first bit (if kept)
// is not undefined level-type.
for (unsigned b = 0; b < be; b++) {
- if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) {
+ if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
offset = be - b - 1; // relative to the end
break;
}
@@ -489,8 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Slice on dense level has `locate` property as well, and can be optimized.
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
const auto lt = getLvlType(b);
- if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
- !isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
+ if (!lt.hasSparseSemantic()) {
if (reset)
simple.reset(b);
reset = true;
@@ -670,8 +669,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) {
const auto lt = getLvlType(b);
- if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- isNOutOfMLT(lt))
+ if (lt.hasSparseSemantic())
return true;
}
return hasSparseIdxReduction(bits);
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 62a19c084cac0f..943e7d5c120b87 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -120,7 +120,8 @@ static Match synZeroMatch() { return Match(); }
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
#undef IMPL_BINOP_PATTERN
-class MergerTestBase : public ::testing::Test {
+// Parameterize LevelFormat to test both Dense and Batch LevelFormat.
+class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
protected:
MergerTestBase(unsigned numTensors, unsigned numLoops)
: merger(numTensors, numLoops, /*maxRank=*/numLoops) {
@@ -317,10 +318,14 @@ class MergerTest3T1L : public MergerTestBase {
// Tensor 1: sparse input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
// Tensor 2: dense output vector.
- merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
}
};
+INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
+ ::testing::Values(LevelFormat::Dense,
+ LevelFormat::Batch));
+
/// Four tensors (three inputs, one output); and a single loop.
class MergerTest4T1L : public MergerTestBase {
protected:
@@ -333,10 +338,14 @@ class MergerTest4T1L : public MergerTestBase {
// Tensor 2: sparse input vector
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
// Tensor 3: dense output vector
- merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
}
};
+INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
+ ::testing::Values(LevelFormat::Dense,
+ LevelFormat::Batch));
+
///
/// Tests with both sparse and dense input.
///
@@ -349,12 +358,16 @@ class MergerTest3T1LD : public MergerTestBase {
// Tensor 0: sparse input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
// Tensor 1: dense input vector.
- merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
// Tensor 2: dense output vector.
- merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
}
};
+INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
+ ::testing::Values(LevelFormat::Dense,
+ LevelFormat::Batch));
+
///
/// Tests with both undef and dense input.
///
@@ -367,14 +380,18 @@ class MergerTest4T1LU : public MergerTestBase {
// Tensor 0: undef input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
// Tensor 1: dense input vector.
- merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
// Tensor 2: undef input vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
// Tensor 3: dense output vector.
- merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
+ merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
}
};
+INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
+ ::testing::Values(LevelFormat::Dense,
+ LevelFormat::Batch));
+
///
/// Tests with operation on sparse output.
///
@@ -395,6 +412,11 @@ class MergerTest3T1LSo : public MergerTestBase {
}
};
+// This testsuite does not use any dense-like format, just one of {Dense, Batch}
+// is enough.
+INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
+ ::testing::Values(LevelFormat::Dense));
+
} // namespace
/// Vector multiplication (conjunction) of 3 vectors, i.e.;
@@ -409,7 +431,7 @@ class MergerTest3T1LSo : public MergerTestBase {
/// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
- TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
+ TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@@ -443,7 +465,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
/// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
- TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
+ TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@@ -482,7 +504,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
/// lat( i_01 / tensor_1 )
/// }
#define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \
- TEST_F(MergerTest3T1L, vector_##OP) { \
+ TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@@ -514,7 +536,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
/// lat( i_00 i_01 / (tensor_0 * tensor_1) )
/// }
#define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \
- TEST_F(MergerTest3T1L, vector_##OP) { \
+ TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@@ -544,7 +566,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
/// lat( i_02 / tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
- TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
+ TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
const auto em = CONJ##Expr(tensor(0), tensor(1)); \
const auto e = DISJ##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@@ -587,7 +609,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
/// lat( i_00 / tensor_0 )
/// }
#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
- TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
+ TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
const auto em = DISJ1##Expr(tensor(0), tensor(1)); \
const auto e = DISJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@@ -636,7 +658,7 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
/// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
- TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
+ TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@@ -675,7 +697,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense
diff
/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \
- TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
+ TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@@ -711,7 +733,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
/// }
/// since i_01 is a dense dimension.
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \
- TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
+ TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@@ -746,7 +768,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
/// lat( i_00 / tensor_0 cmp 0 )
/// lat( i_01 / 0 cmp tensor_1 )
/// }
-TEST_F(MergerTest3T1L, vector_cmp) {
+TEST_P(MergerTest3T1L, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0);
const auto t0 = tid(0);
@@ -784,7 +806,7 @@ TEST_F(MergerTest3T1L, vector_cmp) {
///
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense
diff
/// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
-TEST_F(MergerTest3T1LD, vector_cmp) {
+TEST_P(MergerTest3T1LD, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0);
const auto t0 = tid(0);
More information about the Mlir-commits
mailing list