[Mlir-commits] [mlir] 78921a6 - [mlir][sparse] Add more helper methods for converting DimLvlExpr to Var

Aart Bik llvmlistbot at llvm.org
Tue Aug 22 20:01:15 PDT 2023


Author: wren romano
Date: 2023-08-22T20:00:59-07:00
New Revision: 78921a64f74facb4a2aa1552a89d2e7c579884d1

URL: https://github.com/llvm/llvm-project/commit/78921a64f74facb4a2aa1552a89d2e7c579884d1
DIFF: https://github.com/llvm/llvm-project/commit/78921a64f74facb4a2aa1552a89d2e7c579884d1.diff

LOG: [mlir][sparse] Add more helper methods for converting DimLvlExpr to Var

These new methods help clean up some code for doing LvlExpr-analysis during DimExpr-inference.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D157647

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 6efcd0215ec6da..792626b45283ea 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -16,19 +16,46 @@ using namespace mlir::sparse_tensor::ir_detail;
 // `DimLvlExpr` implementation.
 //===----------------------------------------------------------------------===//
 
+Var DimLvlExpr::castAnyVar() const {
+  assert(expr && "uninitialized DimLvlExpr");
+  const auto var = dyn_castAnyVar();
+  assert(var && "expected DimLvlExpr to be a Var");
+  return *var;
+}
+
+std::optional<Var> DimLvlExpr::dyn_castAnyVar() const {
+  if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>())
+    return SymVar(s);
+  if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>())
+    return Var(getAllowedVarKind(), x);
+  return std::nullopt;
+}
+
 SymVar DimLvlExpr::castSymVar() const {
   return SymVar(expr.cast<AffineSymbolExpr>());
 }
 
+std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const {
+  if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>())
+    return SymVar(s);
+  return std::nullopt;
+}
+
 Var DimLvlExpr::castDimLvlVar() const {
   return Var(getAllowedVarKind(), expr.cast<AffineDimExpr>());
 }
 
+std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
+  if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>())
+    return Var(getAllowedVarKind(), x);
+  return std::nullopt;
+}
+
 int64_t DimLvlExpr::castConstantValue() const {
   return expr.cast<AffineConstantExpr>().getValue();
 }
 
-std::optional<int64_t> DimLvlExpr::tryGetConstantValue() const {
+std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
   const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
   return k ? std::make_optional(k.getValue()) : std::nullopt;
 }
@@ -98,7 +125,7 @@ static std::optional<MatchNeg> matchNeg(DimLvlExpr expr) {
       return MatchNeg{DimLvlExpr{expr.getExprKind(), AffineExpr()}, val};
   }
   if (op == AffineExprKind::Mul)
-    if (const auto rval = rhs.tryGetConstantValue(); rval && *rval < 0)
+    if (const auto rval = rhs.dyn_castConstantValue(); rval && *rval < 0)
       return MatchNeg{lhs, *rval};
   return std::nullopt;
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index 5552e2fe0fd13d..040d7ea919a642 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -105,10 +105,14 @@ class DimLvlExpr {
   // TODO(wrengr): Most if not all of these don't actually need to be
   // methods, they could be free-functions instead.
   //
+  Var castAnyVar() const;
+  std::optional<Var> dyn_castAnyVar() const;
   SymVar castSymVar() const;
+  std::optional<SymVar> dyn_castSymVar() const;
   Var castDimLvlVar() const;
+  std::optional<Var> dyn_castDimLvlVar() const;
   int64_t castConstantValue() const;
-  std::optional<int64_t> tryGetConstantValue() const;
+  std::optional<int64_t> dyn_castConstantValue() const;
   bool hasConstantValue(int64_t val) const;
   DimLvlExpr getLHS() const;
   DimLvlExpr getRHS() const;
@@ -155,6 +159,12 @@ class DimExpr final : public DimLvlExpr {
     return expr->getExprKind() == Kind;
   }
   constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
+
+  LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); }
+  std::optional<LvlVar> dyn_castLvlVar() const {
+    const auto var = dyn_castDimLvlVar();
+    return var ? std::make_optional(var->cast<LvlVar>()) : std::nullopt;
+  }
 };
 static_assert(IsZeroCostAbstraction<DimExpr>);
 
@@ -169,6 +179,12 @@ class LvlExpr final : public DimLvlExpr {
     return expr->getExprKind() == Kind;
   }
   constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
+
+  DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); }
+  std::optional<DimVar> dyn_castDimVar() const {
+    const auto var = dyn_castDimLvlVar();
+    return var ? std::make_optional(var->cast<DimVar>()) : std::nullopt;
+  }
 };
 static_assert(IsZeroCostAbstraction<LvlExpr>);
 


        


More information about the Mlir-commits mailing list