[Mlir-commits] [mlir] 5f5e019 - [mlir][sparse] add some APIs for merger to query the tensor id for output tensor and synthetic tensor.

Peiming Liu llvmlistbot at llvm.org
Mon Oct 24 11:49:58 PDT 2022


Author: Peiming Liu
Date: 2022-10-24T18:49:53Z
New Revision: 5f5e0199c1a650c41e1ea6e0c18d3c0b29f45023

URL: https://github.com/llvm/llvm-project/commit/5f5e0199c1a650c41e1ea6e0c18d3c0b29f45023
DIFF: https://github.com/llvm/llvm-project/commit/5f5e0199c1a650c41e1ea6e0c18d3c0b29f45023.diff

LOG: [mlir][sparse] add some APIs for merger to query the tensor id for output tensor and synthetic tensor.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136630

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 52456a2b06d5b..7d1b770ca55f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -221,8 +221,9 @@ class Merger {
   /// Returns true if Li and Lj only 
diff er in dense.
   bool onlyDenseDiff(unsigned i, unsigned j);
 
-  /// Bit translation.
+  /// Bit translation (get tensor ID).
   unsigned tensor(unsigned b) const { return b % numTensors; }
+  /// Bit translation (get loop index).
   unsigned index(unsigned b) const { return b / numTensors; }
 
   /// Returns true if bit corresponds to index of output tensor.
@@ -230,6 +231,12 @@ class Merger {
     return tensor(b) == outTensor && index(b) == i;
   }
 
+  /// Gets tensor ID for the output tensor.
+  unsigned getOutTensorID() const { return outTensor; }
+  /// Gets tensor ID for the synthetic tensor (used for all invariant tensor
+  /// expressions).
+  unsigned getSynTensorID() const { return syntheticTensor; }
+
   /// Returns true if given tensor iterates *only* in the given tensor
   /// expression. For the output tensor, this defines a "simply dynamic"
   /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for

diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 0e0882fa9fc18..0ce2740485956 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -309,6 +309,8 @@ class MergerTest3T1L : public MergerTestBase {
   const unsigned l0 = 0;
 
   MergerTest3T1L() : MergerTestBase(3, 1) {
+    EXPECT_TRUE(merger.getOutTensorID() == t2);
+
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
     merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
@@ -332,6 +334,8 @@ class MergerTest4T1L : public MergerTestBase {
   const unsigned l0 = 0;
 
   MergerTest4T1L() : MergerTestBase(4, 1) {
+    EXPECT_TRUE(merger.getOutTensorID() == t3);
+
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
     merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
@@ -363,6 +367,8 @@ class MergerTest3T1LD : public MergerTestBase {
   const unsigned l0 = 0;
 
   MergerTest3T1LD() : MergerTestBase(3, 1) {
+    EXPECT_TRUE(merger.getOutTensorID() == t2);
+
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
     merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
@@ -383,13 +389,15 @@ class MergerTest3T1LD : public MergerTestBase {
 
 class MergerTest4T1LU : public MergerTestBase {
 protected:
-  // Our three tensors (two inputs, one output).
+  // Our three tensors (three inputs, one output).
   const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
 
   // Our single loop.
   const unsigned l0 = 0;
 
   MergerTest4T1LU() : MergerTestBase(4, 1) {
+    EXPECT_TRUE(merger.getOutTensorID() == t3);
+
     // Tensor 0: undef input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
     merger.setDimLevelType(t0, l0, DimLevelType::Undef);
@@ -421,6 +429,9 @@ class MergerTest3T1L_SO : public MergerTestBase {
   const unsigned l0 = 0;
 
   MergerTest3T1L_SO() : MergerTestBase(3, 1) {
+    EXPECT_TRUE(merger.getOutTensorID() == t2);
+    EXPECT_TRUE(merger.getSynTensorID() == t3);
+
     merger.setHasSparseOut(true);
 
     // Tensor 0: undef input vector.


        


More information about the Mlir-commits mailing list