[Mlir-commits] [mlir] [mlir][sparse] Clean up parser (PR #72571)

Yinying Li llvmlistbot at llvm.org
Thu Nov 16 13:30:12 PST 2023


https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/72571

>From 626c729c3f4102f2cd1e8052e4dc6def627783f9 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 16 Nov 2023 21:09:18 +0000
Subject: [PATCH 1/2] [mlir][sparse] Clean up parser

Remove unused functions in parser.
---
 .../SparseTensor/IR/Detail/DimLvlMap.cpp      | 236 ------------------
 .../SparseTensor/IR/Detail/DimLvlMap.h        |  36 ---
 .../Dialect/SparseTensor/IR/Detail/Var.cpp    |  30 ---
 mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h |  21 +-
 4 files changed, 1 insertion(+), 322 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 9757a599bd1eb60..022180e333530af 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -16,21 +16,6 @@ using namespace mlir::sparse_tensor::ir_detail;
 // `DimLvlExpr` implementation.
 //===----------------------------------------------------------------------===//
 
-Var DimLvlExpr::castAnyVar() const {
-  assert(expr && "uninitialized DimLvlExpr");
-  const auto var = dyn_castAnyVar();
-  assert(var && "expected DimLvlExpr to be a Var");
-  return *var;
-}
-
-std::optional<Var> DimLvlExpr::dyn_castAnyVar() const {
-  if (const auto s = dyn_cast_or_null<AffineSymbolExpr>(expr))
-    return SymVar(s);
-  if (const auto x = dyn_cast_or_null<AffineDimExpr>(expr))
-    return Var(getAllowedVarKind(), x);
-  return std::nullopt;
-}
-
 SymVar DimLvlExpr::castSymVar() const {
   return SymVar(llvm::cast<AffineSymbolExpr>(expr));
 }
@@ -51,30 +36,6 @@ std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
   return std::nullopt;
 }
 
-int64_t DimLvlExpr::castConstantValue() const {
-  return llvm::cast<AffineConstantExpr>(expr).getValue();
-}
-
-std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
-  const auto k = dyn_cast_or_null<AffineConstantExpr>(expr);
-  return k ? std::make_optional(k.getValue()) : std::nullopt;
-}
-
-bool DimLvlExpr::hasConstantValue(int64_t val) const {
-  const auto k = dyn_cast_or_null<AffineConstantExpr>(expr);
-  return k && k.getValue() == val;
-}
-
-DimLvlExpr DimLvlExpr::getLHS() const {
-  const auto binop = dyn_cast_or_null<AffineBinaryOpExpr>(expr);
-  return DimLvlExpr(kind, binop ? binop.getLHS() : nullptr);
-}
-
-DimLvlExpr DimLvlExpr::getRHS() const {
-  const auto binop = dyn_cast_or_null<AffineBinaryOpExpr>(expr);
-  return DimLvlExpr(kind, binop ? binop.getRHS() : nullptr);
-}
-
 std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr>
 DimLvlExpr::unpackBinop() const {
   const auto ak = getAffineKind();
@@ -84,114 +45,6 @@ DimLvlExpr::unpackBinop() const {
   return {lhs, ak, rhs};
 }
 
-void DimLvlExpr::dump() const {
-  print(llvm::errs());
-  llvm::errs() << "\n";
-}
-std::string DimLvlExpr::str() const {
-  std::string str;
-  llvm::raw_string_ostream os(str);
-  print(os);
-  return os.str();
-}
-void DimLvlExpr::print(AsmPrinter &printer) const {
-  print(printer.getStream());
-}
-void DimLvlExpr::print(llvm::raw_ostream &os) const {
-  if (!expr)
-    os << "<<NULL AFFINE EXPR>>";
-  else
-    printWeak(os);
-}
-
-namespace {
-struct MatchNeg final : public std::pair<DimLvlExpr, int64_t> {
-  using Base = std::pair<DimLvlExpr, int64_t>;
-  using Base::Base;
-  constexpr DimLvlExpr getLHS() const { return first; }
-  constexpr int64_t getRHS() const { return second; }
-};
-} // namespace
-
-static std::optional<MatchNeg> matchNeg(DimLvlExpr expr) {
-  const auto [lhs, op, rhs] = expr.unpackBinop();
-  if (op == AffineExprKind::Constant) {
-    const auto val = expr.castConstantValue();
-    if (val < 0)
-      return MatchNeg{DimLvlExpr{expr.getExprKind(), AffineExpr()}, val};
-  }
-  if (op == AffineExprKind::Mul)
-    if (const auto rval = rhs.dyn_castConstantValue(); rval && *rval < 0)
-      return MatchNeg{lhs, *rval};
-  return std::nullopt;
-}
-
-// A heavily revised version of `AsmPrinter::Impl::printAffineExprInternal`.
-void DimLvlExpr::printAffineExprInternal(
-    llvm::raw_ostream &os, BindingStrength enclosingTightness) const {
-  const char *binopSpelling = nullptr;
-  switch (getAffineKind()) {
-  case AffineExprKind::SymbolId:
-    os << castSymVar();
-    return;
-  case AffineExprKind::DimId:
-    os << castDimLvlVar();
-    return;
-  case AffineExprKind::Constant:
-    os << castConstantValue();
-    return;
-  case AffineExprKind::Add:
-    binopSpelling = " + "; // N.B., this is unused
-    break;
-  case AffineExprKind::Mul:
-    binopSpelling = " * ";
-    break;
-  case AffineExprKind::FloorDiv:
-    binopSpelling = " floordiv ";
-    break;
-  case AffineExprKind::CeilDiv:
-    binopSpelling = " ceildiv ";
-    break;
-  case AffineExprKind::Mod:
-    binopSpelling = " mod ";
-    break;
-  }
-
-  if (enclosingTightness == BindingStrength::Strong)
-    os << '(';
-
-  const auto [lhs, op, rhs] = unpackBinop();
-  if (op == AffineExprKind::Mul && rhs.hasConstantValue(-1)) {
-    // Pretty print `(lhs * -1)` as "-lhs".
-    os << '-';
-    lhs.printStrong(os);
-  } else if (op != AffineExprKind::Add) {
-    // Default rule for tightly binding binary operators.
-    // (Including `Mul` that didn't match the previous rule.)
-    lhs.printStrong(os);
-    os << binopSpelling;
-    rhs.printStrong(os);
-  } else {
-    // Combination of all the special rules for addition/subtraction.
-    lhs.printWeak(os);
-    const auto rx = matchNeg(rhs);
-    os << (rx ? " - " : " + ");
-    const auto &rlhs = rx ? rx->getLHS() : rhs;
-    const auto rrhs = rx ? rx->getRHS() : -1; // value irrelevant when `!rx`
-    const bool nonunit = rrhs != -1;          // value irrelevant when `!rx`
-    const bool isStrong =
-        rx && rlhs && (nonunit || rlhs.getAffineKind() == AffineExprKind::Add);
-    if (rlhs)
-      rlhs.printAffineExprInternal(os, BindingStrength{isStrong});
-    if (rx && rlhs && nonunit)
-      os << " * ";
-    if (rx && (!rlhs || nonunit))
-      os << -rrhs;
-  }
-
-  if (enclosingTightness == BindingStrength::Strong)
-    os << ')';
-}
 
 //===----------------------------------------------------------------------===//
 // `DimSpec` implementation.
@@ -206,30 +59,6 @@ bool DimSpec::isValid(Ranks const &ranks) const {
   return ranks.isValid(var) && (!expr || ranks.isValid(expr));
 }
 
-void DimSpec::dump() const {
-  print(llvm::errs(), /*wantElision=*/false);
-  llvm::errs() << "\n";
-}
-std::string DimSpec::str(bool wantElision) const {
-  std::string str;
-  llvm::raw_string_ostream os(str);
-  print(os, wantElision);
-  return os.str();
-}
-void DimSpec::print(AsmPrinter &printer, bool wantElision) const {
-  print(printer.getStream(), wantElision);
-}
-void DimSpec::print(llvm::raw_ostream &os, bool wantElision) const {
-  os << var;
-  if (expr && (!wantElision || !elideExpr))
-    os << " = " << expr;
-  if (slice) {
-    os << " : ";
-    // Call `SparseTensorDimSliceAttr::print` directly, to avoid
-    // printing the mnemonic.
-    slice.print(os);
-  }
-}
 
 //===----------------------------------------------------------------------===//
 // `LvlSpec` implementation.
@@ -246,25 +75,6 @@ bool LvlSpec::isValid(Ranks const &ranks) const {
   return ranks.isValid(var) && ranks.isValid(expr);
 }
 
-void LvlSpec::dump() const {
-  print(llvm::errs(), /*wantElision=*/false);
-  llvm::errs() << "\n";
-}
-std::string LvlSpec::str(bool wantElision) const {
-  std::string str;
-  llvm::raw_string_ostream os(str);
-  print(os, wantElision);
-  return os.str();
-}
-void LvlSpec::print(AsmPrinter &printer, bool wantElision) const {
-  print(printer.getStream(), wantElision);
-}
-void LvlSpec::print(llvm::raw_ostream &os, bool wantElision) const {
-  if (!wantElision || !elideVar)
-    os << var << " = ";
-  os << expr;
-  os << ": " << toMLIRString(type);
-}
 
 //===----------------------------------------------------------------------===//
 // `DimLvlMap` implementation.
@@ -334,51 +144,5 @@ AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
   return map;
 }
 
-void DimLvlMap::dump() const {
-  print(llvm::errs(), /*wantElision=*/false);
-  llvm::errs() << "\n";
-}
-std::string DimLvlMap::str(bool wantElision) const {
-  std::string str;
-  llvm::raw_string_ostream os(str);
-  print(os, wantElision);
-  return os.str();
-}
-void DimLvlMap::print(AsmPrinter &printer, bool wantElision) const {
-  print(printer.getStream(), wantElision);
-}
-void DimLvlMap::print(llvm::raw_ostream &os, bool wantElision) const {
-  // Symbolic identifiers.
-  // NOTE: Unlike `AffineMap` we place the SymVar bindings before the DimVar
-  // bindings, since the SymVars may occur within DimExprs and thus this
-  // ordering helps reduce potential user confusion about the scope of bidings
-  // (since it means SymVars and DimVars both bind-forward in the usual way,
-  // whereas only LvlVars have different binding rules).
-  if (symRank != 0) {
-    os << "[s0";
-    for (unsigned i = 1; i < symRank; ++i)
-      os << ", s" << i;
-    os << ']';
-  }
-
-  // LvlVar forward-declarations.
-  if (mustPrintLvlVars) {
-    os << '{';
-    llvm::interleaveComma(
-        lvlSpecs, os, [&](LvlSpec const &spec) { os << spec.getBoundVar(); });
-    os << "} ";
-  }
-
-  // Dimension specifiers.
-  os << '(';
-  llvm::interleaveComma(
-      dimSpecs, os, [&](DimSpec const &spec) { spec.print(os, wantElision); });
-  os << ") -> (";
-  // Level specifiers.
-  wantElision = wantElision && !mustPrintLvlVars;
-  llvm::interleaveComma(
-      lvlSpecs, os, [&](LvlSpec const &spec) { spec.print(os, wantElision); });
-  os << ')';
-}
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index b3200d0983eb790..8563d8f7e936ca4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -77,40 +77,19 @@ class DimLvlExpr {
   //
   // Getters for handling `AffineExpr` subclasses.
   //
-  Var castAnyVar() const;
-  std::optional<Var> dyn_castAnyVar() const;
   SymVar castSymVar() const;
   std::optional<SymVar> dyn_castSymVar() const;
   Var castDimLvlVar() const;
   std::optional<Var> dyn_castDimLvlVar() const;
-  int64_t castConstantValue() const;
-  std::optional<int64_t> dyn_castConstantValue() const;
-  bool hasConstantValue(int64_t val) const;
-  DimLvlExpr getLHS() const;
-  DimLvlExpr getRHS() const;
   std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;
 
   /// Checks whether the variables bound/used by this spec are valid
   /// with respect to the given ranks.
   [[nodiscard]] bool isValid(Ranks const &ranks) const;
 
-  std::string str() const;
-  void print(llvm::raw_ostream &os) const;
-  void print(AsmPrinter &printer) const;
-  void dump() const;
-
 protected:
   // Variant of `mlir::AsmPrinter::Impl::BindingStrength`
   enum class BindingStrength : bool { Weak = false, Strong = true };
-
-  void printAffineExprInternal(llvm::raw_ostream &os,
-                               BindingStrength enclosingTightness) const;
-  void printStrong(llvm::raw_ostream &os) const {
-    printAffineExprInternal(os, BindingStrength::Strong);
-  }
-  void printWeak(llvm::raw_ostream &os) const {
-    printAffineExprInternal(os, BindingStrength::Weak);
-  }
 };
 static_assert(IsZeroCostAbstraction<DimLvlExpr>);
 
@@ -208,11 +187,6 @@ class DimSpec final {
   /// to be vacuously valid, and therefore calling `setExpr` invalidates
   /// the result of this predicate.
   [[nodiscard]] bool isValid(Ranks const &ranks) const;
-
-  std::string str(bool wantElision = true) const;
-  void print(llvm::raw_ostream &os, bool wantElision = true) const;
-  void print(AsmPrinter &printer, bool wantElision = true) const;
-  void dump() const;
 };
 
 static_assert(IsZeroCostAbstraction<DimSpec>);
@@ -248,11 +222,6 @@ class LvlSpec final {
   /// Checks whether the variables bound/used by this spec are valid
   /// with respect to the given ranks.
   [[nodiscard]] bool isValid(Ranks const &ranks) const;
-
-  std::string str(bool wantElision = true) const;
-  void print(llvm::raw_ostream &os, bool wantElision = true) const;
-  void print(AsmPrinter &printer, bool wantElision = true) const;
-  void dump() const;
 };
 
 static_assert(IsZeroCostAbstraction<LvlSpec>);
@@ -282,11 +251,6 @@ class DimLvlMap final {
   AffineMap getDimToLvlMap(MLIRContext *context) const;
   AffineMap getLvlToDimMap(MLIRContext *context) const;
 
-  std::string str(bool wantElision = true) const;
-  void print(llvm::raw_ostream &os, bool wantElision = true) const;
-  void print(AsmPrinter &printer, bool wantElision = true) const;
-  void dump() const;
-
 private:
   /// Checks for integrity of variable-binding structure.
   /// This is already called by the ctor.
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 966e32401c1f9e3..481275f052a3cee 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -84,36 +84,6 @@ bool VarSet::contains(Var var) const {
   return num < bits.size() && bits[num];
 }
 
-bool VarSet::occursIn(VarSet const &other) const {
-  for (const auto vk : everyVarKind)
-    if (impl[vk].anyCommon(other.impl[vk]))
-      return true;
-  return false;
-}
-
-bool VarSet::occursIn(DimLvlExpr expr) const {
-  if (!expr)
-    return false;
-  switch (expr.getAffineKind()) {
-  case AffineExprKind::Constant:
-    return false;
-  case AffineExprKind::SymbolId:
-    return contains(expr.castSymVar());
-  case AffineExprKind::DimId:
-    return contains(expr.castDimLvlVar());
-  case AffineExprKind::Add:
-  case AffineExprKind::Mul:
-  case AffineExprKind::Mod:
-  case AffineExprKind::FloorDiv:
-  case AffineExprKind::CeilDiv: {
-    const auto [lhs, op, rhs] = expr.unpackBinop();
-    (void)op;
-    return occursIn(lhs) || occursIn(rhs);
-  }
-  }
-  llvm_unreachable("unknown AffineExprKind");
-}
-
 void VarSet::add(Var var) {
   // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
   impl[var.getKind()][var.getNum()] = true;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index 81f480187c059e7..dce8b003b013bb9 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -36,14 +36,6 @@ enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 };
   return 0 <= vk_ && vk_ <= 2;
 }
 
-/// Swaps `Dimension` and `Level`, but leaves `Symbol` the same.
-constexpr VarKind flipVarKind(VarKind vk) {
-  return VarKind{2 - llvm::to_underlying(vk)};
-}
-static_assert(flipVarKind(VarKind::Symbol) == VarKind::Symbol &&
-              flipVarKind(VarKind::Dimension) == VarKind::Level &&
-              flipVarKind(VarKind::Level) == VarKind::Dimension);
-
 /// Gets the ASCII character used as the prefix when printing `Var`.
 constexpr char toChar(VarKind vk) {
   // If `isWF(vk)` then this computation's intermediate results are always
@@ -260,12 +252,10 @@ class VarSet final {
   Ranks getRanks() const {
     return Ranks(getSymRank(), getDimRank(), getLvlRank());
   }
-  /// For the `contains`/`occursIn` methods: if variables occurring in
+  /// For the `contains` method: if variables occurring in
   /// the method parameter are OOB for the `VarSet`, then these methods will
   /// always return false.
   bool contains(Var var) const;
-  bool occursIn(VarSet const &vars) const;
-  bool occursIn(DimLvlExpr expr) const;
 
   /// For the `add` methods: OOB parameters cause undefined behavior.
   /// Currently the `add` methods will raise an assertion error.
@@ -319,9 +309,6 @@ class VarInfo final {
     assert(hasNum());
     return Var(kind, *num);
   }
-  constexpr std::optional<Var> tryGetVar() const {
-    return num ? std::make_optional(Var(kind, *num)) : std::nullopt;
-  }
 };
 
 //===----------------------------------------------------------------------===//
@@ -405,12 +392,6 @@ class VarEnv final {
   /// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
   /// failure if the variable is not bound.
   Var getVar(VarInfo::ID id) const { return access(id).getVar(); }
-
-  /// Gets the `Var` identified by the `VarInfo::ID`, returning nullopt
-  /// if the variable is not bound.
-  std::optional<Var> tryGetVar(VarInfo::ID id) const {
-    return access(id).tryGetVar();
-  }
 };
 
 //===----------------------------------------------------------------------===//

>From 675469e8031b966ca77579e3632849a854340b86 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 16 Nov 2023 21:29:49 +0000
Subject: [PATCH 2/2] format

---
 mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 022180e333530af..95f8d7bf595c9ed 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -45,7 +45,6 @@ DimLvlExpr::unpackBinop() const {
   return {lhs, ak, rhs};
 }
 
-
 //===----------------------------------------------------------------------===//
 // `DimSpec` implementation.
 //===----------------------------------------------------------------------===//
@@ -59,7 +58,6 @@ bool DimSpec::isValid(Ranks const &ranks) const {
   return ranks.isValid(var) && (!expr || ranks.isValid(expr));
 }
 
-
 //===----------------------------------------------------------------------===//
 // `LvlSpec` implementation.
 //===----------------------------------------------------------------------===//
@@ -75,7 +73,6 @@ bool LvlSpec::isValid(Ranks const &ranks) const {
   return ranks.isValid(var) && ranks.isValid(expr);
 }
 
-
 //===----------------------------------------------------------------------===//
 // `DimLvlMap` implementation.
 //===----------------------------------------------------------------------===//
@@ -144,5 +141,4 @@ AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
   return map;
 }
 
-
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list