[Mlir-commits] [mlir] [mlir][sparse] rename files and unifies APIs (PR #88162)
Peiming Liu
llvmlistbot at llvm.org
Tue Apr 9 10:52:30 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/88162
>From 3dddeb5dc855ca3ac8e8ff52b76dc11646ca05a2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 9 Apr 2024 17:33:17 +0000
Subject: [PATCH 1/2] [mlir][sparse] rename files and unifies APIs
---
.../SparseTensor/Transforms/CMakeLists.txt | 2 +-
.../Transforms/Utils/LoopEmitter.h | 2 +-
...nsorLevel.cpp => SparseTensorIterator.cpp} | 64 ++++++++++++-------
...seTensorLevel.h => SparseTensorIterator.h} | 1 +
4 files changed, 44 insertions(+), 25 deletions(-)
rename mlir/lib/Dialect/SparseTensor/Transforms/Utils/{SparseTensorLevel.cpp => SparseTensorIterator.cpp} (96%)
rename mlir/lib/Dialect/SparseTensor/Transforms/Utils/{SparseTensorLevel.h => SparseTensorIterator.h} (99%)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..af3a1b48f45af9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -20,7 +20,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
Utils/IterationGraphSorter.cpp
Utils/LoopEmitter.cpp
Utils/SparseTensorDescriptor.cpp
- Utils/SparseTensorLevel.cpp
+ Utils/SparseTensorIterator.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index b5a0ac8484abdd..59c3e49264dbe1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -11,7 +11,7 @@
#include <vector>
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
similarity index 96%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index bc27fae5d19480..671a624a2fc39c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
#include "CodegenUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;
namespace {
+template <bool hasPosBuffer>
class SparseLevel : public SparseTensorLevel {
+ // It is either a array of size 2 or size 1 depending on whether the sparse
+ // level requires a position array.
+ using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
+ std::array<Value, 1>>;
+
public:
SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
- : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
+ BufferT buffers)
+ : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
+
+ ValueRange getLvlBuffers() const override { return buffers; }
Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value iv) const override {
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(iv);
- return genIndexLoad(b, l, crdBuffer, memCrd);
+ return genIndexLoad(b, l, getCrdBuf(), memCrd);
}
protected:
- const Value crdBuffer;
+ template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
+ Value getPosBuf() const {
+ return buffers[0];
+ }
+
+ Value getCrdBuf() const {
+ if constexpr (hasPosBuffer)
+ return buffers[1];
+ else
+ return buffers[0];
+ }
+
+ const BufferT buffers;
};
class DenseLevel : public SparseTensorLevel {
@@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}
+ ValueRange getLvlBuffers() const override { return {}; }
+
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
Value posLo = MULI(p, lvlSize);
@@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}
+ ValueRange getLvlBuffers() const override { return {}; }
+
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
@@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
}
};
-class CompressedLevel : public SparseLevel {
+class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
@@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
-
-private:
- const Value posBuffer;
};
-class LooseCompressedLevel : public SparseLevel {
+class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
@@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {
p = MULI(p, C_IDX(2));
memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
-
-private:
- const Value posBuffer;
};
-class SingletonLevel : public SparseLevel {
+class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value segHi) const override {
@@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
}
};
-class NOutOfMLevel : public SparseLevel {
+class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
similarity index 99%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 9f92eecdf75cb6..19c0dc942ca62f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -55,6 +55,7 @@ class SparseTensorLevel {
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
Value getSize() const { return lvlSize; }
+ virtual ValueRange getLvlBuffers() const = 0;
//
// Level properties
>From b87c16f183237a3f439df7f84c45d1bfc631d110 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 9 Apr 2024 17:52:16 +0000
Subject: [PATCH 2/2] address comments
---
.../Transforms/Utils/SparseTensorIterator.cpp | 4 ++--
.../SparseTensor/Transforms/Utils/SparseTensorIterator.h | 8 ++++----
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index 671a624a2fc39c..60dca3c55dec3d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -1,4 +1,4 @@
-//===- SparseTensorLevel.cpp - Tensor management class -------------------===//
+//===- SparseTensorIterator.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -48,7 +48,7 @@ namespace {
template <bool hasPosBuffer>
class SparseLevel : public SparseTensorLevel {
- // It is either a array of size 2 or size 1 depending on whether the sparse
+ // It is either an array of size 2 or size 1 depending on whether the sparse
// level requires a position array.
using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
std::array<Value, 1>>;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 19c0dc942ca62f..9d69a233555986 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -1,4 +1,4 @@
-//===- SparseTensorLevel.h --------------------------------------*- C++ -*-===//
+//===- SparseTensorIterator.h ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
-#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -322,4 +322,4 @@ std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
} // namespace sparse_tensor
} // namespace mlir
-#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
More information about the Mlir-commits
mailing list