[Mlir-commits] [mlir] [mlir][sparse] fixed naming consistency (PR #73053)

Aart Bik llvmlistbot at llvm.org
Tue Nov 21 15:41:58 PST 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/73053

 All DLT related methods have DLT at end, removed stale TODO

>From 362438f39f5f5fd892b03054853ac623ae5e2405 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 21 Nov 2023 15:38:30 -0800
Subject: [PATCH 1/2] [mlir][sparse] fixed naming consistency

All DLT related methods have DLT at end
---
 mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h             | 4 ++--
 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h  | 4 ++--
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp      | 4 ++--
 .../Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp   | 4 ++--
 4 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 9c277a0b23633d8..c00fb93f389acce 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -282,12 +282,12 @@ constexpr bool is2OutOf4DLT(DimLevelType dlt) {
 }
 
 /// Check if the `DimLevelType` needs positions array.
-constexpr bool isDLTWithPos(DimLevelType dlt) {
+constexpr bool isWithPosDLT(DimLevelType dlt) {
   return isCompressedDLT(dlt) || isLooseCompressedDLT(dlt);
 }
 
 /// Check if the `DimLevelType` needs coordinates array.
-constexpr bool isDLTWithCrd(DimLevelType dlt) {
+constexpr bool isWithCrdDLT(DimLevelType dlt) {
   return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
          isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
 }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 6c427ae1b7aee69..220975edb61359c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -302,8 +302,8 @@ class SparseTensorType {
   bool is2OutOf4Lvl(Level l) const { return is2OutOf4DLT(getLvlType(l)); }
   bool isOrderedLvl(Level l) const { return isOrderedDLT(getLvlType(l)); }
   bool isUniqueLvl(Level l) const { return isUniqueDLT(getLvlType(l)); }
-  bool isWithPos(Level l) const { return isDLTWithPos(getLvlType(l)); }
-  bool isWithCrd(Level l) const { return isDLTWithCrd(getLvlType(l)); }
+  bool isWithPos(Level l) const { return isWithPosDLT(getLvlType(l)); }
+  bool isWithCrd(Level l) const { return isWithCrdDLT(getLvlType(l)); }
 
   /// Returns the coordinate-overhead bitwidth, defaulting to zero.
   unsigned getCrdWidth() const { return enc ? enc.getCrdWidth() : 0; }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index f49d0893246943a..fb2e70482a1978b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -72,11 +72,11 @@ void StorageLayout::foreachField(
   // Per-level storage.
   for (Level l = 0; l < end; l++) {
     const auto dlt = lvlTypes[l];
-    if (isDLTWithPos(dlt)) {
+    if (isWithPosDLT(dlt)) {
       if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt)))
         return;
     }
-    if (isDLTWithCrd(dlt)) {
+    if (isWithCrdDLT(dlt)) {
       if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt)))
         return;
     }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 1c0366c6476a385..33f1ebecfdcf21b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1341,7 +1341,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
         continue;
       }
 
-      if (isDLTWithPos(dlt)) {
+      if (isWithPosDLT(dlt)) {
         assert(isCompressedDLT(dlt) || isLooseCompressedDLT(dlt));
         if (isLooseCompressedDLT(dlt)) {
           memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
@@ -1356,7 +1356,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
         memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
         posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
       }
-      assert(isDLTWithCrd(dlt) && lvl <= trailCOOStart);
+      assert(isWithCrdDLT(dlt) && lvl <= trailCOOStart);
       // FIXME: This seems to be unnecessarily complex, can we simplify it?
       if (lvl == trailCOOStart) {
         Value cooSz = rewriter.create<arith::MulIOp>(

>From 9234eb2e37649c5ebbdb2c10dd020cde9c8c6766 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 21 Nov 2023 15:39:44 -0800
Subject: [PATCH 2/2] removed stale TODO

---
 mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index c00fb93f389acce..697bb0733953d64 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -311,10 +311,8 @@ constexpr std::optional<LevelFormat> getLevelFormat(DimLevelType dlt) {
 }
 
 /// Convert a LevelFormat to its corresponding DimLevelType with the given
-/// properties. Returns std::nullopt when the properties are not applicable for
-/// the input level format.
-/// TODO: factor out a new LevelProperties type so we can add new properties
-/// without changing this function's signature
+/// properties. Returns std::nullopt when the properties are not applicable
+/// for the input level format.
 constexpr std::optional<DimLevelType>
 buildLevelType(LevelFormat lf, bool ordered, bool unique) {
   auto dlt = static_cast<DimLevelType>(static_cast<uint8_t>(lf) |



More information about the Mlir-commits mailing list