[Mlir-commits] [mlir] [mlir][sparse] rename files and unifies APIs (PR #88162)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 9 10:35:02 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/88162.diff
4 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+1-1)
- (renamed) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (+41-23)
- (renamed) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h (+1)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..af3a1b48f45af9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -20,7 +20,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
Utils/IterationGraphSorter.cpp
Utils/LoopEmitter.cpp
Utils/SparseTensorDescriptor.cpp
- Utils/SparseTensorLevel.cpp
+ Utils/SparseTensorIterator.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index b5a0ac8484abdd..59c3e49264dbe1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -11,7 +11,7 @@
#include <vector>
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
similarity index 96%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index bc27fae5d19480..3f2ee89c854d69 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
#include "CodegenUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;
namespace {
+template <bool hasPosBuffer>
class SparseLevel : public SparseTensorLevel {
+ // It is either a array of size 2 or size 1 depending on whether the space
+ // level requires a position array.
+ using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
+ std::array<Value, 1>>;
+
public:
SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
- : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
+ BufferT buffers)
+ : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
+
+ ValueRange getLvlBuffers() const override { return buffers; }
Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value iv) const override {
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(iv);
- return genIndexLoad(b, l, crdBuffer, memCrd);
+ return genIndexLoad(b, l, getCrdBuf(), memCrd);
}
protected:
- const Value crdBuffer;
+ template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
+ Value getPosBuf() const {
+ return buffers[0];
+ }
+
+ Value getCrdBuf() const {
+ if constexpr (hasPosBuffer)
+ return buffers[1];
+ else
+ return buffers[0];
+ }
+
+ const BufferT buffers;
};
class DenseLevel : public SparseTensorLevel {
@@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}
+ ValueRange getLvlBuffers() const override { return {}; }
+
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
Value posLo = MULI(p, lvlSize);
@@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}
+ ValueRange getLvlBuffers() const override { return {}; }
+
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
@@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
}
};
-class CompressedLevel : public SparseLevel {
+class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
@@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
-
-private:
- const Value posBuffer;
};
-class LooseCompressedLevel : public SparseLevel {
+class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
@@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {
p = MULI(p, C_IDX(2));
memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
-
-private:
- const Value posBuffer;
};
-class SingletonLevel : public SparseLevel {
+class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value segHi) const override {
@@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
}
};
-class NOutOfMLevel : public SparseLevel {
+class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
similarity index 99%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 9f92eecdf75cb6..19c0dc942ca62f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -55,6 +55,7 @@ class SparseTensorLevel {
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
Value getSize() const { return lvlSize; }
+ virtual ValueRange getLvlBuffers() const = 0;
//
// Level properties
``````````
</details>
https://github.com/llvm/llvm-project/pull/88162
More information about the Mlir-commits
mailing list