[Mlir-commits] [mlir] 78921a6 - [mlir][sparse] Add more helper methods for converting DimLvlExpr to Var
Aart Bik
llvmlistbot at llvm.org
Tue Aug 22 20:01:15 PDT 2023
Author: wren romano
Date: 2023-08-22T20:00:59-07:00
New Revision: 78921a64f74facb4a2aa1552a89d2e7c579884d1
URL: https://github.com/llvm/llvm-project/commit/78921a64f74facb4a2aa1552a89d2e7c579884d1
DIFF: https://github.com/llvm/llvm-project/commit/78921a64f74facb4a2aa1552a89d2e7c579884d1.diff
LOG: [mlir][sparse] Add more helper methods for converting DimLvlExpr to Var
These new methods help clean up some code for doing LvlExpr-analysis during DimExpr-inference.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D157647
Added:
Modified:
mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 6efcd0215ec6da..792626b45283ea 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -16,19 +16,46 @@ using namespace mlir::sparse_tensor::ir_detail;
// `DimLvlExpr` implementation.
//===----------------------------------------------------------------------===//
+Var DimLvlExpr::castAnyVar() const {
+ assert(expr && "uninitialized DimLvlExpr");
+ const auto var = dyn_castAnyVar();
+ assert(var && "expected DimLvlExpr to be a Var");
+ return *var;
+}
+
+std::optional<Var> DimLvlExpr::dyn_castAnyVar() const {
+ if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>())
+ return SymVar(s);
+ if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>())
+ return Var(getAllowedVarKind(), x);
+ return std::nullopt;
+}
+
SymVar DimLvlExpr::castSymVar() const {
return SymVar(expr.cast<AffineSymbolExpr>());
}
+std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const {
+ if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>())
+ return SymVar(s);
+ return std::nullopt;
+}
+
Var DimLvlExpr::castDimLvlVar() const {
return Var(getAllowedVarKind(), expr.cast<AffineDimExpr>());
}
+std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
+ if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>())
+ return Var(getAllowedVarKind(), x);
+ return std::nullopt;
+}
+
int64_t DimLvlExpr::castConstantValue() const {
return expr.cast<AffineConstantExpr>().getValue();
}
-std::optional<int64_t> DimLvlExpr::tryGetConstantValue() const {
+std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
return k ? std::make_optional(k.getValue()) : std::nullopt;
}
@@ -98,7 +125,7 @@ static std::optional<MatchNeg> matchNeg(DimLvlExpr expr) {
return MatchNeg{DimLvlExpr{expr.getExprKind(), AffineExpr()}, val};
}
if (op == AffineExprKind::Mul)
- if (const auto rval = rhs.tryGetConstantValue(); rval && *rval < 0)
+ if (const auto rval = rhs.dyn_castConstantValue(); rval && *rval < 0)
return MatchNeg{lhs, *rval};
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index 5552e2fe0fd13d..040d7ea919a642 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -105,10 +105,14 @@ class DimLvlExpr {
// TODO(wrengr): Most if not all of these don't actually need to be
// methods, they could be free-functions instead.
//
+ Var castAnyVar() const;
+ std::optional<Var> dyn_castAnyVar() const;
SymVar castSymVar() const;
+ std::optional<SymVar> dyn_castSymVar() const;
Var castDimLvlVar() const;
+ std::optional<Var> dyn_castDimLvlVar() const;
int64_t castConstantValue() const;
- std::optional<int64_t> tryGetConstantValue() const;
+ std::optional<int64_t> dyn_castConstantValue() const;
bool hasConstantValue(int64_t val) const;
DimLvlExpr getLHS() const;
DimLvlExpr getRHS() const;
@@ -155,6 +159,12 @@ class DimExpr final : public DimLvlExpr {
return expr->getExprKind() == Kind;
}
constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
+
+ LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); }
+ std::optional<LvlVar> dyn_castLvlVar() const {
+ const auto var = dyn_castDimLvlVar();
+ return var ? std::make_optional(var->cast<LvlVar>()) : std::nullopt;
+ }
};
static_assert(IsZeroCostAbstraction<DimExpr>);
@@ -169,6 +179,12 @@ class LvlExpr final : public DimLvlExpr {
return expr->getExprKind() == Kind;
}
constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
+
+ DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); }
+ std::optional<DimVar> dyn_castDimVar() const {
+ const auto var = dyn_castDimLvlVar();
+ return var ? std::make_optional(var->cast<DimVar>()) : std::nullopt;
+ }
};
static_assert(IsZeroCostAbstraction<LvlExpr>);
More information about the Mlir-commits
mailing list