[Mlir-commits] [mlir] 5f5e019 - [mlir][sparse] add some APIs for merger to query the tensor id for output tensor and synthetic tensor.
Peiming Liu
llvmlistbot at llvm.org
Mon Oct 24 11:49:58 PDT 2022
Author: Peiming Liu
Date: 2022-10-24T18:49:53Z
New Revision: 5f5e0199c1a650c41e1ea6e0c18d3c0b29f45023
URL: https://github.com/llvm/llvm-project/commit/5f5e0199c1a650c41e1ea6e0c18d3c0b29f45023
DIFF: https://github.com/llvm/llvm-project/commit/5f5e0199c1a650c41e1ea6e0c18d3c0b29f45023.diff
LOG: [mlir][sparse] add some APIs for merger to query the tensor id for output tensor and synthetic tensor.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D136630
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 52456a2b06d5b..7d1b770ca55f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -221,8 +221,9 @@ class Merger {
/// Returns true if Li and Lj only
diff er in dense.
bool onlyDenseDiff(unsigned i, unsigned j);
- /// Bit translation.
+ /// Bit translation (get tensor ID).
unsigned tensor(unsigned b) const { return b % numTensors; }
+ /// Bit translation (get loop index).
unsigned index(unsigned b) const { return b / numTensors; }
/// Returns true if bit corresponds to index of output tensor.
@@ -230,6 +231,12 @@ class Merger {
return tensor(b) == outTensor && index(b) == i;
}
+ /// Gets tensor ID for the output tensor.
+ unsigned getOutTensorID() const { return outTensor; }
+ /// Gets tensor ID for the synthetic tensor (used for all invariant tensor
+ /// expressions).
+ unsigned getSynTensorID() const { return syntheticTensor; }
+
/// Returns true if given tensor iterates *only* in the given tensor
/// expression. For the output tensor, this defines a "simply dynamic"
/// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 0e0882fa9fc18..0ce2740485956 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -309,6 +309,8 @@ class MergerTest3T1L : public MergerTestBase {
const unsigned l0 = 0;
MergerTest3T1L() : MergerTestBase(3, 1) {
+ EXPECT_TRUE(merger.getOutTensorID() == t2);
+
// Tensor 0: sparse input vector.
merger.addExp(Kind::kTensor, t0, -1u);
merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
@@ -332,6 +334,8 @@ class MergerTest4T1L : public MergerTestBase {
const unsigned l0 = 0;
MergerTest4T1L() : MergerTestBase(4, 1) {
+ EXPECT_TRUE(merger.getOutTensorID() == t3);
+
// Tensor 0: sparse input vector.
merger.addExp(Kind::kTensor, t0, -1u);
merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
@@ -363,6 +367,8 @@ class MergerTest3T1LD : public MergerTestBase {
const unsigned l0 = 0;
MergerTest3T1LD() : MergerTestBase(3, 1) {
+ EXPECT_TRUE(merger.getOutTensorID() == t2);
+
// Tensor 0: sparse input vector.
merger.addExp(Kind::kTensor, t0, -1u);
merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
@@ -383,13 +389,15 @@ class MergerTest3T1LD : public MergerTestBase {
class MergerTest4T1LU : public MergerTestBase {
protected:
- // Our three tensors (two inputs, one output).
+ // Our three tensors (three inputs, one output).
const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
// Our single loop.
const unsigned l0 = 0;
MergerTest4T1LU() : MergerTestBase(4, 1) {
+ EXPECT_TRUE(merger.getOutTensorID() == t3);
+
// Tensor 0: undef input vector.
merger.addExp(Kind::kTensor, t0, -1u);
merger.setDimLevelType(t0, l0, DimLevelType::Undef);
@@ -421,6 +429,9 @@ class MergerTest3T1L_SO : public MergerTestBase {
const unsigned l0 = 0;
MergerTest3T1L_SO() : MergerTestBase(3, 1) {
+ EXPECT_TRUE(merger.getOutTensorID() == t2);
+ EXPECT_TRUE(merger.getSynTensorID() == t3);
+
merger.setHasSparseOut(true);
// Tensor 0: undef input vector.
More information about the Mlir-commits
mailing list