[Mlir-commits] [mlir] a454d92 - [mlir][sparse] rename files and unifies APIs (#88162)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 9 10:59:18 PDT 2024
Author: Peiming Liu
Date: 2024-04-09T10:59:15-07:00
New Revision: a454d92c5ac906d391b683661ac3d9a362ab0107
URL: https://github.com/llvm/llvm-project/commit/a454d92c5ac906d391b683661ac3d9a362ab0107
DIFF: https://github.com/llvm/llvm-project/commit/a454d92c5ac906d391b683661ac3d9a362ab0107.diff
LOG: [mlir][sparse] rename files and unifies APIs (#88162)
Added:
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
Removed:
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..af3a1b48f45af9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -20,7 +20,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
Utils/IterationGraphSorter.cpp
Utils/LoopEmitter.cpp
Utils/SparseTensorDescriptor.cpp
- Utils/SparseTensorLevel.cpp
+ Utils/SparseTensorIterator.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index b5a0ac8484abdd..59c3e49264dbe1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -11,7 +11,7 @@
#include <vector>
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
similarity index 96%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index bc27fae5d19480..60dca3c55dec3d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -1,4 +1,4 @@
-//===- SparseTensorLevel.cpp - Tensor management class -------------------===//
+//===- SparseTensorIterator.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
#include "CodegenUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;
namespace {
+template <bool hasPosBuffer>
class SparseLevel : public SparseTensorLevel {
+ // It is either an array of size 2 or size 1 depending on whether the sparse
+ // level requires a position array.
+ using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
+ std::array<Value, 1>>;
+
public:
SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
- : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
+ BufferT buffers)
+ : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
+
+ ValueRange getLvlBuffers() const override { return buffers; }
Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value iv) const override {
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(iv);
- return genIndexLoad(b, l, crdBuffer, memCrd);
+ return genIndexLoad(b, l, getCrdBuf(), memCrd);
}
protected:
- const Value crdBuffer;
+ template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
+ Value getPosBuf() const {
+ return buffers[0];
+ }
+
+ Value getCrdBuf() const {
+ if constexpr (hasPosBuffer)
+ return buffers[1];
+ else
+ return buffers[0];
+ }
+
+ const BufferT buffers;
};
class DenseLevel : public SparseTensorLevel {
@@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}
+ ValueRange getLvlBuffers() const override { return {}; }
+
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
Value posLo = MULI(p, lvlSize);
@@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}
+ ValueRange getLvlBuffers() const override { return {}; }
+
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
@@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
}
};
-class CompressedLevel : public SparseLevel {
+class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
@@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
-
-private:
- const Value posBuffer;
};
-class LooseCompressedLevel : public SparseLevel {
+class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
@@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {
p = MULI(p, C_IDX(2));
memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
-
-private:
- const Value posBuffer;
};
-class SingletonLevel : public SparseLevel {
+class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value segHi) const override {
@@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
}
};
-class NOutOfMLevel : public SparseLevel {
+class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
similarity index 97%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 9f92eecdf75cb6..9d69a233555986 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -1,4 +1,4 @@
-//===- SparseTensorLevel.h --------------------------------------*- C++ -*-===//
+//===- SparseTensorIterator.h ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
-#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -55,6 +55,7 @@ class SparseTensorLevel {
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
Value getSize() const { return lvlSize; }
+ virtual ValueRange getLvlBuffers() const = 0;
//
// Level properties
@@ -321,4 +322,4 @@ std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
} // namespace sparse_tensor
} // namespace mlir
-#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
More information about the Mlir-commits
mailing list