[Mlir-commits] [mlir] 704c224 - [mlir][sparse] Clean up parser (#72571)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 17 10:38:26 PST 2023
Author: Yinying Li
Date: 2023-11-17T13:38:22-05:00
New Revision: 704c22473641e26d95435c55aa482fbf5abbbc2c
URL: https://github.com/llvm/llvm-project/commit/704c22473641e26d95435c55aa482fbf5abbbc2c
DIFF: https://github.com/llvm/llvm-project/commit/704c22473641e26d95435c55aa482fbf5abbbc2c.diff
LOG: [mlir][sparse] Clean up parser (#72571)
Remove unused functions in parser.
Added:
Modified:
mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 9757a599bd1eb60..95f8d7bf595c9ed 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -16,21 +16,6 @@ using namespace mlir::sparse_tensor::ir_detail;
// `DimLvlExpr` implementation.
//===----------------------------------------------------------------------===//
-Var DimLvlExpr::castAnyVar() const {
- assert(expr && "uninitialized DimLvlExpr");
- const auto var = dyn_castAnyVar();
- assert(var && "expected DimLvlExpr to be a Var");
- return *var;
-}
-
-std::optional<Var> DimLvlExpr::dyn_castAnyVar() const {
- if (const auto s = dyn_cast_or_null<AffineSymbolExpr>(expr))
- return SymVar(s);
- if (const auto x = dyn_cast_or_null<AffineDimExpr>(expr))
- return Var(getAllowedVarKind(), x);
- return std::nullopt;
-}
-
SymVar DimLvlExpr::castSymVar() const {
return SymVar(llvm::cast<AffineSymbolExpr>(expr));
}
@@ -51,30 +36,6 @@ std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
return std::nullopt;
}
-int64_t DimLvlExpr::castConstantValue() const {
- return llvm::cast<AffineConstantExpr>(expr).getValue();
-}
-
-std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
- const auto k = dyn_cast_or_null<AffineConstantExpr>(expr);
- return k ? std::make_optional(k.getValue()) : std::nullopt;
-}
-
-bool DimLvlExpr::hasConstantValue(int64_t val) const {
- const auto k = dyn_cast_or_null<AffineConstantExpr>(expr);
- return k && k.getValue() == val;
-}
-
-DimLvlExpr DimLvlExpr::getLHS() const {
- const auto binop = dyn_cast_or_null<AffineBinaryOpExpr>(expr);
- return DimLvlExpr(kind, binop ? binop.getLHS() : nullptr);
-}
-
-DimLvlExpr DimLvlExpr::getRHS() const {
- const auto binop = dyn_cast_or_null<AffineBinaryOpExpr>(expr);
- return DimLvlExpr(kind, binop ? binop.getRHS() : nullptr);
-}
-
std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr>
DimLvlExpr::unpackBinop() const {
const auto ak = getAffineKind();
@@ -84,115 +45,6 @@ DimLvlExpr::unpackBinop() const {
return {lhs, ak, rhs};
}
-void DimLvlExpr::dump() const {
- print(llvm::errs());
- llvm::errs() << "\n";
-}
-std::string DimLvlExpr::str() const {
- std::string str;
- llvm::raw_string_ostream os(str);
- print(os);
- return os.str();
-}
-void DimLvlExpr::print(AsmPrinter &printer) const {
- print(printer.getStream());
-}
-void DimLvlExpr::print(llvm::raw_ostream &os) const {
- if (!expr)
- os << "<<NULL AFFINE EXPR>>";
- else
- printWeak(os);
-}
-
-namespace {
-struct MatchNeg final : public std::pair<DimLvlExpr, int64_t> {
- using Base = std::pair<DimLvlExpr, int64_t>;
- using Base::Base;
- constexpr DimLvlExpr getLHS() const { return first; }
- constexpr int64_t getRHS() const { return second; }
-};
-} // namespace
-
-static std::optional<MatchNeg> matchNeg(DimLvlExpr expr) {
- const auto [lhs, op, rhs] = expr.unpackBinop();
- if (op == AffineExprKind::Constant) {
- const auto val = expr.castConstantValue();
- if (val < 0)
- return MatchNeg{DimLvlExpr{expr.getExprKind(), AffineExpr()}, val};
- }
- if (op == AffineExprKind::Mul)
- if (const auto rval = rhs.dyn_castConstantValue(); rval && *rval < 0)
- return MatchNeg{lhs, *rval};
- return std::nullopt;
-}
-
-// A heavily revised version of `AsmPrinter::Impl::printAffineExprInternal`.
-void DimLvlExpr::printAffineExprInternal(
- llvm::raw_ostream &os, BindingStrength enclosingTightness) const {
- const char *binopSpelling = nullptr;
- switch (getAffineKind()) {
- case AffineExprKind::SymbolId:
- os << castSymVar();
- return;
- case AffineExprKind::DimId:
- os << castDimLvlVar();
- return;
- case AffineExprKind::Constant:
- os << castConstantValue();
- return;
- case AffineExprKind::Add:
- binopSpelling = " + "; // N.B., this is unused
- break;
- case AffineExprKind::Mul:
- binopSpelling = " * ";
- break;
- case AffineExprKind::FloorDiv:
- binopSpelling = " floordiv ";
- break;
- case AffineExprKind::CeilDiv:
- binopSpelling = " ceildiv ";
- break;
- case AffineExprKind::Mod:
- binopSpelling = " mod ";
- break;
- }
-
- if (enclosingTightness == BindingStrength::Strong)
- os << '(';
-
- const auto [lhs, op, rhs] = unpackBinop();
- if (op == AffineExprKind::Mul && rhs.hasConstantValue(-1)) {
- // Pretty print `(lhs * -1)` as "-lhs".
- os << '-';
- lhs.printStrong(os);
- } else if (op != AffineExprKind::Add) {
- // Default rule for tightly binding binary operators.
- // (Including `Mul` that didn't match the previous rule.)
- lhs.printStrong(os);
- os << binopSpelling;
- rhs.printStrong(os);
- } else {
- // Combination of all the special rules for addition/subtraction.
- lhs.printWeak(os);
- const auto rx = matchNeg(rhs);
- os << (rx ? " - " : " + ");
- const auto &rlhs = rx ? rx->getLHS() : rhs;
- const auto rrhs = rx ? rx->getRHS() : -1; // value irrelevant when `!rx`
- const bool nonunit = rrhs != -1; // value irrelevant when `!rx`
- const bool isStrong =
- rx && rlhs && (nonunit || rlhs.getAffineKind() == AffineExprKind::Add);
- if (rlhs)
- rlhs.printAffineExprInternal(os, BindingStrength{isStrong});
- if (rx && rlhs && nonunit)
- os << " * ";
- if (rx && (!rlhs || nonunit))
- os << -rrhs;
- }
-
- if (enclosingTightness == BindingStrength::Strong)
- os << ')';
-}
-
//===----------------------------------------------------------------------===//
// `DimSpec` implementation.
//===----------------------------------------------------------------------===//
@@ -206,31 +58,6 @@ bool DimSpec::isValid(Ranks const &ranks) const {
return ranks.isValid(var) && (!expr || ranks.isValid(expr));
}
-void DimSpec::dump() const {
- print(llvm::errs(), /*wantElision=*/false);
- llvm::errs() << "\n";
-}
-std::string DimSpec::str(bool wantElision) const {
- std::string str;
- llvm::raw_string_ostream os(str);
- print(os, wantElision);
- return os.str();
-}
-void DimSpec::print(AsmPrinter &printer, bool wantElision) const {
- print(printer.getStream(), wantElision);
-}
-void DimSpec::print(llvm::raw_ostream &os, bool wantElision) const {
- os << var;
- if (expr && (!wantElision || !elideExpr))
- os << " = " << expr;
- if (slice) {
- os << " : ";
- // Call `SparseTensorDimSliceAttr::print` directly, to avoid
- // printing the mnemonic.
- slice.print(os);
- }
-}
-
//===----------------------------------------------------------------------===//
// `LvlSpec` implementation.
//===----------------------------------------------------------------------===//
@@ -246,26 +73,6 @@ bool LvlSpec::isValid(Ranks const &ranks) const {
return ranks.isValid(var) && ranks.isValid(expr);
}
-void LvlSpec::dump() const {
- print(llvm::errs(), /*wantElision=*/false);
- llvm::errs() << "\n";
-}
-std::string LvlSpec::str(bool wantElision) const {
- std::string str;
- llvm::raw_string_ostream os(str);
- print(os, wantElision);
- return os.str();
-}
-void LvlSpec::print(AsmPrinter &printer, bool wantElision) const {
- print(printer.getStream(), wantElision);
-}
-void LvlSpec::print(llvm::raw_ostream &os, bool wantElision) const {
- if (!wantElision || !elideVar)
- os << var << " = ";
- os << expr;
- os << ": " << toMLIRString(type);
-}
-
//===----------------------------------------------------------------------===//
// `DimLvlMap` implementation.
//===----------------------------------------------------------------------===//
@@ -334,51 +141,4 @@ AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
return map;
}
-void DimLvlMap::dump() const {
- print(llvm::errs(), /*wantElision=*/false);
- llvm::errs() << "\n";
-}
-std::string DimLvlMap::str(bool wantElision) const {
- std::string str;
- llvm::raw_string_ostream os(str);
- print(os, wantElision);
- return os.str();
-}
-void DimLvlMap::print(AsmPrinter &printer, bool wantElision) const {
- print(printer.getStream(), wantElision);
-}
-void DimLvlMap::print(llvm::raw_ostream &os, bool wantElision) const {
- // Symbolic identifiers.
- // NOTE: Unlike `AffineMap` we place the SymVar bindings before the DimVar
- // bindings, since the SymVars may occur within DimExprs and thus this
- // ordering helps reduce potential user confusion about the scope of bidings
- // (since it means SymVars and DimVars both bind-forward in the usual way,
- // whereas only LvlVars have
diff erent binding rules).
- if (symRank != 0) {
- os << "[s0";
- for (unsigned i = 1; i < symRank; ++i)
- os << ", s" << i;
- os << ']';
- }
-
- // LvlVar forward-declarations.
- if (mustPrintLvlVars) {
- os << '{';
- llvm::interleaveComma(
- lvlSpecs, os, [&](LvlSpec const &spec) { os << spec.getBoundVar(); });
- os << "} ";
- }
-
- // Dimension specifiers.
- os << '(';
- llvm::interleaveComma(
- dimSpecs, os, [&](DimSpec const &spec) { spec.print(os, wantElision); });
- os << ") -> (";
- // Level specifiers.
- wantElision = wantElision && !mustPrintLvlVars;
- llvm::interleaveComma(
- lvlSpecs, os, [&](LvlSpec const &spec) { spec.print(os, wantElision); });
- os << ')';
-}
-
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index b3200d0983eb790..8563d8f7e936ca4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -77,40 +77,19 @@ class DimLvlExpr {
//
// Getters for handling `AffineExpr` subclasses.
//
- Var castAnyVar() const;
- std::optional<Var> dyn_castAnyVar() const;
SymVar castSymVar() const;
std::optional<SymVar> dyn_castSymVar() const;
Var castDimLvlVar() const;
std::optional<Var> dyn_castDimLvlVar() const;
- int64_t castConstantValue() const;
- std::optional<int64_t> dyn_castConstantValue() const;
- bool hasConstantValue(int64_t val) const;
- DimLvlExpr getLHS() const;
- DimLvlExpr getRHS() const;
std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;
/// Checks whether the variables bound/used by this spec are valid
/// with respect to the given ranks.
[[nodiscard]] bool isValid(Ranks const &ranks) const;
- std::string str() const;
- void print(llvm::raw_ostream &os) const;
- void print(AsmPrinter &printer) const;
- void dump() const;
-
protected:
// Variant of `mlir::AsmPrinter::Impl::BindingStrength`
enum class BindingStrength : bool { Weak = false, Strong = true };
-
- void printAffineExprInternal(llvm::raw_ostream &os,
- BindingStrength enclosingTightness) const;
- void printStrong(llvm::raw_ostream &os) const {
- printAffineExprInternal(os, BindingStrength::Strong);
- }
- void printWeak(llvm::raw_ostream &os) const {
- printAffineExprInternal(os, BindingStrength::Weak);
- }
};
static_assert(IsZeroCostAbstraction<DimLvlExpr>);
@@ -208,11 +187,6 @@ class DimSpec final {
/// to be vacuously valid, and therefore calling `setExpr` invalidates
/// the result of this predicate.
[[nodiscard]] bool isValid(Ranks const &ranks) const;
-
- std::string str(bool wantElision = true) const;
- void print(llvm::raw_ostream &os, bool wantElision = true) const;
- void print(AsmPrinter &printer, bool wantElision = true) const;
- void dump() const;
};
static_assert(IsZeroCostAbstraction<DimSpec>);
@@ -248,11 +222,6 @@ class LvlSpec final {
/// Checks whether the variables bound/used by this spec are valid
/// with respect to the given ranks.
[[nodiscard]] bool isValid(Ranks const &ranks) const;
-
- std::string str(bool wantElision = true) const;
- void print(llvm::raw_ostream &os, bool wantElision = true) const;
- void print(AsmPrinter &printer, bool wantElision = true) const;
- void dump() const;
};
static_assert(IsZeroCostAbstraction<LvlSpec>);
@@ -282,11 +251,6 @@ class DimLvlMap final {
AffineMap getDimToLvlMap(MLIRContext *context) const;
AffineMap getLvlToDimMap(MLIRContext *context) const;
- std::string str(bool wantElision = true) const;
- void print(llvm::raw_ostream &os, bool wantElision = true) const;
- void print(AsmPrinter &printer, bool wantElision = true) const;
- void dump() const;
-
private:
/// Checks for integrity of variable-binding structure.
/// This is already called by the ctor.
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 966e32401c1f9e3..481275f052a3cee 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -84,36 +84,6 @@ bool VarSet::contains(Var var) const {
return num < bits.size() && bits[num];
}
-bool VarSet::occursIn(VarSet const &other) const {
- for (const auto vk : everyVarKind)
- if (impl[vk].anyCommon(other.impl[vk]))
- return true;
- return false;
-}
-
-bool VarSet::occursIn(DimLvlExpr expr) const {
- if (!expr)
- return false;
- switch (expr.getAffineKind()) {
- case AffineExprKind::Constant:
- return false;
- case AffineExprKind::SymbolId:
- return contains(expr.castSymVar());
- case AffineExprKind::DimId:
- return contains(expr.castDimLvlVar());
- case AffineExprKind::Add:
- case AffineExprKind::Mul:
- case AffineExprKind::Mod:
- case AffineExprKind::FloorDiv:
- case AffineExprKind::CeilDiv: {
- const auto [lhs, op, rhs] = expr.unpackBinop();
- (void)op;
- return occursIn(lhs) || occursIn(rhs);
- }
- }
- llvm_unreachable("unknown AffineExprKind");
-}
-
void VarSet::add(Var var) {
// NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
impl[var.getKind()][var.getNum()] = true;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index 81f480187c059e7..dce8b003b013bb9 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -36,14 +36,6 @@ enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 };
return 0 <= vk_ && vk_ <= 2;
}
-/// Swaps `Dimension` and `Level`, but leaves `Symbol` the same.
-constexpr VarKind flipVarKind(VarKind vk) {
- return VarKind{2 - llvm::to_underlying(vk)};
-}
-static_assert(flipVarKind(VarKind::Symbol) == VarKind::Symbol &&
- flipVarKind(VarKind::Dimension) == VarKind::Level &&
- flipVarKind(VarKind::Level) == VarKind::Dimension);
-
/// Gets the ASCII character used as the prefix when printing `Var`.
constexpr char toChar(VarKind vk) {
// If `isWF(vk)` then this computation's intermediate results are always
@@ -260,12 +252,10 @@ class VarSet final {
Ranks getRanks() const {
return Ranks(getSymRank(), getDimRank(), getLvlRank());
}
- /// For the `contains`/`occursIn` methods: if variables occurring in
+ /// For the `contains` method: if variables occurring in
/// the method parameter are OOB for the `VarSet`, then these methods will
/// always return false.
bool contains(Var var) const;
- bool occursIn(VarSet const &vars) const;
- bool occursIn(DimLvlExpr expr) const;
/// For the `add` methods: OOB parameters cause undefined behavior.
/// Currently the `add` methods will raise an assertion error.
@@ -319,9 +309,6 @@ class VarInfo final {
assert(hasNum());
return Var(kind, *num);
}
- constexpr std::optional<Var> tryGetVar() const {
- return num ? std::make_optional(Var(kind, *num)) : std::nullopt;
- }
};
//===----------------------------------------------------------------------===//
@@ -405,12 +392,6 @@ class VarEnv final {
/// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
/// failure if the variable is not bound.
Var getVar(VarInfo::ID id) const { return access(id).getVar(); }
-
- /// Gets the `Var` identified by the `VarInfo::ID`, returning nullopt
- /// if the variable is not bound.
- std::optional<Var> tryGetVar(VarInfo::ID id) const {
- return access(id).tryGetVar();
- }
};
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list