[Mlir-commits] [mlir] fdbe931 - [mlir][sparse] Adding getters/setters to `DimLvlMap`

wren romano llvmlistbot at llvm.org
Tue Aug 1 12:55:52 PDT 2023


Author: wren romano
Date: 2023-08-01T12:55:45-07:00
New Revision: fdbe9312b1c626ea61a2456db94cd52109ff1a50

URL: https://github.com/llvm/llvm-project/commit/fdbe9312b1c626ea61a2456db94cd52109ff1a50
DIFF: https://github.com/llvm/llvm-project/commit/fdbe9312b1c626ea61a2456db94cd52109ff1a50.diff

LOG: [mlir][sparse] Adding getters/setters to `DimLvlMap`

Depends On D156768

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D156770

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index cbdca742ff7e4d..6efcd0215ec6da 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -262,10 +262,8 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
   // needs to happen before the code for setting every `LvlSpec::elideVar`,
   // since if the LvlVar is only used in elided DimExpr, then the
   // LvlVar should also be elided.
-  // NOTE: Whenever we set a new DimExpr, we must make sure to validate it
-  // against our ranks, to restore the invariant established by `isWF` above.
-  // TODO(wrengr): We might should adjust the `DimLvlExpr` ctor to take a
-  // `Ranks` argument and perform the validation then.
+  // NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs,
+  // to ensure that we maintain the invariant established by `isWF` above.
 
   // Third, we set every `LvlSpec::elideVar` according to whether that
   // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
@@ -300,6 +298,22 @@ bool DimLvlMap::isWF() const {
   return true;
 }
 
+AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const {
+  SmallVector<AffineExpr> lvlAffines;
+  lvlAffines.reserve(getLvlRank());
+  for (const auto &lvlSpec : lvlSpecs)
+    lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr());
+  return AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context);
+}
+
+AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
+  SmallVector<AffineExpr> dimAffines;
+  dimAffines.reserve(getDimRank());
+  for (const auto &dimSpec : dimSpecs)
+    dimAffines.push_back(dimSpec.getExpr().getAffineExpr());
+  return AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context);
+}
+
 void DimLvlMap::dump() const {
   print(llvm::errs(), /*wantElision=*/false);
   llvm::errs() << "\n";

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index b1d3e437621c55..c39cd9a3e96f75 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -290,16 +290,6 @@ static_assert(IsZeroCostAbstraction<LvlSpec>);
 
 //===----------------------------------------------------------------------===//
 class DimLvlMap final {
-  // TODO(wrengr): Need to define getters
-  unsigned symRank;
-  SmallVector<DimSpec> dimSpecs;
-  SmallVector<LvlSpec> lvlSpecs;
-  bool mustPrintLvlVars;
-
-  // Checks for integrity of variable-binding structure.
-  // This is already called by the ctor.
-  [[nodiscard]] bool isWF() const;
-
 public:
   DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
             ArrayRef<LvlSpec> lvlSpecs);
@@ -310,11 +300,41 @@ class DimLvlMap final {
   unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
   Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }
 
-  DimLevelType getDimLevelType(unsigned i) { return lvlSpecs[i].getType(); }
+  ArrayRef<DimSpec> getDims() const { return dimSpecs; }
+  const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; }
+  SparseTensorDimSliceAttr getDimSlice(Dimension dim) const {
+    return getDim(dim).getSlice();
+  }
+
+  ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; }
+  const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; }
+  DimLevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }
+
+  AffineMap getDimToLvlMap(MLIRContext *context) const;
+  AffineMap getLvlToDimMap(MLIRContext *context) const;
 
   void print(llvm::raw_ostream &os, bool wantElision = true) const;
   void print(AsmPrinter &printer, bool wantElision = true) const;
   void dump() const;
+
+private:
+  /// Checks for integrity of variable-binding structure.
+  /// This is already called by the ctor.
+  [[nodiscard]] bool isWF() const;
+
+  /// Helper function to call `DimSpec::setExpr` while asserting that
+  /// the invariant established by `DimLvlMap:isWF` is maintained.
+  /// This is used by the ctor.
+  void setDimExpr(Dimension dim, DimExpr expr) {
+    assert(expr && getRanks().isValid(expr));
+    dimSpecs[dim].setExpr(expr);
+  }
+
+  // All these fields are const-after-ctor.
+  unsigned symRank;
+  SmallVector<DimSpec> dimSpecs;
+  SmallVector<LvlSpec> lvlSpecs;
+  bool mustPrintLvlVars;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 63f55bd43e8db3..15dae63649ede3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -115,7 +115,7 @@ bool VarSet::occursIn(DimLvlExpr expr) const {
 }
 
 void VarSet::add(Var var) {
-  // NOTE: `SmallBitVactor::operator[]` will raise assertion errors for OOB.
+  // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
   impl[var.getKind()][var.getNum()] = true;
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 7dd15b96b30656..db31ae0f0433d2 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -530,8 +530,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
       RETURN_ON_FAIL(res);
       // Proof of concept result.
       // TODO: use DimLvlMap directly as storage representation
-      for (unsigned i = 0, e = res->getLvlRank(); i < e; i++)
-        lvlTypes.push_back(res->getDimLevelType(i));
+      for (Level lvl = 0, lvlRank = res->getLvlRank(); lvl < lvlRank; lvl++)
+        lvlTypes.push_back(res->getLvlType(lvl));
     }
 
     // Only the last item can omit the comma


        


More information about the Mlir-commits mailing list