[Mlir-commits] [mlir] [mlir][sparse] rename files and unifies APIs (PR #88162)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 9 10:35:02 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/88162.diff


4 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt (+1-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+1-1) 
- (renamed) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (+41-23) 
- (renamed) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h (+1) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..af3a1b48f45af9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -20,7 +20,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   Utils/IterationGraphSorter.cpp
   Utils/LoopEmitter.cpp
   Utils/SparseTensorDescriptor.cpp
-  Utils/SparseTensorLevel.cpp
+  Utils/SparseTensorIterator.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index b5a0ac8484abdd..59c3e49264dbe1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -11,7 +11,7 @@
 
 #include <vector>
 
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
 
 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
similarity index 96%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index bc27fae5d19480..3f2ee89c854d69 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
 #include "CodegenUtils.h"
 
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;
 
 namespace {
 
+template <bool hasPosBuffer>
 class SparseLevel : public SparseTensorLevel {
+  // It is either a array of size 2 or size 1 depending on whether the space
+  // level requires a position array.
+  using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
+                                     std::array<Value, 1>>;
+
 public:
   SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-              Value crdBuffer)
-      : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
+              BufferT buffers)
+      : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
+
+  ValueRange getLvlBuffers() const override { return buffers; }
 
   Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                   Value iv) const override {
     SmallVector<Value> memCrd(batchPrefix);
     memCrd.push_back(iv);
-    return genIndexLoad(b, l, crdBuffer, memCrd);
+    return genIndexLoad(b, l, getCrdBuf(), memCrd);
   }
 
 protected:
-  const Value crdBuffer;
+  template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
+  Value getPosBuf() const {
+    return buffers[0];
+  }
+
+  Value getCrdBuf() const {
+    if constexpr (hasPosBuffer)
+      return buffers[1];
+    else
+      return buffers[0];
+  }
+
+  const BufferT buffers;
 };
 
 class DenseLevel : public SparseTensorLevel {
@@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
     llvm_unreachable("locate random-accessible level instead");
   }
 
+  ValueRange getLvlBuffers() const override { return {}; }
+
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
                         Value max) const override {
     Value posLo = MULI(p, lvlSize);
@@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
     llvm_unreachable("locate random-accessible level instead");
   }
 
+  ValueRange getLvlBuffers() const override { return {}; }
+
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
                         Value max) const override {
     assert(max == nullptr && "Dense level can not be non-unique.");
@@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
   }
 };
 
-class CompressedLevel : public SparseLevel {
+class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
 public:
   CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                   Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
@@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {
 
     SmallVector<Value> memCrd(batchPrefix);
     memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
     memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
     return {pLo, pHi};
   }
-
-private:
-  const Value posBuffer;
 };
 
-class LooseCompressedLevel : public SparseLevel {
+class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
 public:
   LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                        Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
@@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {
 
     p = MULI(p, C_IDX(2));
     memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
     memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
     return {pLo, pHi};
   }
-
-private:
-  const Value posBuffer;
 };
 
-class SingletonLevel : public SparseLevel {
+class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
 public:
   SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                  Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value segHi) const override {
@@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
   }
 };
 
-class NOutOfMLevel : public SparseLevel {
+class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
 public:
   NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
similarity index 99%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 9f92eecdf75cb6..19c0dc942ca62f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -55,6 +55,7 @@ class SparseTensorLevel {
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
   Value getSize() const { return lvlSize; }
+  virtual ValueRange getLvlBuffers() const = 0;
 
   //
   // Level properties

``````````

</details>


https://github.com/llvm/llvm-project/pull/88162


More information about the Mlir-commits mailing list