[Mlir-commits] [mlir] [mlir][affine] cleanup deprecated T.cast style functions (PR #71269)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 3 20:50:54 PDT 2023


https://github.com/lipracer created https://github.com/llvm/llvm-project/pull/71269

None

>From 594c8af4225634a2bfefad20f167062e810a25c0 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sat, 4 Nov 2023 11:43:55 +0800
Subject: [PATCH] [mlir][affine] cleanup deprecated T.cast style functions

---
 mlir/include/mlir/IR/AffineExpr.h             | 52 +++++++++++++++----
 mlir/include/mlir/IR/AffineMap.h              |  8 +--
 .../mlir/Interfaces/VectorInterfaces.td       |  8 +--
 .../Analysis/FlatLinearValueConstraints.cpp   |  2 +-
 .../Affine/Analysis/AffineStructures.cpp      |  4 +-
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 24 ++++-----
 .../Affine/Transforms/DecomposeAffineOps.cpp  |  8 +--
 .../Affine/Transforms/ReifyValueBounds.cpp    |  4 +-
 .../Arith/Transforms/ReifyValueBounds.cpp     | 20 +++----
 .../GPU/TransformOps/GPUTransformOps.cpp      |  2 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  2 +-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  2 +-
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  8 +--
 .../SparseTensor/Transforms/CodegenUtils.cpp  |  6 +--
 .../Transforms/Sparsification.cpp             |  2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 20 +++----
 mlir/lib/IR/AffineMap.cpp                     | 27 +++++++---
 17 files changed, 122 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 69e02c94ef2708d..40e9d28ce5d3a01 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -82,13 +82,17 @@ class AffineExpr {
   bool operator!() const { return expr == nullptr; }
 
   template <typename U>
-  constexpr bool isa() const;
+  [[deprecated("Use llvm::isa<U>() instead")]] constexpr bool isa() const;
+
   template <typename U>
-  U dyn_cast() const;
+  [[deprecated("Use llvm::dyn_cast<U>() instead")]] U dyn_cast() const;
+
   template <typename U>
-  U dyn_cast_or_null() const;
+  [[deprecated("Use llvm::dyn_cast_or_null<U>() instead")]] U
+  dyn_cast_or_null() const;
+
   template <typename U>
-  U cast() const;
+  [[deprecated("Use llvm::cast<U>() instead")]] U cast() const;
 
   MLIRContext *getContext() const;
 
@@ -194,6 +198,8 @@ class AffineExpr {
         reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
   }
 
+  ImplType *getImpl() const { return expr; }
+
 protected:
   ImplType *expr{nullptr};
 };
@@ -281,18 +287,15 @@ constexpr bool AffineExpr::isa() const {
 }
 template <typename U>
 U AffineExpr::dyn_cast() const {
-  if (isa<U>())
-    return U(expr);
-  return U(nullptr);
+  return llvm::dyn_cast<U>(*this);
 }
 template <typename U>
 U AffineExpr::dyn_cast_or_null() const {
-  return (!*this || !isa<U>()) ? U(nullptr) : U(expr);
+  return llvm::dyn_cast_or_null<U>(*this);
 }
 template <typename U>
 U AffineExpr::cast() const {
-  assert(isa<U>());
-  return U(expr);
+  return llvm::cast<U>(*this);
 }
 
 /// Simplify an affine expression by flattening and some amount of simple
@@ -390,6 +393,35 @@ struct DenseMapInfo<mlir::AffineExpr> {
   }
 };
 
+/// Add support for llvm style casts. We provide a cast between To and From if
+/// From is mlir::AffineExpr or derives from it.
+template <typename To, typename From>
+struct CastInfo<To, From,
+                std::enable_if_t<std::is_same_v<mlir::AffineExpr,
+                                                std::remove_const_t<From>> ||
+                                 std::is_base_of_v<mlir::AffineExpr, From>>>
+    : NullableValueCastFailed<To>,
+      DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+  static inline bool isPossible(mlir::AffineExpr expr) {
+    /// Return a constant true instead of a dynamic true when casting to self or
+    /// up the hierarchy.
+    if constexpr (std::is_base_of_v<To, From>) {
+      return true;
+    } else {
+      if constexpr (std::is_same_v<To, ::mlir::AffineBinaryOpExpr>)
+        return expr.getKind() <= ::mlir::AffineExprKind::LAST_AFFINE_BINARY_OP;
+      if constexpr (std::is_same_v<To, ::mlir::AffineDimExpr>)
+        return expr.getKind() == ::mlir::AffineExprKind::DimId;
+      if constexpr (std::is_same_v<To, ::mlir::AffineSymbolExpr>)
+        return expr.getKind() == ::mlir::AffineExprKind::SymbolId;
+      if constexpr (std::is_same_v<To, ::mlir::AffineConstantExpr>)
+        return expr.getKind() == ::mlir::AffineExprKind::Constant;
+    }
+  }
+  static inline To doCast(mlir::AffineExpr expr) { return To(expr.getImpl()); }
+};
+
 } // namespace llvm
 
 #endif // MLIR_IR_AFFINEEXPR_H
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 5af7835258f6bd2..c1231102c3cbe76 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -636,9 +636,9 @@ SmallVector<T> applyPermutationMap(AffineMap map, llvm::ArrayRef<T> source) {
   SmallVector<T> result;
   result.reserve(map.getNumResults());
   for (AffineExpr expr : map.getResults()) {
-    if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+    if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
       result.push_back(source[dimExpr.getPosition()]);
-    } else if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+    } else if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
       assert(constExpr.getValue() == 0 &&
              "Unexpected constant in projected permutation map");
       result.push_back(0);
@@ -657,9 +657,9 @@ static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
   for (const auto &exprs : exprsList) {
     for (auto expr : exprs) {
       expr.walk([&maxDim, &maxSym](AffineExpr e) {
-        if (auto d = e.dyn_cast<AffineDimExpr>())
+        if (auto d = dyn_cast<AffineDimExpr>(e))
           maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition()));
-        if (auto s = e.dyn_cast<AffineSymbolExpr>())
+        if (auto s = dyn_cast<AffineSymbolExpr>(e))
           maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition()));
       });
     }
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 026faf269f368de..66b1b0b70696e8e 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -120,8 +120,8 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         auto expr = $_op.getPermutationMap().getResult(idx);
-        return expr.template isa<::mlir::AffineConstantExpr>() &&
-               expr.template dyn_cast<::mlir::AffineConstantExpr>().getValue() == 0;
+        return ::llvm::isa<::mlir::AffineConstantExpr>(expr) &&
+               ::llvm::dyn_cast<::mlir::AffineConstantExpr>(expr).getValue() == 0;
       }]
     >,
     InterfaceMethod<
@@ -278,9 +278,9 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
           AffineExpr dim = std::get<0>(vecDims);
           int64_t size = std::get<1>(vecDims);
           // Skip broadcast.
-          if (dim.isa<AffineConstantExpr>())
+          if (isa<AffineConstantExpr>(dim))
             continue;
-          dimSizes[dim.cast<AffineDimExpr>().getPosition()] = size;
+          dimSizes[cast<AffineDimExpr>(dim).getPosition()] = size;
         }
         return dimSizes;
       }]
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 382d05f3b2d4851..b838d461c398c83 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -305,7 +305,7 @@ static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos,
     // `var_n`), we can proceed.
     // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
     // to dims themselves.
-    auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
+    auto dimExpr = dyn_cast<AffineDimExpr>(dividendExpr);
     if (!dimExpr)
       continue;
 
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index 6ed3ba14fe15229..469298d3e8f43ff 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -167,10 +167,10 @@ FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
         // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
         // Make sure we skip those cases by checking that the lb result is not
         // just a constant.
-        !lbMap.getResult(0).isa<AffineConstantExpr>()) {
+        !isa<AffineConstantExpr>(lbMap.getResult(0))) {
       // Limited support: we expect the lb result to be just a loop dimension.
       // Not supported otherwise for now.
-      AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
+      AffineDimExpr result = dyn_cast<AffineDimExpr>(lbMap.getResult(0));
       if (!result)
         return failure();
 
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index ba4285bd52394f3..5b8916b08fc5c4b 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -573,9 +573,9 @@ OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
 
   // Fold dims and symbols to existing values.
   auto expr = map.getResult(0);
-  if (auto dim = expr.dyn_cast<AffineDimExpr>())
+  if (auto dim = dyn_cast<AffineDimExpr>(expr))
     return getOperand(dim.getPosition());
-  if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
+  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
     return getOperand(map.getNumDims() + sym.getPosition());
 
   // Otherwise, default to folding the map.
@@ -627,7 +627,7 @@ static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
 /// being an affine dim expression or a constant.
 static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands,
                                    int64_t k) {
-  if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
+  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
     int64_t constVal = constExpr.getValue();
     return constVal >= 0 && constVal < k;
   }
@@ -715,7 +715,7 @@ static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
     constUpperBounds.push_back(getUpperBound(operand));
   }
 
-  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
+  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
     return constExpr.getValue();
 
   return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
@@ -739,7 +739,7 @@ static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
   }
 
   std::optional<int64_t> lowerBound;
-  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
     lowerBound = constExpr.getValue();
   } else {
     lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
@@ -775,7 +775,7 @@ static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
   // The `lhs` and `rhs` may be different post construction of simplified expr.
   lhs = binExpr.getLHS();
   rhs = binExpr.getRHS();
-  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
   if (!rhsConst)
     return;
 
@@ -879,7 +879,7 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
   lowerBounds.reserve(map.getNumResults());
   upperBounds.reserve(map.getNumResults());
   for (AffineExpr e : map.getResults()) {
-    if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
+    if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
       lowerBounds.push_back(constExpr.getValue());
       upperBounds.push_back(constExpr.getValue());
     } else {
@@ -2066,7 +2066,7 @@ static void printBound(AffineMapAttr boundMap,
 
     // Print constant bound.
     if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
-      if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+      if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
         p << constExpr.getValue();
         return;
       }
@@ -3304,13 +3304,13 @@ struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
     // with the corresponding operand which is the result of another affine
     // min/max op. If So it can be merged into this affine op.
     for (AffineExpr expr : oldMap.getResults()) {
-      if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
+      if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
         Value symValue = symOperands[symExpr.getPosition()];
         if (auto producerOp = symValue.getDefiningOp<T>()) {
           producerOps.push_back(producerOp);
           continue;
         }
-      } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+      } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
         Value dimValue = dimOperands[dimExpr.getPosition()];
         if (auto producerOp = dimValue.getDefiningOp<T>()) {
           producerOps.push_back(producerOp);
@@ -3760,7 +3760,7 @@ std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
   out.reserve(rangesValueMap.getNumResults());
   for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
     auto expr = rangesValueMap.getResult(i);
-    auto cst = expr.dyn_cast<AffineConstantExpr>();
+    auto cst = dyn_cast<AffineConstantExpr>(expr);
     if (!cst)
       return std::nullopt;
     out.push_back(cst.getValue());
@@ -4188,7 +4188,7 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
     SmallVector<int64_t, 4> steps;
     auto stepsMap = stepsMapAttr.getValue();
     for (const auto &result : stepsMap.getResults()) {
-      auto constExpr = result.dyn_cast<AffineConstantExpr>();
+      auto constExpr = dyn_cast<AffineConstantExpr>(result);
       if (!constExpr)
         return parser.emitError(parser.getNameLoc(),
                                 "steps must be constant integers");
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index e87c5c030c5b9a2..e5501e848c1646a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -102,12 +102,12 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
     return rewriter.notifyMatchFailure(op, "expected no dims");
 
   AffineExpr remainingExp = m.getResult(0);
-  auto binExpr = remainingExp.dyn_cast<AffineBinaryOpExpr>();
+  auto binExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
   if (!binExpr)
     return rewriter.notifyMatchFailure(op, "terminal affine.apply");
 
-  if (!binExpr.getLHS().isa<AffineBinaryOpExpr>() &&
-      !binExpr.getRHS().isa<AffineBinaryOpExpr>())
+  if (!isa<AffineBinaryOpExpr>(binExpr.getLHS()) &&
+      !isa<AffineBinaryOpExpr>(binExpr.getRHS()))
     return rewriter.notifyMatchFailure(op, "terminal affine.apply");
 
   bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) ||
@@ -123,7 +123,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
   MLIRContext *ctx = op->getContext();
   SmallVector<AffineExpr> subExpressions;
   while (true) {
-    auto currentBinExpr = remainingExp.dyn_cast<AffineBinaryOpExpr>();
+    auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
     if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
       subExpressions.push_back(remainingExp);
       LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 4990229dfd3c876..37b36f76d4465df 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -70,9 +70,9 @@ OpFoldResult affine::materializeComputedBound(
         b.getIndexAttr(boundMap.getSingleConstantResult()));
   }
   // No affine.apply op is needed if the bound is a single SSA value.
-  if (auto expr = boundMap.getResult(0).dyn_cast<AffineDimExpr>())
+  if (auto expr = dyn_cast<AffineDimExpr>(boundMap.getResult(0)))
     return static_cast<OpFoldResult>(operands[expr.getPosition()]);
-  if (auto expr = boundMap.getResult(0).dyn_cast<AffineSymbolExpr>())
+  if (auto expr = dyn_cast<AffineSymbolExpr>(boundMap.getResult(0)))
     return static_cast<OpFoldResult>(
         operands[expr.getPosition() + boundMap.getNumDims()]);
   // General case: build affine.apply op.
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 8eddd811dbea4d9..8d9fd1478aa9e61 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -24,34 +24,34 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
     switch (e.getKind()) {
     case AffineExprKind::Constant:
       return b.create<ConstantIndexOp>(loc,
-                                       e.cast<AffineConstantExpr>().getValue());
+                                       cast<AffineConstantExpr>(e).getValue());
     case AffineExprKind::DimId:
-      return operands[e.cast<AffineDimExpr>().getPosition()];
+      return operands[cast<AffineDimExpr>(e).getPosition()];
     case AffineExprKind::SymbolId:
-      return operands[e.cast<AffineSymbolExpr>().getPosition() +
+      return operands[cast<AffineSymbolExpr>(e).getPosition() +
                       map.getNumDims()];
     case AffineExprKind::Add: {
-      auto binaryExpr = e.cast<AffineBinaryOpExpr>();
+      auto binaryExpr = cast<AffineBinaryOpExpr>(e);
       return b.create<AddIOp>(loc, buildExpr(binaryExpr.getLHS()),
                               buildExpr(binaryExpr.getRHS()));
     }
     case AffineExprKind::Mul: {
-      auto binaryExpr = e.cast<AffineBinaryOpExpr>();
+      auto binaryExpr = cast<AffineBinaryOpExpr>(e);
       return b.create<MulIOp>(loc, buildExpr(binaryExpr.getLHS()),
                               buildExpr(binaryExpr.getRHS()));
     }
     case AffineExprKind::FloorDiv: {
-      auto binaryExpr = e.cast<AffineBinaryOpExpr>();
+      auto binaryExpr = cast<AffineBinaryOpExpr>(e);
       return b.create<DivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
                                buildExpr(binaryExpr.getRHS()));
     }
     case AffineExprKind::CeilDiv: {
-      auto binaryExpr = e.cast<AffineBinaryOpExpr>();
+      auto binaryExpr = cast<AffineBinaryOpExpr>(e);
       return b.create<CeilDivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
                                    buildExpr(binaryExpr.getRHS()));
     }
     case AffineExprKind::Mod: {
-      auto binaryExpr = e.cast<AffineBinaryOpExpr>();
+      auto binaryExpr = cast<AffineBinaryOpExpr>(e);
       return b.create<RemSIOp>(loc, buildExpr(binaryExpr.getLHS()),
                                buildExpr(binaryExpr.getRHS()));
     }
@@ -106,9 +106,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
         b.getIndexAttr(boundMap.getSingleConstantResult()));
   }
   // No arith ops are needed if the bound is a single SSA value.
-  if (auto expr = boundMap.getResult(0).dyn_cast<AffineDimExpr>())
+  if (auto expr = dyn_cast<AffineDimExpr>(boundMap.getResult(0)))
     return static_cast<OpFoldResult>(operands[expr.getPosition()]);
-  if (auto expr = boundMap.getResult(0).dyn_cast<AffineSymbolExpr>())
+  if (auto expr = dyn_cast<AffineSymbolExpr>(boundMap.getResult(0)))
     return static_cast<OpFoldResult>(
         operands[expr.getPosition() + boundMap.getNumDims()]);
   // General case: build Arith ops.
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index d4908fa7e89e736..a668e2436a8d24d 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -151,7 +151,7 @@ gpuMmaUnrollOrder(vector::ContractionOp contract) {
 
   llvm::SmallDenseSet<int64_t> dims;
   for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
-    dims.insert(expr.cast<AffineDimExpr>().getPosition());
+    dims.insert(cast<AffineDimExpr>(expr).getPosition());
   }
   // Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
   for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5a593fbb2b6024d..d12ba8c4c59b33f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2178,7 +2178,7 @@ static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
     for (unsigned i = 0; i < sourceShape.size(); i++) {
       if (sourceType.isDynamicDim(i))
         continue;
-      if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
+      if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
         affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
     }
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 5f566d8b10aef73..cae7b50b0fb3b47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -192,7 +192,7 @@ struct LinalgOpTilingInterface
     }
     for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
       unsigned dimPosition =
-          resultExpr.value().template cast<AffineDimExpr>().getPosition();
+          cast<AffineDimExpr>(resultExpr.value()).getPosition();
       iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
       iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
     }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index aedac67865aacec..ecf45af4c5d966f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -830,20 +830,20 @@ bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
     return false;
   std::map<unsigned, int64_t> coeffientMap;
   for (auto result : dimToLvl.getResults()) {
-    if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
-      auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+    if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
+      auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
       if (result.getKind() == AffineExprKind::FloorDiv) {
         // Expect only one floordiv for each dimension.
         if (coeffientMap.find(pos) != coeffientMap.end())
           return false;
         coeffientMap[pos] =
-            binOp.getRHS().dyn_cast<AffineConstantExpr>().getValue();
+            dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue();
       } else if (result.getKind() == AffineExprKind::Mod) {
         // Expect floordiv before mod.
         if (coeffientMap.find(pos) == coeffientMap.end())
           return false;
         // Expect mod to have the same coefficient as floordiv.
-        if (binOp.getRHS().dyn_cast<AffineConstantExpr>().getValue() !=
+        if (dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue() !=
             coeffientMap[pos]) {
           return false;
         }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index db969436a30712d..6d19700ee260bc4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -722,9 +722,9 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
       break;
     }
     case AffineExprKind::Mod: {
-      auto mod = exp.cast<AffineBinaryOpExpr>();
-      d = mod.getLHS().cast<AffineDimExpr>().getPosition();
-      cm = mod.getRHS().cast<AffineConstantExpr>().getValue();
+      auto mod = cast<AffineBinaryOpExpr>(exp);
+      d = cast<AffineDimExpr>(mod.getLHS()).getPosition();
+      cm = cast<AffineConstantExpr>(mod.getRHS()).getValue();
       break;
     }
     default:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index ede4d38bb281277..4ea086f0dc6dd05 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1722,7 +1722,7 @@ static bool translateBitsToTidLvlPairs(
             AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)];
             // Skip simple affine expression and non-dense levels (which
             // have their own filter loop).
-            if (exp.isa<AffineDimExpr>() || !stt.isDenseLvl(l))
+            if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
               continue;
 
             // Constant affine expression are handled in genLoop
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 60416f550ee619d..95ecc7b79f220bf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -888,14 +888,14 @@ static LogicalResult verifyOutputShape(
                                      /*symCount=*/0, extents, ctx);
     // Compose the resMap with the extentsMap, which is a constant map.
     AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
-    assert(llvm::all_of(
-               expectedMap.getResults(),
-               [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
-           "expected constant extent along all dimensions.");
+    assert(
+        llvm::all_of(expectedMap.getResults(),
+                     [](AffineExpr e) { return isa<AffineConstantExpr>(e); }) &&
+        "expected constant extent along all dimensions.");
     // Extract the expected shape and build the type.
     auto expectedShape = llvm::to_vector<4>(
         llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
-          return e.cast<AffineConstantExpr>().getValue();
+          return cast<AffineConstantExpr>(e).getValue();
         }));
     auto expected =
         VectorType::get(expectedShape, resVectorType.getElementType(),
@@ -1076,7 +1076,7 @@ void ContractionOp::getIterationIndexMap(
     auto index = it.index();
     auto map = it.value();
     for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
-      auto dim = map.getResult(i).cast<AffineDimExpr>();
+      auto dim = cast<AffineDimExpr>(map.getResult(i));
       iterationIndexMap[index][dim.getPosition()] = i;
     }
   }
@@ -3624,8 +3624,8 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
                                           EmitFun emitOpError) {
   SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
   for (auto expr : permutationMap.getResults()) {
-    auto dim = expr.dyn_cast<AffineDimExpr>();
-    auto zero = expr.dyn_cast<AffineConstantExpr>();
+    auto dim = dyn_cast<AffineDimExpr>(expr);
+    auto zero = dyn_cast<AffineConstantExpr>(expr);
     if (zero) {
       if (zero.getValue() != 0) {
         return emitOpError(
@@ -3726,7 +3726,7 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
              << AffineMapAttr::get(permutationMap)
              << " vs inBounds of size: " << inBounds.size();
     for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
-      if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
+      if (isa<AffineConstantExpr>(permutationMap.getResult(i)) &&
           !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
         return op->emitOpError("requires broadcast dimensions to be in-bounds");
   }
@@ -3918,7 +3918,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
     }
     // Currently out-of-bounds, check whether we can statically determine it is
     // inBounds.
-    auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
+    auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
     assert(dimExpr && "Broadcast dims must be in-bounds");
     auto inBounds =
         isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 3bd1181b6c7bbd8..9664467f4476389 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -46,6 +46,8 @@ class AffineExprConstantFolder {
     return nullptr;
   }
 
+  bool hasPoison() const { return hasPoison_; }
+
 private:
   std::optional<int64_t> constantFoldImpl(AffineExpr expr) {
     switch (expr.getKind()) {
@@ -57,13 +59,22 @@ class AffineExprConstantFolder {
           expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
     case AffineExprKind::Mod:
       return constantFoldBinExpr(
-          expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
+          expr,
+          [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+            return mod(lhs, rhs);
+          });
     case AffineExprKind::FloorDiv:
       return constantFoldBinExpr(
-          expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
+          expr,
+          [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+            return floorDiv(lhs, rhs);
+          });
     case AffineExprKind::CeilDiv:
       return constantFoldBinExpr(
-          expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
+          expr,
+          [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+            return ceilDiv(lhs, rhs);
+          });
     case AffineExprKind::Constant:
       return expr.cast<AffineConstantExpr>().getValue();
     case AffineExprKind::DimId:
@@ -82,8 +93,9 @@ class AffineExprConstantFolder {
   }
 
   // TODO: Change these to operate on APInts too.
-  std::optional<int64_t> constantFoldBinExpr(AffineExpr expr,
-                                             int64_t (*op)(int64_t, int64_t)) {
+  std::optional<int64_t> constantFoldBinExpr(
+      AffineExpr expr,
+      llvm::function_ref<std::optional<int64_t>(int64_t, int64_t)> op) {
     auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
     if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
       if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
@@ -95,6 +107,7 @@ class AffineExprConstantFolder {
   unsigned numDims;
   // The constant valued operands used to evaluate this AffineExpr.
   ArrayRef<Attribute> operandConsts;
+  bool hasPoison_{false};
 };
 
 } // namespace
@@ -547,7 +560,7 @@ bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
   // number of result expressions is lower or equal than the number of input
   // expressions.
   for (auto expr : getResults()) {
-    if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
+    if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
       if (seen[dim.getPosition()])
         return false;
       seen[dim.getPosition()] = true;
@@ -738,7 +751,7 @@ AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) {
   SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero);
   for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
     // Skip zeros from input map. 'exprs' is already initialized to zero.
-    if (auto constExpr = map.getResult(i).dyn_cast<AffineConstantExpr>()) {
+    if (auto constExpr = dyn_cast<AffineConstantExpr>(map.getResult(i))) {
       assert(constExpr.getValue() == 0 &&
              "Unexpected constant in projected permutation");
       (void)constExpr;



More information about the Mlir-commits mailing list