[Mlir-commits] [mlir] [mlir][sparse] rename files and unifies APIs (PR #88162)

Peiming Liu llvmlistbot at llvm.org
Tue Apr 9 10:52:30 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/88162

>From 3dddeb5dc855ca3ac8e8ff52b76dc11646ca05a2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 9 Apr 2024 17:33:17 +0000
Subject: [PATCH 1/2] [mlir][sparse] rename files and unifies APIs

---
 .../SparseTensor/Transforms/CMakeLists.txt    |  2 +-
 .../Transforms/Utils/LoopEmitter.h            |  2 +-
 ...nsorLevel.cpp => SparseTensorIterator.cpp} | 64 ++++++++++++-------
 ...seTensorLevel.h => SparseTensorIterator.h} |  1 +
 4 files changed, 44 insertions(+), 25 deletions(-)
 rename mlir/lib/Dialect/SparseTensor/Transforms/Utils/{SparseTensorLevel.cpp => SparseTensorIterator.cpp} (96%)
 rename mlir/lib/Dialect/SparseTensor/Transforms/Utils/{SparseTensorLevel.h => SparseTensorIterator.h} (99%)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..af3a1b48f45af9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -20,7 +20,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   Utils/IterationGraphSorter.cpp
   Utils/LoopEmitter.cpp
   Utils/SparseTensorDescriptor.cpp
-  Utils/SparseTensorLevel.cpp
+  Utils/SparseTensorIterator.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index b5a0ac8484abdd..59c3e49264dbe1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -11,7 +11,7 @@
 
 #include <vector>
 
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
 
 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
similarity index 96%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index bc27fae5d19480..671a624a2fc39c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
 #include "CodegenUtils.h"
 
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;
 
 namespace {
 
+template <bool hasPosBuffer>
 class SparseLevel : public SparseTensorLevel {
+  // It is either a array of size 2 or size 1 depending on whether the sparse
+  // level requires a position array.
+  using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
+                                     std::array<Value, 1>>;
+
 public:
   SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-              Value crdBuffer)
-      : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
+              BufferT buffers)
+      : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
+
+  ValueRange getLvlBuffers() const override { return buffers; }
 
   Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                   Value iv) const override {
     SmallVector<Value> memCrd(batchPrefix);
     memCrd.push_back(iv);
-    return genIndexLoad(b, l, crdBuffer, memCrd);
+    return genIndexLoad(b, l, getCrdBuf(), memCrd);
   }
 
 protected:
-  const Value crdBuffer;
+  template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
+  Value getPosBuf() const {
+    return buffers[0];
+  }
+
+  Value getCrdBuf() const {
+    if constexpr (hasPosBuffer)
+      return buffers[1];
+    else
+      return buffers[0];
+  }
+
+  const BufferT buffers;
 };
 
 class DenseLevel : public SparseTensorLevel {
@@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
     llvm_unreachable("locate random-accessible level instead");
   }
 
+  ValueRange getLvlBuffers() const override { return {}; }
+
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
                         Value max) const override {
     Value posLo = MULI(p, lvlSize);
@@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
     llvm_unreachable("locate random-accessible level instead");
   }
 
+  ValueRange getLvlBuffers() const override { return {}; }
+
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
                         Value max) const override {
     assert(max == nullptr && "Dense level can not be non-unique.");
@@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
   }
 };
 
-class CompressedLevel : public SparseLevel {
+class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
 public:
   CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                   Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
@@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {
 
     SmallVector<Value> memCrd(batchPrefix);
     memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
     memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
     return {pLo, pHi};
   }
-
-private:
-  const Value posBuffer;
 };
 
-class LooseCompressedLevel : public SparseLevel {
+class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
 public:
   LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                        Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
@@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {
 
     p = MULI(p, C_IDX(2));
     memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
     memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
     return {pLo, pHi};
   }
-
-private:
-  const Value posBuffer;
 };
 
-class SingletonLevel : public SparseLevel {
+class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
 public:
   SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                  Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value segHi) const override {
@@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
   }
 };
 
-class NOutOfMLevel : public SparseLevel {
+class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
 public:
   NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
similarity index 99%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 9f92eecdf75cb6..19c0dc942ca62f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -55,6 +55,7 @@ class SparseTensorLevel {
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
   Value getSize() const { return lvlSize; }
+  virtual ValueRange getLvlBuffers() const = 0;
 
   //
   // Level properties

>From b87c16f183237a3f439df7f84c45d1bfc631d110 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 9 Apr 2024 17:52:16 +0000
Subject: [PATCH 2/2] address comments

---
 .../Transforms/Utils/SparseTensorIterator.cpp             | 4 ++--
 .../SparseTensor/Transforms/Utils/SparseTensorIterator.h  | 8 ++++----
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index 671a624a2fc39c..60dca3c55dec3d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -1,4 +1,4 @@
-//===- SparseTensorLevel.cpp - Tensor management class -------------------===//
+//===- SparseTensorIterator.cpp -------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -48,7 +48,7 @@ namespace {
 
 template <bool hasPosBuffer>
 class SparseLevel : public SparseTensorLevel {
-  // It is either a array of size 2 or size 1 depending on whether the sparse
+  // It is either an array of size 2 or size 1 depending on whether the sparse
   // level requires a position array.
   using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
                                      std::array<Value, 1>>;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 19c0dc942ca62f..9d69a233555986 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -1,4 +1,4 @@
-//===- SparseTensorLevel.h --------------------------------------*- C++ -*-===//
+//===- SparseTensorIterator.h ---------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
-#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
 
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -322,4 +322,4 @@ std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
 } // namespace sparse_tensor
 } // namespace mlir
 
-#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_



More information about the Mlir-commits mailing list