[Mlir-commits] [mlir] 0495912 - [mlir][sparse] Adding `LvlVar` forward-declarations to `DimLvlMap::print`

wren romano llvmlistbot at llvm.org
Tue Aug 1 12:52:25 PDT 2023


Author: wren romano
Date: 2023-08-01T12:52:17-07:00
New Revision: 0495912388fd174406c81421a91e1356b775dd07

URL: https://github.com/llvm/llvm-project/commit/0495912388fd174406c81421a91e1356b775dd07
DIFF: https://github.com/llvm/llvm-project/commit/0495912388fd174406c81421a91e1356b775dd07.diff

LOG: [mlir][sparse] Adding `LvlVar` forward-declarations to `DimLvlMap::print`

This commit makes `DimLvlMap::print` consistent with `DimLvlMapParser` with respect to both the forward-decls per se, as well as the all-or-none constraint on LvlVar-bindings in LvlSpecs.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D156768

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 3de88aa085635b..cbdca742ff7e4d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -249,7 +249,8 @@ void LvlSpec::print(llvm::raw_ostream &os, bool wantElision) const {
 
 DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
                      ArrayRef<LvlSpec> lvlSpecs)
-    : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs) {
+    : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs),
+      mustPrintLvlVars(false) {
   // First, check integrity of the variable-binding structure.
   // NOTE: This establishes the invariant that calls to `VarSet::add`
   // below cannot cause OOB errors.
@@ -274,8 +275,14 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
   for (const auto &dimSpec : dimSpecs)
     if (!dimSpec.canElideExpr())
       usedVars.add(dimSpec.getExpr());
-  for (auto &lvlSpec : this->lvlSpecs)
-    lvlSpec.setElideVar(!usedVars.contains(lvlSpec.getBoundVar()));
+  for (auto &lvlSpec : this->lvlSpecs) {
+    // Is this LvlVar used in any overt expression?
+    const bool isUsed = usedVars.contains(lvlSpec.getBoundVar());
+    // This LvlVar can be elided iff it isn't overtly used.
+    lvlSpec.setElideVar(!isUsed);
+    // If any LvlVar cannot be elided, then must forward-declare all LvlVars.
+    mustPrintLvlVars = mustPrintLvlVars || isUsed;
+  }
 }
 
 bool DimLvlMap::isWF() const {
@@ -314,12 +321,21 @@ void DimLvlMap::print(llvm::raw_ostream &os, bool wantElision) const {
     os << ']';
   }
 
+  // LvlVar forward-declarations.
+  if (mustPrintLvlVars) {
+    os << '{';
+    llvm::interleaveComma(
+        lvlSpecs, os, [&](LvlSpec const &spec) { os << spec.getBoundVar(); });
+    os << '}';
+  }
+
   // Dimension specifiers.
   os << '(';
   llvm::interleaveComma(
       dimSpecs, os, [&](DimSpec const &spec) { spec.print(os, wantElision); });
   os << ") -> (";
   // Level specifiers.
+  wantElision = wantElision && !mustPrintLvlVars;
   llvm::interleaveComma(
       lvlSpecs, os, [&](LvlSpec const &spec) { spec.print(os, wantElision); });
   os << ')';

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index 9d8a4454364acd..b1d3e437621c55 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -294,6 +294,7 @@ class DimLvlMap final {
   unsigned symRank;
   SmallVector<DimSpec> dimSpecs;
   SmallVector<LvlSpec> lvlSpecs;
+  bool mustPrintLvlVars;
 
   // Checks for integrity of variable-binding structure.
   // This is already called by the ctor.


        


More information about the Mlir-commits mailing list