[Mlir-commits] [mlir] fdbe931 - [mlir][sparse] Adding getters/setters to `DimLvlMap`
wren romano
llvmlistbot at llvm.org
Tue Aug 1 12:55:52 PDT 2023
Author: wren romano
Date: 2023-08-01T12:55:45-07:00
New Revision: fdbe9312b1c626ea61a2456db94cd52109ff1a50
URL: https://github.com/llvm/llvm-project/commit/fdbe9312b1c626ea61a2456db94cd52109ff1a50
DIFF: https://github.com/llvm/llvm-project/commit/fdbe9312b1c626ea61a2456db94cd52109ff1a50.diff
LOG: [mlir][sparse] Adding getters/setters to `DimLvlMap`
Depends On D156768
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D156770
Added:
Modified:
mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index cbdca742ff7e4d..6efcd0215ec6da 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -262,10 +262,8 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
// needs to happen before the code for setting every `LvlSpec::elideVar`,
// since if the LvlVar is only used in elided DimExpr, then the
// LvlVar should also be elided.
- // NOTE: Whenever we set a new DimExpr, we must make sure to validate it
- // against our ranks, to restore the invariant established by `isWF` above.
- // TODO(wrengr): We might should adjust the `DimLvlExpr` ctor to take a
- // `Ranks` argument and perform the validation then.
+ // NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs,
+ // to ensure that we maintain the invariant established by `isWF` above.
// Third, we set every `LvlSpec::elideVar` according to whether that
// LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
@@ -300,6 +298,22 @@ bool DimLvlMap::isWF() const {
return true;
}
+AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const {
+ SmallVector<AffineExpr> lvlAffines;
+ lvlAffines.reserve(getLvlRank());
+ for (const auto &lvlSpec : lvlSpecs)
+ lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr());
+ return AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context);
+}
+
+AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
+ SmallVector<AffineExpr> dimAffines;
+ dimAffines.reserve(getDimRank());
+ for (const auto &dimSpec : dimSpecs)
+ dimAffines.push_back(dimSpec.getExpr().getAffineExpr());
+ return AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context);
+}
+
void DimLvlMap::dump() const {
print(llvm::errs(), /*wantElision=*/false);
llvm::errs() << "\n";
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index b1d3e437621c55..c39cd9a3e96f75 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -290,16 +290,6 @@ static_assert(IsZeroCostAbstraction<LvlSpec>);
//===----------------------------------------------------------------------===//
class DimLvlMap final {
- // TODO(wrengr): Need to define getters
- unsigned symRank;
- SmallVector<DimSpec> dimSpecs;
- SmallVector<LvlSpec> lvlSpecs;
- bool mustPrintLvlVars;
-
- // Checks for integrity of variable-binding structure.
- // This is already called by the ctor.
- [[nodiscard]] bool isWF() const;
-
public:
DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
ArrayRef<LvlSpec> lvlSpecs);
@@ -310,11 +300,41 @@ class DimLvlMap final {
unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }
- DimLevelType getDimLevelType(unsigned i) { return lvlSpecs[i].getType(); }
+ ArrayRef<DimSpec> getDims() const { return dimSpecs; }
+ const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; }
+ SparseTensorDimSliceAttr getDimSlice(Dimension dim) const {
+ return getDim(dim).getSlice();
+ }
+
+ ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; }
+ const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; }
+ DimLevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }
+
+ AffineMap getDimToLvlMap(MLIRContext *context) const;
+ AffineMap getLvlToDimMap(MLIRContext *context) const;
void print(llvm::raw_ostream &os, bool wantElision = true) const;
void print(AsmPrinter &printer, bool wantElision = true) const;
void dump() const;
+
+private:
+ /// Checks for integrity of variable-binding structure.
+ /// This is already called by the ctor.
+ [[nodiscard]] bool isWF() const;
+
+ /// Helper function to call `DimSpec::setExpr` while asserting that
+ /// the invariant established by `DimLvlMap:isWF` is maintained.
+ /// This is used by the ctor.
+ void setDimExpr(Dimension dim, DimExpr expr) {
+ assert(expr && getRanks().isValid(expr));
+ dimSpecs[dim].setExpr(expr);
+ }
+
+ // All these fields are const-after-ctor.
+ unsigned symRank;
+ SmallVector<DimSpec> dimSpecs;
+ SmallVector<LvlSpec> lvlSpecs;
+ bool mustPrintLvlVars;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 63f55bd43e8db3..15dae63649ede3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -115,7 +115,7 @@ bool VarSet::occursIn(DimLvlExpr expr) const {
}
void VarSet::add(Var var) {
- // NOTE: `SmallBitVactor::operator[]` will raise assertion errors for OOB.
+ // NOTE: `SmallBitVector::operator[]` will raise assertion errors for OOB.
impl[var.getKind()][var.getNum()] = true;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 7dd15b96b30656..db31ae0f0433d2 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -530,8 +530,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
RETURN_ON_FAIL(res);
// Proof of concept result.
// TODO: use DimLvlMap directly as storage representation
- for (unsigned i = 0, e = res->getLvlRank(); i < e; i++)
- lvlTypes.push_back(res->getDimLevelType(i));
+ for (Level lvl = 0, lvlRank = res->getLvlRank(); lvl < lvlRank; lvl++)
+ lvlTypes.push_back(res->getLvlType(lvl));
}
// Only the last item can omit the comma
More information about the Mlir-commits
mailing list