[Mlir-commits] [mlir] [mlir][affine] cleanup deprecated T.cast style functions (PR #71269)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 8 06:23:26 PST 2023
https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/71269
>From 7aa03f05a88698eeef0f87909e112b5f9601eb09 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sat, 4 Nov 2023 11:43:55 +0800
Subject: [PATCH] [mlir] cleanup AffineExpr's deprecated T.cast style functions
detail see the docment: https://mlir.llvm.org/deprecation/
Not all changes are made manually, most of them are made through
a clang tool I wrote.
---
mlir/include/mlir/IR/AffineExpr.h | 52 +++++-
mlir/include/mlir/IR/AffineExprVisitor.h | 32 ++--
mlir/include/mlir/IR/AffineMap.h | 8 +-
.../mlir/Interfaces/VectorInterfaces.td | 8 +-
.../Analysis/FlatLinearValueConstraints.cpp | 2 +-
mlir/lib/CAPI/IR/AffineExpr.cpp | 18 +-
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 2 +-
mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp | 2 +-
.../Conversion/VectorToGPU/VectorToGPU.cpp | 9 +-
.../Conversion/VectorToSCF/VectorToSCF.cpp | 4 +-
.../Affine/Analysis/AffineStructures.cpp | 4 +-
.../Dialect/Affine/Analysis/LoopAnalysis.cpp | 8 +-
mlir/lib/Dialect/Affine/Analysis/Utils.cpp | 14 +-
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 40 ++--
.../Affine/Transforms/DecomposeAffineOps.cpp | 8 +-
.../Affine/Transforms/ReifyValueBounds.cpp | 4 +-
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp | 2 +-
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 27 ++-
.../Arith/Transforms/ReifyValueBounds.cpp | 20 +-
.../GPU/TransformOps/GPUTransformOps.cpp | 2 +-
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 23 ++-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +-
.../Linalg/IR/ValueBoundsOpInterfaceImpl.cpp | 8 +-
.../Transforms/BubbleUpExtractSlice.cpp | 2 +-
.../Linalg/Transforms/ConstantFold.cpp | 2 +-
.../Transforms/DataLayoutPropagation.cpp | 16 +-
.../Linalg/Transforms/DecomposeLinalgOps.cpp | 2 +-
.../Linalg/Transforms/DropUnitDims.cpp | 6 +-
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 24 ++-
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 6 +-
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 4 +-
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +-
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 2 +-
.../Dialect/Linalg/Transforms/Transforms.cpp | 3 +-
.../Linalg/Transforms/Vectorization.cpp | 2 +-
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 2 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +-
.../Mesh/Interfaces/ShardingInterface.cpp | 12 +-
.../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 4 +-
.../SCF/Transforms/LoopSpecialization.cpp | 4 +-
.../SparseTensor/IR/Detail/DimLvlMap.cpp | 24 +--
.../SparseTensor/IR/SparseTensorDialect.cpp | 27 +--
.../SparseTensor/Transforms/CodegenUtils.cpp | 26 +--
.../SparseTensor/Transforms/LoopEmitter.cpp | 10 +-
.../Transforms/SparseBufferRewriting.cpp | 4 +-
.../Transforms/SparseReinterpretMap.cpp | 2 +-
.../Transforms/Sparsification.cpp | 48 ++---
.../IR/TensorInferTypeOpInterfaceImpl.cpp | 24 +--
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 6 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 20 +-
.../Vector/Transforms/LowerVectorTransfer.cpp | 6 +-
.../Vector/Transforms/VectorDistribute.cpp | 12 +-
.../Vector/Transforms/VectorUnroll.cpp | 4 +-
mlir/lib/IR/AffineExpr.cpp | 176 +++++++++---------
mlir/lib/IR/AffineMap.cpp | 67 ++++---
mlir/lib/IR/AsmPrinter.cpp | 16 +-
mlir/lib/IR/BuiltinTypes.cpp | 12 +-
.../mlir-linalg-ods-yaml-gen.cpp | 2 +-
58 files changed, 458 insertions(+), 422 deletions(-)
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 69e02c94ef2708d..40e9d28ce5d3a01 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -82,13 +82,17 @@ class AffineExpr {
bool operator!() const { return expr == nullptr; }
template <typename U>
- constexpr bool isa() const;
+ [[deprecated("Use llvm::isa<U>() instead")]] constexpr bool isa() const;
+
template <typename U>
- U dyn_cast() const;
+ [[deprecated("Use llvm::dyn_cast<U>() instead")]] U dyn_cast() const;
+
template <typename U>
- U dyn_cast_or_null() const;
+ [[deprecated("Use llvm::dyn_cast_or_null<U>() instead")]] U
+ dyn_cast_or_null() const;
+
template <typename U>
- U cast() const;
+ [[deprecated("Use llvm::cast<U>() instead")]] U cast() const;
MLIRContext *getContext() const;
@@ -194,6 +198,8 @@ class AffineExpr {
reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
}
+ ImplType *getImpl() const { return expr; }
+
protected:
ImplType *expr{nullptr};
};
@@ -281,18 +287,15 @@ constexpr bool AffineExpr::isa() const {
}
template <typename U>
U AffineExpr::dyn_cast() const {
- if (isa<U>())
- return U(expr);
- return U(nullptr);
+ return llvm::dyn_cast<U>(*this);
}
template <typename U>
U AffineExpr::dyn_cast_or_null() const {
- return (!*this || !isa<U>()) ? U(nullptr) : U(expr);
+ return llvm::dyn_cast_or_null<U>(*this);
}
template <typename U>
U AffineExpr::cast() const {
- assert(isa<U>());
- return U(expr);
+ return llvm::cast<U>(*this);
}
/// Simplify an affine expression by flattening and some amount of simple
@@ -390,6 +393,35 @@ struct DenseMapInfo<mlir::AffineExpr> {
}
};
+/// Add support for llvm style casts. We provide a cast between To and From if
+/// From is mlir::AffineExpr or derives from it.
+template <typename To, typename From>
+struct CastInfo<To, From,
+ std::enable_if_t<std::is_same_v<mlir::AffineExpr,
+ std::remove_const_t<From>> ||
+ std::is_base_of_v<mlir::AffineExpr, From>>>
+ : NullableValueCastFailed<To>,
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+ static inline bool isPossible(mlir::AffineExpr expr) {
+ /// Return a constant true instead of a dynamic true when casting to self or
+ /// up the hierarchy.
+ if constexpr (std::is_base_of_v<To, From>) {
+ return true;
+ } else {
+ if constexpr (std::is_same_v<To, ::mlir::AffineBinaryOpExpr>)
+ return expr.getKind() <= ::mlir::AffineExprKind::LAST_AFFINE_BINARY_OP;
+ if constexpr (std::is_same_v<To, ::mlir::AffineDimExpr>)
+ return expr.getKind() == ::mlir::AffineExprKind::DimId;
+ if constexpr (std::is_same_v<To, ::mlir::AffineSymbolExpr>)
+ return expr.getKind() == ::mlir::AffineExprKind::SymbolId;
+ if constexpr (std::is_same_v<To, ::mlir::AffineConstantExpr>)
+ return expr.getKind() == ::mlir::AffineExprKind::Constant;
+ }
+ }
+ static inline To doCast(mlir::AffineExpr expr) { return To(expr.getImpl()); }
+};
+
} // namespace llvm
#endif // MLIR_IR_AFFINEEXPR_H
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index f6216614c2238e1..382db22dce463e5 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -77,39 +77,39 @@ class AffineExprVisitor {
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr.getKind()) {
case AffineExprKind::Add: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
return static_cast<SubClass *>(this)->visitConstantExpr(
- expr.cast<AffineConstantExpr>());
+ cast<AffineConstantExpr>(expr));
case AffineExprKind::DimId:
return static_cast<SubClass *>(this)->visitDimExpr(
- expr.cast<AffineDimExpr>());
+ cast<AffineDimExpr>(expr));
case AffineExprKind::SymbolId:
return static_cast<SubClass *>(this)->visitSymbolExpr(
- expr.cast<AffineSymbolExpr>());
+ cast<AffineSymbolExpr>(expr));
}
}
@@ -119,34 +119,34 @@ class AffineExprVisitor {
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr.getKind()) {
case AffineExprKind::Add: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
return static_cast<SubClass *>(this)->visitConstantExpr(
- expr.cast<AffineConstantExpr>());
+ cast<AffineConstantExpr>(expr));
case AffineExprKind::DimId:
return static_cast<SubClass *>(this)->visitDimExpr(
- expr.cast<AffineDimExpr>());
+ cast<AffineDimExpr>(expr));
case AffineExprKind::SymbolId:
return static_cast<SubClass *>(this)->visitSymbolExpr(
- expr.cast<AffineSymbolExpr>());
+ cast<AffineSymbolExpr>(expr));
}
llvm_unreachable("Unknown AffineExpr");
}
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index f691a3daf8889c5..713aef767edf669 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -648,9 +648,9 @@ SmallVector<T> applyPermutationMap(AffineMap map, llvm::ArrayRef<T> source) {
SmallVector<T> result;
result.reserve(map.getNumResults());
for (AffineExpr expr : map.getResults()) {
- if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
result.push_back(source[dimExpr.getPosition()]);
- } else if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ } else if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
assert(constExpr.getValue() == 0 &&
"Unexpected constant in projected permutation map");
result.push_back(0);
@@ -669,9 +669,9 @@ static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
for (const auto &exprs : exprsList) {
for (auto expr : exprs) {
expr.walk([&maxDim, &maxSym](AffineExpr e) {
- if (auto d = e.dyn_cast<AffineDimExpr>())
+ if (auto d = dyn_cast<AffineDimExpr>(e))
maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition()));
- if (auto s = e.dyn_cast<AffineSymbolExpr>())
+ if (auto s = dyn_cast<AffineSymbolExpr>(e))
maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition()));
});
}
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 026faf269f368de..66b1b0b70696e8e 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -120,8 +120,8 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto expr = $_op.getPermutationMap().getResult(idx);
- return expr.template isa<::mlir::AffineConstantExpr>() &&
- expr.template dyn_cast<::mlir::AffineConstantExpr>().getValue() == 0;
+ return ::llvm::isa<::mlir::AffineConstantExpr>(expr) &&
+ ::llvm::dyn_cast<::mlir::AffineConstantExpr>(expr).getValue() == 0;
}]
>,
InterfaceMethod<
@@ -278,9 +278,9 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
AffineExpr dim = std::get<0>(vecDims);
int64_t size = std::get<1>(vecDims);
// Skip broadcast.
- if (dim.isa<AffineConstantExpr>())
+ if (isa<AffineConstantExpr>(dim))
continue;
- dimSizes[dim.cast<AffineDimExpr>().getPosition()] = size;
+ dimSizes[cast<AffineDimExpr>(dim).getPosition()] = size;
}
return dimSizes;
}]
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 382d05f3b2d4851..b838d461c398c83 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -305,7 +305,7 @@ static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos,
// `var_n`), we can proceed.
// TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
// to dims themselves.
- auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
+ auto dimExpr = dyn_cast<AffineDimExpr>(dividendExpr);
if (!dimExpr)
continue;
diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp
index 5b25ab5337e2f77..6e3328b65cb08d3 100644
--- a/mlir/lib/CAPI/IR/AffineExpr.cpp
+++ b/mlir/lib/CAPI/IR/AffineExpr.cpp
@@ -66,7 +66,7 @@ MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr,
//===----------------------------------------------------------------------===//
bool mlirAffineExprIsADim(MlirAffineExpr affineExpr) {
- return unwrap(affineExpr).isa<AffineDimExpr>();
+ return isa<AffineDimExpr>(unwrap(affineExpr));
}
MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx, intptr_t position) {
@@ -74,7 +74,7 @@ MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx, intptr_t position) {
}
intptr_t mlirAffineDimExprGetPosition(MlirAffineExpr affineExpr) {
- return unwrap(affineExpr).cast<AffineDimExpr>().getPosition();
+ return cast<AffineDimExpr>(unwrap(affineExpr)).getPosition();
}
//===----------------------------------------------------------------------===//
@@ -82,7 +82,7 @@ intptr_t mlirAffineDimExprGetPosition(MlirAffineExpr affineExpr) {
//===----------------------------------------------------------------------===//
bool mlirAffineExprIsASymbol(MlirAffineExpr affineExpr) {
- return unwrap(affineExpr).isa<AffineSymbolExpr>();
+ return isa<AffineSymbolExpr>(unwrap(affineExpr));
}
MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx, intptr_t position) {
@@ -90,7 +90,7 @@ MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx, intptr_t position) {
}
intptr_t mlirAffineSymbolExprGetPosition(MlirAffineExpr affineExpr) {
- return unwrap(affineExpr).cast<AffineSymbolExpr>().getPosition();
+ return cast<AffineSymbolExpr>(unwrap(affineExpr)).getPosition();
}
//===----------------------------------------------------------------------===//
@@ -98,7 +98,7 @@ intptr_t mlirAffineSymbolExprGetPosition(MlirAffineExpr affineExpr) {
//===----------------------------------------------------------------------===//
bool mlirAffineExprIsAConstant(MlirAffineExpr affineExpr) {
- return unwrap(affineExpr).isa<AffineConstantExpr>();
+ return isa<AffineConstantExpr>(unwrap(affineExpr));
}
MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx, int64_t constant) {
@@ -106,7 +106,7 @@ MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx, int64_t constant) {
}
int64_t mlirAffineConstantExprGetValue(MlirAffineExpr affineExpr) {
- return unwrap(affineExpr).cast<AffineConstantExpr>().getValue();
+ return cast<AffineConstantExpr>(unwrap(affineExpr)).getValue();
}
//===----------------------------------------------------------------------===//
@@ -181,13 +181,13 @@ MlirAffineExpr mlirAffineCeilDivExprGet(MlirAffineExpr lhs,
//===----------------------------------------------------------------------===//
bool mlirAffineExprIsABinary(MlirAffineExpr affineExpr) {
- return unwrap(affineExpr).isa<AffineBinaryOpExpr>();
+ return isa<AffineBinaryOpExpr>(unwrap(affineExpr));
}
MlirAffineExpr mlirAffineBinaryOpExprGetLHS(MlirAffineExpr affineExpr) {
- return wrap(unwrap(affineExpr).cast<AffineBinaryOpExpr>().getLHS());
+ return wrap(cast<AffineBinaryOpExpr>(unwrap(affineExpr)).getLHS());
}
MlirAffineExpr mlirAffineBinaryOpExprGetRHS(MlirAffineExpr affineExpr) {
- return wrap(unwrap(affineExpr).cast<AffineBinaryOpExpr>().getRHS());
+ return wrap(cast<AffineBinaryOpExpr>(unwrap(affineExpr)).getRHS());
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91b1210efec23e0..4ae6e865f2a49e0 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1390,7 +1390,7 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
for (const auto &en :
llvm::enumerate(transposeOp.getPermutation().getResults())) {
int targetPos = en.index();
- int sourcePos = en.value().cast<AffineDimExpr>().getPosition();
+ int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
targetMemRef.setSize(rewriter, loc, targetPos,
viewMemRef.size(rewriter, loc, sourcePos));
targetMemRef.setStride(rewriter, loc, targetPos,
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 94d875d678df293..11b4cbb2506705b 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -318,7 +318,7 @@ static Value deriveStaticUpperBound(Value upperBound,
if (auto minOp = upperBound.getDefiningOp<AffineMinOp>()) {
for (const AffineExpr &result : minOp.getMap().getResults()) {
- if (auto constExpr = result.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(result)) {
return rewriter.create<arith::ConstantIndexOp>(minOp.getLoc(),
constExpr.getValue());
}
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index f0412648608a6e4..1126c2c20758c7a 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -61,7 +61,7 @@ static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
Location loc = xferOp.getLoc();
unsigned offsetsIdx = 0;
for (auto expr : xferOp.getPermutationMap().getResults()) {
- if (auto dim = expr.template dyn_cast<AffineDimExpr>()) {
+ if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
Value prevIdx = indices[dim.getPosition()];
SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end());
dims.push_back(prevIdx);
@@ -549,8 +549,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
bool isTranspose = isTransposeMatrixLoadMap(map);
// Handle broadcast by setting the stride to 0.
- if (auto cstExpr =
- map.getResult(isTranspose).dyn_cast<AffineConstantExpr>()) {
+ if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
assert(cstExpr.getValue() == 0);
stride = 0;
}
@@ -682,8 +681,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
mlir::AffineExpr dN = map.getResult(1);
// Find the position of these expressions in the input.
- auto exprM = dM.dyn_cast<AffineDimExpr>();
- auto exprN = dN.dyn_cast<AffineDimExpr>();
+ auto exprM = dyn_cast<AffineDimExpr>(dM);
+ auto exprN = dyn_cast<AffineDimExpr>(dN);
if (!exprM || !exprN) {
LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 5fffd9091d2286d..a262cf488ed2951 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -62,7 +62,7 @@ static std::optional<int64_t> unpackedDim(OpTy xferOp) {
// TODO: support 0-d corner case.
assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
auto map = xferOp.getPermutationMap();
- if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
+ if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
return expr.getPosition();
}
assert(xferOp.isBroadcastDim(0) &&
@@ -1290,7 +1290,7 @@ get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
memrefIndices.append(indices.begin(), indices.end());
assert(map.getNumResults() == 1 &&
"Expected 1 permutation map result for 1D transfer");
- if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
+ if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
Location loc = xferOp.getLoc();
auto dim = expr.getPosition();
AffineExpr d0, d1;
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index 6ed3ba14fe15229..469298d3e8f43ff 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -167,10 +167,10 @@ FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
// iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
// Make sure we skip those cases by checking that the lb result is not
// just a constant.
- !lbMap.getResult(0).isa<AffineConstantExpr>()) {
+ !isa<AffineConstantExpr>(lbMap.getResult(0))) {
// Limited support: we expect the lb result to be just a loop dimension.
// Not supported otherwise for now.
- AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
+ AffineDimExpr result = dyn_cast<AffineDimExpr>(lbMap.getResult(0));
if (!result)
return failure();
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index d56db64eac08261..e645afe7cd3e8fa 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -95,7 +95,7 @@ std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
// Take the min if all trip counts are constant.
std::optional<uint64_t> tripCount;
for (auto resultExpr : map.getResults()) {
- if (auto constExpr = resultExpr.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
if (tripCount.has_value())
tripCount =
std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
@@ -124,7 +124,7 @@ uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) {
std::optional<uint64_t> gcd;
for (auto resultExpr : map.getResults()) {
uint64_t thisGcd;
- if (auto constExpr = resultExpr.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
uint64_t tripCount = constExpr.getValue();
// 0 iteration loops (greatest divisor is 2^64 - 1).
if (tripCount == 0)
@@ -235,9 +235,9 @@ static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
SmallVector<Value, 4> exprOperands;
auto resultExpr = accessMap.getResult(i);
resultExpr.walk([&](AffineExpr expr) {
- if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
exprOperands.push_back(mapOperands[dimExpr.getPosition()]);
- else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
+ else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]);
});
// Check access invariance of each operand in 'exprOperands'.
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index ce3ff0a095770c1..eda314a994a4d1e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -635,12 +635,12 @@ std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
// iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
// Make sure we skip those cases by checking that the lb result is not
// just a constant.
- lbMap.getResult(0).isa<AffineConstantExpr>())
+ isa<AffineConstantExpr>(lbMap.getResult(0)))
return std::nullopt;
// Limited support: we expect the lb result to be just a loop dimension for
// now.
- AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
+ AffineDimExpr result = dyn_cast<AffineDimExpr>(lbMap.getResult(0));
if (!result)
return std::nullopt;
@@ -668,10 +668,10 @@ std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
AffineExpr dstLbResult = dstLbMap.getResult(0);
AffineExpr srcUbResult = srcUbMap.getResult(0);
AffineExpr dstUbResult = dstUbMap.getResult(0);
- if (!srcLbResult.isa<AffineConstantExpr>() ||
- !srcUbResult.isa<AffineConstantExpr>() ||
- !dstLbResult.isa<AffineConstantExpr>() ||
- !dstUbResult.isa<AffineConstantExpr>())
+ if (!isa<AffineConstantExpr>(srcLbResult) ||
+ !isa<AffineConstantExpr>(srcUbResult) ||
+ !isa<AffineConstantExpr>(dstLbResult) ||
+ !isa<AffineConstantExpr>(dstUbResult))
return std::nullopt;
// Check if src and dst loop bounds are the same. If not, we can guarantee
@@ -1460,7 +1460,7 @@ static std::optional<uint64_t> getConstDifference(AffineMap lbMap,
AffineExpr ubExpr(ubMap.getResult(0));
auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
lbMap.getNumSymbols());
- auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
+ auto cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
if (!cExpr)
return std::nullopt;
return cExpr.getValue();
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index ba4285bd52394f3..05496e70716a2a1 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -573,9 +573,9 @@ OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
// Fold dims and symbols to existing values.
auto expr = map.getResult(0);
- if (auto dim = expr.dyn_cast<AffineDimExpr>())
+ if (auto dim = dyn_cast<AffineDimExpr>(expr))
return getOperand(dim.getPosition());
- if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
+ if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
return getOperand(map.getNumDims() + sym.getPosition());
// Otherwise, default to folding the map.
@@ -597,7 +597,7 @@ static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
// well for dim/sym expressions, but in that case, getLargestKnownDivisor
// can't be part of the IR library but of the `Analysis` library. The IR
// library can only really depend on simple O(1) checks.
- auto dimExpr = e.dyn_cast<AffineDimExpr>();
+ auto dimExpr = dyn_cast<AffineDimExpr>(e);
// If it's not a dim expr, `div` is the best we have.
if (!dimExpr)
return div;
@@ -627,11 +627,11 @@ static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
/// being an affine dim expression or a constant.
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands,
int64_t k) {
- if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
int64_t constVal = constExpr.getValue();
return constVal >= 0 && constVal < k;
}
- auto dimExpr = e.dyn_cast<AffineDimExpr>();
+ auto dimExpr = dyn_cast<AffineDimExpr>(e);
if (!dimExpr)
return false;
Value operand = operands[dimExpr.getPosition()];
@@ -655,7 +655,7 @@ static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands,
/// expression is in that form.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
AffineExpr "ientTimesDiv, AffineExpr &rem) {
- auto bin = e.dyn_cast<AffineBinaryOpExpr>();
+ auto bin = dyn_cast<AffineBinaryOpExpr>(e);
if (!bin || bin.getKind() != AffineExprKind::Add)
return false;
@@ -715,7 +715,7 @@ static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
constUpperBounds.push_back(getUpperBound(operand));
}
- if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
return constExpr.getValue();
return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
@@ -739,7 +739,7 @@ static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
}
std::optional<int64_t> lowerBound;
- if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
lowerBound = constExpr.getValue();
} else {
lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
@@ -754,7 +754,7 @@ static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
unsigned numSymbols,
ArrayRef<Value> operands) {
// We do this only for certain floordiv/mod expressions.
- auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
+ auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
if (!binExpr)
return;
@@ -765,7 +765,7 @@ static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
- binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
+ binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
expr.getKind() != AffineExprKind::CeilDiv &&
expr.getKind() != AffineExprKind::Mod)) {
@@ -775,7 +775,7 @@ static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
// The `lhs` and `rhs` may be different post construction of simplified expr.
lhs = binExpr.getLHS();
rhs = binExpr.getRHS();
- auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+ auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
if (!rhsConst)
return;
@@ -879,7 +879,7 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
lowerBounds.reserve(map.getNumResults());
upperBounds.reserve(map.getNumResults());
for (AffineExpr e : map.getResults()) {
- if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
lowerBounds.push_back(constExpr.getValue());
upperBounds.push_back(constExpr.getValue());
} else {
@@ -1335,9 +1335,9 @@ static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
mapOrSet->walkExprs([&](AffineExpr expr) {
- if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
usedDims[dimExpr.getPosition()] = true;
- else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
+ else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
usedSyms[symExpr.getPosition()] = true;
});
@@ -2066,7 +2066,7 @@ static void printBound(AffineMapAttr boundMap,
// Print constant bound.
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
- if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
p << constExpr.getValue();
return;
}
@@ -2075,7 +2075,7 @@ static void printBound(AffineMapAttr boundMap,
// Print bound that consists of a single SSA symbol if the map is over a
// single symbol.
if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
- if (expr.dyn_cast<AffineSymbolExpr>()) {
+ if (dyn_cast<AffineSymbolExpr>(expr)) {
p.printOperand(*boundOperands.begin());
return;
}
@@ -3304,13 +3304,13 @@ struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
// with the corresponding operand which is the result of another affine
// min/max op. If So it can be merged into this affine op.
for (AffineExpr expr : oldMap.getResults()) {
- if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
+ if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
Value symValue = symOperands[symExpr.getPosition()];
if (auto producerOp = symValue.getDefiningOp<T>()) {
producerOps.push_back(producerOp);
continue;
}
- } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+ } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
Value dimValue = dimOperands[dimExpr.getPosition()];
if (auto producerOp = dimValue.getDefiningOp<T>()) {
producerOps.push_back(producerOp);
@@ -3760,7 +3760,7 @@ std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
out.reserve(rangesValueMap.getNumResults());
for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
auto expr = rangesValueMap.getResult(i);
- auto cst = expr.dyn_cast<AffineConstantExpr>();
+ auto cst = dyn_cast<AffineConstantExpr>(expr);
if (!cst)
return std::nullopt;
out.push_back(cst.getValue());
@@ -4188,7 +4188,7 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
SmallVector<int64_t, 4> steps;
auto stepsMap = stepsMapAttr.getValue();
for (const auto &result : stepsMap.getResults()) {
- auto constExpr = result.dyn_cast<AffineConstantExpr>();
+ auto constExpr = dyn_cast<AffineConstantExpr>(result);
if (!constExpr)
return parser.emitError(parser.getNameLoc(),
"steps must be constant integers");
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index e87c5c030c5b9a2..e5501e848c1646a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -102,12 +102,12 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
return rewriter.notifyMatchFailure(op, "expected no dims");
AffineExpr remainingExp = m.getResult(0);
- auto binExpr = remainingExp.dyn_cast<AffineBinaryOpExpr>();
+ auto binExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
if (!binExpr)
return rewriter.notifyMatchFailure(op, "terminal affine.apply");
- if (!binExpr.getLHS().isa<AffineBinaryOpExpr>() &&
- !binExpr.getRHS().isa<AffineBinaryOpExpr>())
+ if (!isa<AffineBinaryOpExpr>(binExpr.getLHS()) &&
+ !isa<AffineBinaryOpExpr>(binExpr.getRHS()))
return rewriter.notifyMatchFailure(op, "terminal affine.apply");
bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) ||
@@ -123,7 +123,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
MLIRContext *ctx = op->getContext();
SmallVector<AffineExpr> subExpressions;
while (true) {
- auto currentBinExpr = remainingExp.dyn_cast<AffineBinaryOpExpr>();
+ auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
subExpressions.push_back(remainingExp);
LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 4990229dfd3c876..37b36f76d4465df 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -70,9 +70,9 @@ OpFoldResult affine::materializeComputedBound(
b.getIndexAttr(boundMap.getSingleConstantResult()));
}
// No affine.apply op is needed if the bound is a single SSA value.
- if (auto expr = boundMap.getResult(0).dyn_cast<AffineDimExpr>())
+ if (auto expr = dyn_cast<AffineDimExpr>(boundMap.getResult(0)))
return static_cast<OpFoldResult>(operands[expr.getPosition()]);
- if (auto expr = boundMap.getResult(0).dyn_cast<AffineSymbolExpr>())
+ if (auto expr = dyn_cast<AffineSymbolExpr>(boundMap.getResult(0)))
return static_cast<OpFoldResult>(
operands[expr.getPosition() + boundMap.getNumDims()]);
// General case: build affine.apply op.
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index fb8a0a7c330cf22..f2f67d1a03e22b5 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -2069,7 +2069,7 @@ static LogicalResult generateCopy(
// Set copy start location for this dimension in the lower memory space
// memref.
- if (auto caf = offset.dyn_cast<AffineConstantExpr>()) {
+ if (auto caf = dyn_cast<AffineConstantExpr>(offset)) {
auto indexVal = caf.getValue();
if (indexVal == 0) {
memIndices.push_back(zeroIndex);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index a6df512897eccaf..50a052fb8b74e70 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -75,7 +75,7 @@ class AffineApplyExpander
/// negative = a < 0 in
/// select negative, remainder + b, remainder.
Value visitModExpr(AffineBinaryOpExpr expr) {
- if (auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>()) {
+ if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
if (rhsConst.getValue() <= 0) {
emitError(loc, "modulo by non-positive value is not supported");
return nullptr;
@@ -115,7 +115,7 @@ class AffineApplyExpander
/// IR because arith.floordivsi is more general than affine floordiv in that
/// it supports negative RHS.
Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
- if (auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>()) {
+ if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
if (rhsConst.getValue() <= 0) {
emitError(loc, "division by non-positive value is not supported");
return nullptr;
@@ -154,7 +154,7 @@ class AffineApplyExpander
/// Note: not using arith.ceildivsi for the same reason as explained in the
/// visitFloorDivExpr comment.
Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
- if (auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>()) {
+ if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
if (rhsConst.getValue() <= 0) {
emitError(loc, "division by non-positive value is not supported");
return nullptr;
@@ -464,15 +464,15 @@ AffineExpr mlir::affine::substWithMin(AffineExpr e, AffineExpr dim,
bool positivePath) {
if (e == dim)
return positivePath ? min : max;
- if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) {
+ if (auto bin = dyn_cast<AffineBinaryOpExpr>(e)) {
AffineExpr lhs = bin.getLHS();
AffineExpr rhs = bin.getRHS();
if (bin.getKind() == mlir::AffineExprKind::Add)
return substWithMin(lhs, dim, min, max, positivePath) +
substWithMin(rhs, dim, min, max, positivePath);
- auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>();
- auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>();
+ auto c1 = dyn_cast<AffineConstantExpr>(bin.getLHS());
+ auto c2 = dyn_cast<AffineConstantExpr>(bin.getRHS());
if (c1 && c1.getValue() < 0)
return getAffineBinaryOpExpr(
bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath));
@@ -497,8 +497,7 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
bool isAlreadyNormalized =
llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
int64_t step = std::get<0>(tuple);
- auto lbExpr =
- std::get<1>(tuple).template dyn_cast<AffineConstantExpr>();
+ auto lbExpr = dyn_cast<AffineConstantExpr>(std::get<1>(tuple));
return lbExpr && lbExpr.getValue() == 0 && step == 1;
});
if (isAlreadyNormalized)
@@ -1474,8 +1473,8 @@ static LogicalResult getTileSizePos(
unsigned pos = 0;
for (AffineExpr expr : map.getResults()) {
if (expr.getKind() == AffineExprKind::FloorDiv) {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
- if (binaryExpr.getRHS().isa<AffineConstantExpr>())
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ if (isa<AffineConstantExpr>(binaryExpr.getRHS()))
floordivExprs.emplace_back(
std::make_tuple(binaryExpr.getLHS(), binaryExpr.getRHS(), pos));
}
@@ -1509,7 +1508,7 @@ static LogicalResult getTileSizePos(
expr.walk([&](AffineExpr e) {
if (e == floordivExprLHS) {
if (expr.getKind() == AffineExprKind::Mod) {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
// If LHS and RHS of `mod` are the same with those of floordiv.
if (floordivExprLHS == binaryExpr.getLHS() &&
floordivExprRHS == binaryExpr.getRHS()) {
@@ -1569,7 +1568,7 @@ isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
// Check if affine expr of the dimension includes dynamic dimension of input
// memrefType.
expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
- if (e.isa<AffineDimExpr>()) {
+ if (isa<AffineDimExpr>(e)) {
for (unsigned dm : inMemrefTypeDynDims) {
if (e == getAffineDimExpr(dm, context)) {
isDynamicDim = true;
@@ -1590,11 +1589,11 @@ static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
AffineBinaryOpExpr binaryExpr = nullptr;
switch (pat) {
case TileExprPattern::TileMod:
- binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
+ binaryExpr = cast<AffineBinaryOpExpr>(oldMapOutput);
newMapOutput = binaryExpr.getRHS();
break;
case TileExprPattern::TileFloorDiv:
- binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
+ binaryExpr = cast<AffineBinaryOpExpr>(oldMapOutput);
newMapOutput = getAffineBinaryOpExpr(
AffineExprKind::CeilDiv, binaryExpr.getLHS(), binaryExpr.getRHS());
break;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 8eddd811dbea4d9..8d9fd1478aa9e61 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -24,34 +24,34 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
switch (e.getKind()) {
case AffineExprKind::Constant:
return b.create<ConstantIndexOp>(loc,
- e.cast<AffineConstantExpr>().getValue());
+ cast<AffineConstantExpr>(e).getValue());
case AffineExprKind::DimId:
- return operands[e.cast<AffineDimExpr>().getPosition()];
+ return operands[cast<AffineDimExpr>(e).getPosition()];
case AffineExprKind::SymbolId:
- return operands[e.cast<AffineSymbolExpr>().getPosition() +
+ return operands[cast<AffineSymbolExpr>(e).getPosition() +
map.getNumDims()];
case AffineExprKind::Add: {
- auto binaryExpr = e.cast<AffineBinaryOpExpr>();
+ auto binaryExpr = cast<AffineBinaryOpExpr>(e);
return b.create<AddIOp>(loc, buildExpr(binaryExpr.getLHS()),
buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::Mul: {
- auto binaryExpr = e.cast<AffineBinaryOpExpr>();
+ auto binaryExpr = cast<AffineBinaryOpExpr>(e);
return b.create<MulIOp>(loc, buildExpr(binaryExpr.getLHS()),
buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::FloorDiv: {
- auto binaryExpr = e.cast<AffineBinaryOpExpr>();
+ auto binaryExpr = cast<AffineBinaryOpExpr>(e);
return b.create<DivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::CeilDiv: {
- auto binaryExpr = e.cast<AffineBinaryOpExpr>();
+ auto binaryExpr = cast<AffineBinaryOpExpr>(e);
return b.create<CeilDivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
buildExpr(binaryExpr.getRHS()));
}
case AffineExprKind::Mod: {
- auto binaryExpr = e.cast<AffineBinaryOpExpr>();
+ auto binaryExpr = cast<AffineBinaryOpExpr>(e);
return b.create<RemSIOp>(loc, buildExpr(binaryExpr.getLHS()),
buildExpr(binaryExpr.getRHS()));
}
@@ -106,9 +106,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
b.getIndexAttr(boundMap.getSingleConstantResult()));
}
// No arith ops are needed if the bound is a single SSA value.
- if (auto expr = boundMap.getResult(0).dyn_cast<AffineDimExpr>())
+ if (auto expr = dyn_cast<AffineDimExpr>(boundMap.getResult(0)))
return static_cast<OpFoldResult>(operands[expr.getPosition()]);
- if (auto expr = boundMap.getResult(0).dyn_cast<AffineSymbolExpr>())
+ if (auto expr = dyn_cast<AffineSymbolExpr>(boundMap.getResult(0)))
return static_cast<OpFoldResult>(
operands[expr.getPosition() + boundMap.getNumDims()]);
// General case: build Arith ops.
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index d4908fa7e89e736..a668e2436a8d24d 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -151,7 +151,7 @@ gpuMmaUnrollOrder(vector::ContractionOp contract) {
llvm::SmallDenseSet<int64_t> dims;
for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
- dims.insert(expr.cast<AffineDimExpr>().getPosition());
+ dims.insert(cast<AffineDimExpr>(expr).getPosition());
}
// Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 08d46f236f8ab3b..ba419d32f22a3eb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -190,7 +190,7 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
for (AffineExpr e : indexingMap.getResults()) {
- if (auto d = e.dyn_cast<AffineDimExpr>()) {
+ if (auto d = dyn_cast<AffineDimExpr>(e)) {
if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
return e.isFunctionOfDim(d.getPosition());
@@ -363,8 +363,7 @@ LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
/// preference over `rhs`)
template <typename T>
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) {
- return lhs.isa<T>() ? lhs.cast<T>()
- : (rhs.isa<T>() ? rhs.cast<T>() : nullptr);
+ return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
}
namespace {
@@ -437,7 +436,7 @@ struct ConvAccessExprWalker
}
FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
- if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
int64_t dim = dimExpr.getPosition();
if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
return failure();
@@ -447,7 +446,7 @@ struct ConvAccessExprWalker
convolvedDims.insert(dim);
return dim;
}
- if (auto symbolMulExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+ if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
if (symbolMulExpr.getKind() != AffineExprKind::Mul)
return failure();
auto lhsExpr = symbolMulExpr.getLHS();
@@ -479,7 +478,7 @@ static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
"expected map to have projected permutations");
llvm::SmallDenseSet<int64_t> preservedDims;
for (auto expr : map.getResults())
- preservedDims.insert(expr.cast<AffineDimExpr>().getPosition());
+ preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
return preservedDims;
}
@@ -487,7 +486,7 @@ static SmallVector<int64_t, 2>
getConstantsFromExprList(SmallVector<AffineExpr, 2> exprs) {
SmallVector<int64_t, 2> vals;
for (auto e : exprs) {
- auto constantExpr = e.dyn_cast<AffineConstantExpr>();
+ auto constantExpr = dyn_cast<AffineConstantExpr>(e);
assert(constantExpr && "Found non-constant stride/dilation");
vals.push_back(constantExpr.getValue());
}
@@ -684,7 +683,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
// filter.
llvm::SmallDenseSet<int64_t> allLoopDims;
for (auto outputExpr : indexingMaps.back().getResults()) {
- int64_t outputDim = outputExpr.cast<AffineDimExpr>().getPosition();
+ int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
!filterDims.count(outputDim)) {
// Batch dimension.
@@ -721,7 +720,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
return MatchConvolutionResult::NonConvolutionLoop;
}
for (auto filterExpr : indexingMaps[1].getResults()) {
- int64_t filterDim = filterExpr.cast<AffineDimExpr>().getPosition();
+ int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
if (outputDims.count(filterDim) &&
!inputExprWalker.unConvolvedDims.count(filterDim) &&
!inputExprWalker.convolvedDims.count(filterDim)) {
@@ -871,7 +870,7 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
SmallVector<Range, 4> res(numDims);
for (unsigned idx = 0; idx < numRes; ++idx) {
auto result = map.getResult(idx);
- if (auto d = result.dyn_cast<AffineDimExpr>()) {
+ if (auto d = dyn_cast<AffineDimExpr>(result)) {
if (res[d.getPosition()].offset)
continue;
res[d.getPosition()] =
@@ -888,7 +887,7 @@ SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
SmallVector<int64_t, 4> res(numDims, 0);
for (unsigned idx = 0; idx < numRes; ++idx) {
auto result = map.getResult(idx);
- if (auto d = result.dyn_cast<AffineDimExpr>())
+ if (auto d = dyn_cast<AffineDimExpr>(result))
res[d.getPosition()] = allShapeSizes[idx];
}
return res;
@@ -1093,7 +1092,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
"unexpected result less than 0 at expression #")
<< dim << " in " << mapStr;
}
- if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
+ if (dyn_cast<AffineDimExpr>(indexingMap.getResult(dim))) {
if (inferredDimSize != shape[dim]) {
return op->emitOpError("inferred input/output operand #")
<< opOperand.getOperandNumber() << " has shape's dimension #"
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5a593fbb2b6024d..d12ba8c4c59b33f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2178,7 +2178,7 @@ static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
for (unsigned i = 0; i < sourceShape.size(); i++) {
if (sourceType.isDynamicDim(i))
continue;
- if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
+ if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
}
}
diff --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
index d6dc150584186f5..f56ef485069f856 100644
--- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -29,10 +29,10 @@ struct IndexOpInterface
cstr.bound(value) >= 0;
// index < dim size
- int64_t flatDimPos = linalgOp.getShapesToLoopsMap()
- .getResult(indexOp.getDim())
- .cast<AffineDimExpr>()
- .getPosition();
+ int64_t flatDimPos =
+ cast<AffineDimExpr>(
+ linalgOp.getShapesToLoopsMap().getResult(indexOp.getDim()))
+ .getPosition();
// Find the `flatDimPos`-th operand dimension.
int64_t flatDimCtr = 0;
for (Value operand : linalgOp->getOperands()) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
index 28377279b7ce94c..5c4bc9137c10a8a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
@@ -107,7 +107,7 @@ struct BubbleUpExtractSliceOpPattern
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> tileSizes = sizeBounds;
for (auto const &result : enumerate(indexingMap.getResults())) {
- unsigned position = result.value().cast<AffineDimExpr>().getPosition();
+ unsigned position = cast<AffineDimExpr>(result.value()).getPosition();
tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 4322b6e77eb8fcf..062751552b3cc6c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -148,7 +148,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
SmallVector<unsigned> dims;
dims.reserve(map.getNumResults());
for (AffineExpr result : map.getResults()) {
- dims.push_back(result.cast<AffineDimExpr>().getPosition());
+ dims.push_back(cast<AffineDimExpr>(result).getPosition());
}
return dims;
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 95a20f2369f9e07..9e71a0765bbb5f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -76,10 +76,10 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
innerDimsPos, packOrUnPackOp.getMixedTiles())) {
auto expr = exprs[innerDimPos];
- if (!expr.template isa<AffineDimExpr>())
+ if (!isa<AffineDimExpr>(expr))
return failure();
int64_t domainDimPos =
- exprs[innerDimPos].template cast<AffineDimExpr>().getPosition();
+ cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
if (!isParallelIterator(iterators[domainDimPos]))
return failure();
packInfo.tiledDimsPos.push_back(domainDimPos);
@@ -99,7 +99,7 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
auto areAllAffineDimExpr = [&](int dim) {
for (AffineMap map : indexingMaps) {
if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
- return expr.isFunctionOfDim(dim) && !expr.isa<AffineDimExpr>();
+ return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
})) {
return false;
}
@@ -126,7 +126,7 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
SmallVector<int64_t> permutedOuterDims;
for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
auto permutedExpr = indexingMap.getResult(dim);
- if (auto dimExpr = permutedExpr.template dyn_cast<AffineDimExpr>()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
permutedOuterDims.push_back(dimExpr.getPosition());
continue;
}
@@ -177,7 +177,7 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
// Here we rely on the assumption that the outer dims permutation
// when propagating currently requires that non-affine dim expressions
// are not permuted, thus allowing the identity assignment below.
- if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
currentPositionTileLoops[dimExpr.getPosition()] = pos;
else
currentPositionTileLoops[pos] = pos;
@@ -238,7 +238,7 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
// Step 1. Construct the information of packing data dimensions; append inner
// dimensions to the indexing maps for the operand.
for (auto [index, expr] : llvm::enumerate(exprs)) {
- if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
int64_t dimPos = dimExpr.getPosition();
domainDimToOperandDim[dimPos] = index;
continue;
@@ -264,12 +264,12 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
SmallVector<int64_t> inversedOuterPerm =
invertPermutationVector(packInfo.outerDimsOnDomainPerm);
for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
- if (auto dimExpr = exprs[i].dyn_cast<AffineDimExpr>()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
int64_t dimPos = dimExpr.getPosition();
exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
continue;
}
- assert(exprs[i].isa<AffineConstantExpr>() &&
+ assert(isa<AffineConstantExpr>(exprs[i]) &&
"Attempted to permute non-constant and non-affine dim expression");
}
// Step 2.2: Undo the transposition on `exprs` and propagate the
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index eae03924fb5c7bd..1227478118fbef8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -122,7 +122,7 @@ SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
SmallVector<OpFoldResult> permutedValues(values.size());
for (const auto &position :
llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) {
- return expr.cast<AffineDimExpr>().getPosition();
+ return cast<AffineDimExpr>(expr).getPosition();
})))
permutedValues[position.value()] = values[position.index()];
return permutedValues;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 2e3610b7c08d9da..6fbf35145578716 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -349,14 +349,14 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
auto isUnitDim = [&](unsigned dim) {
- if (auto dimExpr = exprs[dim].dyn_cast<AffineDimExpr>()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
unsigned oldPosition = dimExpr.getPosition();
return !oldDimsToNewDimsMap.count(oldPosition);
}
// Handle the other case where the shape is 1, and is accessed using a
// constant 0.
if (operandShape[dim] == 1) {
- auto constAffineExpr = exprs[dim].dyn_cast<AffineConstantExpr>();
+ auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
return constAffineExpr && constAffineExpr.getValue() == 0;
}
return false;
@@ -411,7 +411,7 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
allowedUnitDims.end());
llvm::SmallDenseSet<unsigned> unitDims;
for (const auto &expr : enumerate(invertedMap.getResults())) {
- if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) {
+ if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
if (dims[dimExpr.getPosition()] == 1 &&
unitDimsFilter.count(expr.index()))
unitDims.insert(expr.index());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index d5b8c6c16c8589a..f0393e44fc00c27 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -122,7 +122,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
auto addToCoveredDims = [&](AffineMap map) {
for (auto result : map.getResults())
- if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
coveredDims[dimExpr.getPosition()] = true;
};
@@ -587,7 +587,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
expandedShapeMap.resize(fusedIndexMap.getNumDims());
for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
- unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
+ unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
numExpandedDims[pos] = foldedDims.getNumResults();
ArrayRef<int64_t> shape =
@@ -645,7 +645,7 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
SmallVector<AffineExpr> newExprs;
for (AffineExpr expr : indexingMap.getResults()) {
- unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+ unsigned pos = cast<AffineDimExpr>(expr).getPosition();
SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
return builder.getAffineDimExpr(static_cast<unsigned>(v));
@@ -664,7 +664,7 @@ static RankedTensorType getExpandedType(RankedTensorType originalType,
const ExpansionInfo &expansionInfo) {
SmallVector<int64_t> expandedShape;
for (AffineExpr expr : indexingMap.getResults()) {
- unsigned dim = expr.cast<AffineDimExpr>().getPosition();
+ unsigned dim = cast<AffineDimExpr>(expr).getPosition();
auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
}
@@ -683,7 +683,7 @@ getReassociationForExpansion(AffineMap indexingMap,
SmallVector<ReassociationIndices> reassociation;
unsigned numReshapeDims = 0;
for (AffineExpr expr : indexingMap.getResults()) {
- unsigned dim = expr.cast<AffineDimExpr>().getPosition();
+ unsigned dim = cast<AffineDimExpr>(expr).getPosition();
auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
@@ -1002,9 +1002,7 @@ getDomainReassociation(AffineMap indexingMap,
ReassociationIndices domainReassociation = llvm::to_vector<4>(
llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
- return indexingMap.getResults()[pos]
- .cast<AffineDimExpr>()
- .getPosition();
+ return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
}));
// The projected permutation semantics ensures that there is no repetition of
// the domain indices.
@@ -1026,7 +1024,7 @@ bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
unsigned dimSequenceStart = dimSequence[0];
for (const auto &expr : enumerate(indexingMap.getResults())) {
- unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
+ unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
// 1. Check if this start of the sequence.
if (dimInMapStart == dimSequenceStart) {
if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
@@ -1034,8 +1032,8 @@ bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
// 1a. Check if sequence is preserved.
for (const auto &dimInSequence : enumerate(dimSequence)) {
unsigned dimInMap =
- indexingMap.getResult(expr.index() + dimInSequence.index())
- .cast<AffineDimExpr>()
+ cast<AffineDimExpr>(
+ indexingMap.getResult(expr.index() + dimInSequence.index()))
.getPosition();
if (dimInMap != dimInSequence.value())
return false;
@@ -1330,7 +1328,7 @@ getCollapsedOpIndexingMap(AffineMap indexingMap,
auto origOpToCollapsedOpMapping =
collapsingInfo.getOrigOpToCollapsedOpMapping();
for (auto expr : indexingMap.getResults()) {
- unsigned dim = expr.cast<AffineDimExpr>().getPosition();
+ unsigned dim = cast<AffineDimExpr>(expr).getPosition();
// If the dim is not the first of the collapsed dim, do nothing.
if (origOpToCollapsedOpMapping[dim].second != 0)
continue;
@@ -1356,7 +1354,7 @@ getOperandReassociation(AffineMap indexingMap,
collapsingInfo.getCollapsedOpToOrigOpMapping();
while (counter < indexingMap.getNumResults()) {
unsigned dim =
- indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
+ cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
// This is the start of a collapsed dimensions of the iteration that
// is gauranteed to be preserved in the indexing map. The number of folded
// dims is obtained from the collapsed op to original op mapping.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index d83ec725e082092..11bd886c36e5379 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -87,10 +87,10 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
<< "getShapeDefiningLoopRange map: " << map << "\n");
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
for (const auto &en : llvm::enumerate(map.getResults())) {
- auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
+ auto dimExpr = dyn_cast<AffineDimExpr>(en.value());
if (!dimExpr)
continue;
- if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) {
+ if (loopDepth == cast<AffineDimExpr>(en.value()).getPosition()) {
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
<< loopDepth << "\n");
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
@@ -196,7 +196,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
DenseMap<unsigned, Range> fusedLoopsAndRanges;
Value shapedOperand = consumerOpOperand.get();
for (const auto &en : llvm::enumerate(producerMap.getResults())) {
- unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
+ unsigned posInProducerLoop = cast<AffineDimExpr>(en.value()).getPosition();
fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 79e295b937b9374..5a56e914ea4c77c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -298,13 +298,13 @@ struct FoldAffineOp : public RewritePattern {
AffineExpr expr = map.getResult(0);
if (map.getNumInputs() == 0) {
- if (auto val = expr.dyn_cast<AffineConstantExpr>()) {
+ if (auto val = dyn_cast<AffineConstantExpr>(expr)) {
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue());
return success();
}
return failure();
}
- if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) {
+ if (dyn_cast<AffineDimExpr>(expr) || dyn_cast<AffineSymbolExpr>(expr)) {
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 472e6fa3ab27b22..7f3ab1f1a24b2f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -553,7 +553,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
if (!options.interchangeVector.empty()) {
for (AffineExpr result : invPermutationMap.getResults())
interchangedIvs.push_back(
- ivs[result.cast<AffineDimExpr>().getPosition()]);
+ ivs[cast<AffineDimExpr>(result).getPosition()]);
} else {
interchangedIvs.assign(ivs.begin(), ivs.end());
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 5f566d8b10aef73..cae7b50b0fb3b47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -192,7 +192,7 @@ struct LinalgOpTilingInterface
}
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
unsigned dimPosition =
- resultExpr.value().template cast<AffineDimExpr>().getPosition();
+ cast<AffineDimExpr>(resultExpr.value()).getPosition();
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index bca343cf8777149..10dfbe6cec781d5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -176,8 +176,7 @@ packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps,
}
// We can only pack AffineDimExpr atm.
- if (!map.getResult(maybeOperandDimensionToPack.value())
- .isa<AffineDimExpr>())
+ if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value())))
return failure();
// Add `newDim` to the results of the map.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b8d82159856825f..f9a53a8451a6016 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1297,7 +1297,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
SmallVector<int64_t> zeroPos;
auto results = indexingMap.getResults();
for (const auto &result : llvm::enumerate(results)) {
- if (result.value().isa<AffineConstantExpr>()) {
+ if (isa<AffineConstantExpr>(result.value())) {
zeroPos.push_back(result.index());
}
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index f177235acff7238..75c8cd3e1d95a10 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -65,7 +65,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
visit(expr.getLHS());
visit(expr.getRHS());
if (expr.getKind() == mlir::AffineExprKind::Mul)
- assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
+ assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 &&
"nonpositive multiplying coefficient");
}
bool isTiled = false;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 215a8f5e7d18be0..484fe92d682f753 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3188,7 +3188,7 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
SmallVector<int64_t> sizes(rank, 0);
SmallVector<int64_t> strides(rank, 1);
for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
- unsigned position = en.value().cast<AffineDimExpr>().getPosition();
+ unsigned position = cast<AffineDimExpr>(en.value()).getPosition();
sizes[en.index()] = originalSizes[position];
strides[en.index()] = originalStrides[position];
}
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 3e0df660d5c46d3..902ad8fc19c5d88 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -34,7 +34,7 @@ checkOperandAffineExprRecursively(AffineExpr expr,
SmallVectorImpl<bool> &seenIds) {
switch (expr.getKind()) {
case AffineExprKind::Add: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
AffineExpr lhs = binOpExpr.getLHS();
AffineExpr rhs = binOpExpr.getRHS();
if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
@@ -44,7 +44,7 @@ checkOperandAffineExprRecursively(AffineExpr expr,
return success();
}
case AffineExprKind::Mul: {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
AffineExpr lhs = binOpExpr.getLHS();
AffineExpr rhs = binOpExpr.getRHS();
AffineExpr dimExpr;
@@ -56,14 +56,14 @@ checkOperandAffineExprRecursively(AffineExpr expr,
dimExpr = rhs;
} else
return failure();
- unsigned position = dimExpr.cast<AffineDimExpr>().getPosition();
+ unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
if ((size_t)position >= seenIds.size() || seenIds[position])
return failure();
seenIds[position] = true;
return success();
}
case AffineExprKind::DimId: {
- unsigned position = expr.cast<AffineDimExpr>().getPosition();
+ unsigned position = cast<AffineDimExpr>(expr).getPosition();
if ((size_t)position >= seenIds.size() || seenIds[position])
return failure();
seenIds[position] = true;
@@ -280,7 +280,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
AffineExpr expr = std::get<0>(it);
ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
- auto dim = expr.cast<AffineDimExpr>();
+ auto dim = cast<AffineDimExpr>(expr);
unsigned index = dim.getPosition();
visitedLoopIndices.insert(index);
if (failed(fillShardingOption(op, shardingOption, shardAttr.getCluster(),
@@ -416,7 +416,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
AffineExpr expr = it.value();
// `expr` must be an `AffineDimExpr` because `map` is verified by
// isProjectedPermutation
- auto dim = expr.cast<AffineDimExpr>();
+ auto dim = cast<AffineDimExpr>(expr);
unsigned loopIdx = dim.getPosition();
if (loopIdx < shardingOption.shardingArray.size())
splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 88c6f3da656f3ba..cb36e0cecf0d24e 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -75,11 +75,11 @@ struct ForOpInterface
// Check if computed bound equals the corresponding iter_arg.
Value singleValue = nullptr;
std::optional<int64_t> singleDim;
- if (auto dimExpr = bound.getResult(0).dyn_cast<AffineDimExpr>()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult(0))) {
int64_t idx = dimExpr.getPosition();
singleValue = boundOperands[idx].first;
singleDim = boundOperands[idx].second;
- } else if (auto symExpr = bound.getResult(0).dyn_cast<AffineSymbolExpr>()) {
+ } else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult(0))) {
int64_t idx = symExpr.getPosition() + bound.getNumDims();
singleValue = boundOperands[idx].first;
singleDim = boundOperands[idx].second;
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index f208e5245977d83..23646f42eb5fe3d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -51,7 +51,7 @@ static void specializeParallelLoopForUnrolling(ParallelOp op) {
return;
int64_t minConstant = std::numeric_limits<int64_t>::max();
for (AffineExpr expr : minOp.getMap().getResults()) {
- if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>())
+ if (auto constantIndex = dyn_cast<AffineConstantExpr>(expr))
minConstant = std::min(minConstant, constantIndex.getValue());
}
if (minConstant == std::numeric_limits<int64_t>::max())
@@ -87,7 +87,7 @@ static void specializeForLoopForUnrolling(ForOp op) {
return;
int64_t minConstant = std::numeric_limits<int64_t>::max();
for (AffineExpr expr : minOp.getMap().getResults()) {
- if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>())
+ if (auto constantIndex = dyn_cast<AffineConstantExpr>(expr))
minConstant = std::min(minConstant, constantIndex.getValue());
}
if (minConstant == std::numeric_limits<int64_t>::max())
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 6a81a11a932f94a..9757a599bd1eb60 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -24,61 +24,61 @@ Var DimLvlExpr::castAnyVar() const {
}
std::optional<Var> DimLvlExpr::dyn_castAnyVar() const {
- if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>())
+ if (const auto s = dyn_cast_or_null<AffineSymbolExpr>(expr))
return SymVar(s);
- if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>())
+ if (const auto x = dyn_cast_or_null<AffineDimExpr>(expr))
return Var(getAllowedVarKind(), x);
return std::nullopt;
}
SymVar DimLvlExpr::castSymVar() const {
- return SymVar(expr.cast<AffineSymbolExpr>());
+ return SymVar(llvm::cast<AffineSymbolExpr>(expr));
}
std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const {
- if (const auto s = expr.dyn_cast_or_null<AffineSymbolExpr>())
+ if (const auto s = dyn_cast_or_null<AffineSymbolExpr>(expr))
return SymVar(s);
return std::nullopt;
}
Var DimLvlExpr::castDimLvlVar() const {
- return Var(getAllowedVarKind(), expr.cast<AffineDimExpr>());
+ return Var(getAllowedVarKind(), llvm::cast<AffineDimExpr>(expr));
}
std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
- if (const auto x = expr.dyn_cast_or_null<AffineDimExpr>())
+ if (const auto x = dyn_cast_or_null<AffineDimExpr>(expr))
return Var(getAllowedVarKind(), x);
return std::nullopt;
}
int64_t DimLvlExpr::castConstantValue() const {
- return expr.cast<AffineConstantExpr>().getValue();
+ return llvm::cast<AffineConstantExpr>(expr).getValue();
}
std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
- const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
+ const auto k = dyn_cast_or_null<AffineConstantExpr>(expr);
return k ? std::make_optional(k.getValue()) : std::nullopt;
}
bool DimLvlExpr::hasConstantValue(int64_t val) const {
- const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
+ const auto k = dyn_cast_or_null<AffineConstantExpr>(expr);
return k && k.getValue() == val;
}
DimLvlExpr DimLvlExpr::getLHS() const {
- const auto binop = expr.dyn_cast_or_null<AffineBinaryOpExpr>();
+ const auto binop = dyn_cast_or_null<AffineBinaryOpExpr>(expr);
return DimLvlExpr(kind, binop ? binop.getLHS() : nullptr);
}
DimLvlExpr DimLvlExpr::getRHS() const {
- const auto binop = expr.dyn_cast_or_null<AffineBinaryOpExpr>();
+ const auto binop = dyn_cast_or_null<AffineBinaryOpExpr>(expr);
return DimLvlExpr(kind, binop ? binop.getRHS() : nullptr);
}
std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr>
DimLvlExpr::unpackBinop() const {
const auto ak = getAffineKind();
- const auto binop = expr.dyn_cast<AffineBinaryOpExpr>();
+ const auto binop = llvm::dyn_cast<AffineBinaryOpExpr>(expr);
const DimLvlExpr lhs(kind, binop ? binop.getLHS() : nullptr);
const DimLvlExpr rhs(kind, binop ? binop.getRHS() : nullptr);
return {lhs, ak, rhs};
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index aedac67865aacec..92bf8ec6468e532 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -438,14 +438,15 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
// Do constant propagation on the affine map.
AffineExpr evalExp =
simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
- if (auto c = evalExp.dyn_cast<AffineConstantExpr>()) {
+ // use llvm namespace here to avoid ambiguity
+ if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
ret.push_back(c.getValue() + 1);
} else {
- if (auto mod = evalExp.dyn_cast<AffineBinaryOpExpr>();
+ if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
mod && mod.getKind() == AffineExprKind::Mod) {
// We can still infer a static bound for expressions in form
// "d % constant" since d % constant \in [0, constant).
- if (auto bound = mod.getRHS().dyn_cast<AffineConstantExpr>()) {
+ if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
ret.push_back(bound.getValue());
continue;
}
@@ -765,10 +766,10 @@ AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
for (unsigned i = 0, n = numLvls; i < n; i++) {
auto result = dimToLvl.getResult(i);
- if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
+ if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
if (result.getKind() == AffineExprKind::FloorDiv) {
// Position of the dimension in dimToLvl.
- auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+ auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
"expected only one floordiv for each dimension");
SmallVector<AffineExpr, 3> components;
@@ -779,7 +780,7 @@ AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
// Map key is the position of the dimension.
lvlExprComponents[pos] = components;
} else if (result.getKind() == AffineExprKind::Mod) {
- auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+ auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
"expected floordiv before mod");
// Add level variable for mod to the same vector
@@ -813,10 +814,10 @@ SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
"expected dimToLvl to be block sparsity for calling getBlockSize");
SmallVector<unsigned> blockSize;
for (auto result : dimToLvl.getResults()) {
- if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
+ if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
if (result.getKind() == AffineExprKind::Mod) {
blockSize.push_back(
- binOp.getRHS().dyn_cast<AffineConstantExpr>().getValue());
+ dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
}
} else {
blockSize.push_back(0);
@@ -830,20 +831,20 @@ bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
return false;
std::map<unsigned, int64_t> coeffientMap;
for (auto result : dimToLvl.getResults()) {
- if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
- auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+ if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
+ auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
if (result.getKind() == AffineExprKind::FloorDiv) {
// Expect only one floordiv for each dimension.
if (coeffientMap.find(pos) != coeffientMap.end())
return false;
coeffientMap[pos] =
- binOp.getRHS().dyn_cast<AffineConstantExpr>().getValue();
+ dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue();
} else if (result.getKind() == AffineExprKind::Mod) {
// Expect floordiv before mod.
if (coeffientMap.find(pos) == coeffientMap.end())
return false;
// Expect mod to have the same coefficient as floordiv.
- if (binOp.getRHS().dyn_cast<AffineConstantExpr>().getValue() !=
+ if (dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue() !=
coeffientMap[pos]) {
return false;
}
@@ -1197,7 +1198,7 @@ LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
? getEncoder().getDimToLvl()
: getEncoder().getLvlToDim();
for (AffineExpr exp : perm.getResults())
- results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
+ results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index d5c9ee41215ae97..bcb923d94a34818 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -712,19 +712,19 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
uint64_t cf = 0, cm = 0;
switch (exp.getKind()) {
case AffineExprKind::DimId: {
- d = exp.cast<AffineDimExpr>().getPosition();
+ d = cast<AffineDimExpr>(exp).getPosition();
break;
}
case AffineExprKind::FloorDiv: {
- auto floor = exp.cast<AffineBinaryOpExpr>();
- d = floor.getLHS().cast<AffineDimExpr>().getPosition();
- cf = floor.getRHS().cast<AffineConstantExpr>().getValue();
+ auto floor = cast<AffineBinaryOpExpr>(exp);
+ d = cast<AffineDimExpr>(floor.getLHS()).getPosition();
+ cf = cast<AffineConstantExpr>(floor.getRHS()).getValue();
break;
}
case AffineExprKind::Mod: {
- auto mod = exp.cast<AffineBinaryOpExpr>();
- d = mod.getLHS().cast<AffineDimExpr>().getPosition();
- cm = mod.getRHS().cast<AffineConstantExpr>().getValue();
+ auto mod = cast<AffineBinaryOpExpr>(exp);
+ d = cast<AffineDimExpr>(mod.getLHS()).getPosition();
+ cm = cast<AffineConstantExpr>(mod.getRHS()).getValue();
break;
}
default:
@@ -760,17 +760,17 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
uint64_t c = 0;
switch (exp.getKind()) {
case AffineExprKind::DimId: {
- l = exp.cast<AffineDimExpr>().getPosition();
+ l = cast<AffineDimExpr>(exp).getPosition();
break;
}
case AffineExprKind::Add: {
// Always mul on lhs, symbol/constant on rhs.
- auto add = exp.cast<AffineBinaryOpExpr>();
+ auto add = cast<AffineBinaryOpExpr>(exp);
assert(add.getLHS().getKind() == AffineExprKind::Mul);
- auto mul = add.getLHS().cast<AffineBinaryOpExpr>();
- ll = mul.getLHS().cast<AffineDimExpr>().getPosition();
- c = mul.getRHS().cast<AffineConstantExpr>().getValue();
- l = add.getRHS().cast<AffineDimExpr>().getPosition();
+ auto mul = cast<AffineBinaryOpExpr>(add.getLHS());
+ ll = cast<AffineDimExpr>(mul.getLHS()).getPosition();
+ c = cast<AffineConstantExpr>(mul.getRHS()).getValue();
+ l = cast<AffineDimExpr>(add.getRHS()).getPosition();
break;
}
default:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index bb3c6fb56f692d9..bb8ccc336b31059 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -609,22 +609,22 @@ Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) {
// level-expression, the `getPosition` must in fact be a `Dimension`.
// However, elsewhere we have been lead to expect that `loopIdToOrd`
// should be indexed by `LoopId`...
- const auto loopId = a.cast<AffineDimExpr>().getPosition();
+ const auto loopId = cast<AffineDimExpr>(a).getPosition();
assert(loopId < loopIdToOrd.size());
return loopStack[loopIdToOrd[loopId]].iv;
}
case AffineExprKind::Add: {
- auto binOp = a.cast<AffineBinaryOpExpr>();
+ auto binOp = cast<AffineBinaryOpExpr>(a);
return ADDI(genAffine(builder, loc, binOp.getLHS()),
genAffine(builder, loc, binOp.getRHS()));
}
case AffineExprKind::Mul: {
- auto binOp = a.cast<AffineBinaryOpExpr>();
+ auto binOp = cast<AffineBinaryOpExpr>(a);
return MULI(genAffine(builder, loc, binOp.getLHS()),
genAffine(builder, loc, binOp.getRHS()));
}
case AffineExprKind::Constant: {
- int64_t c = a.cast<AffineConstantExpr>().getValue();
+ int64_t c = cast<AffineConstantExpr>(a).getValue();
return C_IDX(c);
}
default:
@@ -1157,7 +1157,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
OpBuilder &builder, Location loc, TensorId tid, Level lvl,
AffineExpr affine, MutableArrayRef<Value> reduc) {
assert(isValidLevel(tid, lvl));
- assert(!affine.isa<AffineDimExpr>() && !isDenseDLT(lvlTypes[tid][lvl]));
+ assert(!isa<AffineDimExpr>(affine) && !isDenseDLT(lvlTypes[tid][lvl]));
// We can not re-enter the same level.
assert(!coords[tid][lvl]);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index ced7983af324c1f..463a49f52283a73 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -56,7 +56,7 @@ static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
uint64_t ny, ValueRange operands) {
nameOstream << namePrefix;
for (auto res : xPerm.getResults())
- nameOstream << res.cast<AffineDimExpr>().getPosition() << "_";
+ nameOstream << cast<AffineDimExpr>(res).getPosition() << "_";
nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
nameOstream << "_coo_" << ny;
@@ -114,7 +114,7 @@ static void forEachIJPairInXs(
Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
- unsigned actualK = xPerm.getResult(k).cast<AffineDimExpr>().getPosition();
+ unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition();
Value ak = constantIndex(builder, loc, actualK);
Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 307a609fd1b7746..029af87136d4422 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -33,7 +33,7 @@ static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
assert(lvl2dim.getNumInputs() == lvlRank);
SmallVector<AffineExpr> exps;
for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
- unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
+ unsigned pos = cast<AffineDimExpr>(map.getResult(i)).getPosition();
exps.push_back(lvl2dim.getResult(pos));
}
return AffineMap::get(lvlRank, 0, exps, builder.getContext());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 85d6a6ddabf9eb6..4a171f6cc816e1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -101,7 +101,7 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
}
/// Get the desired AffineDimExpr.
- AffineDimExpr getDimExpr() const { return pickedDim.cast<AffineDimExpr>(); }
+ AffineDimExpr getDimExpr() const { return cast<AffineDimExpr>(pickedDim); }
private:
/// The picked AffineDimExpr after visit. This must be stored as
@@ -137,7 +137,7 @@ static bool isInvariantAffine(AffineExpr a, ArrayRef<LoopId> loopStack,
LoopId ldx, bool &isAtLoop) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
- const LoopId i = a.cast<AffineDimExpr>().getPosition();
+ const LoopId i = cast<AffineDimExpr>(a).getPosition();
if (i == ldx) {
isAtLoop = true;
// Must be invariant if we are at the given loop.
@@ -153,12 +153,12 @@ static bool isInvariantAffine(AffineExpr a, ArrayRef<LoopId> loopStack,
}
case AffineExprKind::Add:
case AffineExprKind::Mul: {
- auto binOp = a.cast<AffineBinaryOpExpr>();
+ auto binOp = cast<AffineBinaryOpExpr>(a);
return isInvariantAffine(binOp.getLHS(), loopStack, ldx, isAtLoop) &&
isInvariantAffine(binOp.getRHS(), loopStack, ldx, isAtLoop);
}
default: {
- assert(a.isa<AffineConstantExpr>());
+ assert(isa<AffineConstantExpr>(a));
return true;
}
}
@@ -197,7 +197,7 @@ static AffineMap permute(CodegenEnv &env, AffineMap m) {
const unsigned preSize = perm.size();
for (unsigned dim : worklist.set_bits()) {
bool isAtLoop = false;
- if (m.getResult(dim).isa<AffineConstantExpr>() ||
+ if (isa<AffineConstantExpr>(m.getResult(dim)) ||
(isInvariantAffine(m.getResult(dim), env.getLoopStackUpTo(loopDepth),
env.topSortAt(loopDepth - 1), isAtLoop) &&
isAtLoop)) {
@@ -231,7 +231,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
bool setLvlFormat = true) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
- const LoopId idx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition());
+ const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
if (!isUndefDLT(merger.getLvlType(tid, idx)))
return false; // used more than once
@@ -249,7 +249,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
++filterLdx;
}
- if (auto binOp = a.dyn_cast<AffineBinaryOpExpr>()) {
+ if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
// We do not set dim level format for affine expression like d0 + d1 on
// either loop index at d0 or d1.
// We continue the recursion merely to check whether current affine is
@@ -290,7 +290,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
if (coefficient <= 0)
return false;
- const LoopId ldx = merger.makeLoopId(a.cast<AffineDimExpr>().getPosition());
+ const LoopId ldx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
if (!isUndefDLT(merger.getLvlType(tensor, ldx)))
return false; // used more than once, e.g., A[i][i]
@@ -329,17 +329,17 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
// complicated cases like `2 * d0 + d1`.
if (!isSubExp)
return false;
- auto binOp = a.cast<AffineBinaryOpExpr>();
+ auto binOp = cast<AffineBinaryOpExpr>(a);
auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
- if (rhs.isa<AffineConstantExpr>())
+ if (isa<AffineConstantExpr>(rhs))
std::swap(lhs, rhs);
// Must be in form of `constant * d`.
- assert(lhs.isa<AffineConstantExpr>() && rhs.isa<AffineDimExpr>());
- int64_t coefficient = lhs.cast<AffineConstantExpr>().getValue();
+ assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
+ int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
return findDepIdxSet(merger, tensor, lvl, rhs, dlt, isSubExp, coefficient);
}
case AffineExprKind::Add: {
- auto binOp = a.cast<AffineBinaryOpExpr>();
+ auto binOp = cast<AffineBinaryOpExpr>(a);
return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), dlt, true) &&
findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), dlt, true);
}
@@ -391,7 +391,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
for (Level l = 0; l < lvlRank; l++) {
// FIXME: `toOrigDim` is deprecated.
const Dimension d = toOrigDim(stt.getEncoding(), l);
- if (!exprs[d].isa<AffineDimExpr>() && !stt.isDenseLvl(l))
+ if (!isa<AffineDimExpr>(exprs[d]) && !stt.isDenseLvl(l))
num++;
}
return num;
@@ -567,7 +567,7 @@ static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
switch (toExpand.getKind()) {
case AffineExprKind::DimId: {
const std::optional<LoopId> idx{
- toExpand.cast<AffineDimExpr>().getPosition()};
+ cast<AffineDimExpr>(toExpand).getPosition()};
if (toExpand == a)
addAffineOrderings(adjM, inDegree, AffineExpr(), b, idx, tidx);
else // toExpand == b
@@ -576,7 +576,7 @@ static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
}
case AffineExprKind::Add:
case AffineExprKind::Mul: {
- auto binOp = toExpand.cast<AffineBinaryOpExpr>();
+ auto binOp = cast<AffineBinaryOpExpr>(toExpand);
if (toExpand == a) {
addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx, tidx);
addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx, tidx);
@@ -606,7 +606,7 @@ static void tryRelaxAffineConstraints(linalg::GenericOp op,
// require both d0 < d2 and d1 < d2 to ensure correct ordering (i.e.,
// no ordering like d0->d2->d1).
// TODO: this is obviously a sub optimal solution.
- if (!fldx && !fa.isa<AffineConstantExpr>()) {
+ if (!fldx && !isa<AffineConstantExpr>(fa)) {
// Heuristic: we prefer parallel loop for lhs to reduce the chance
// we add reduce < parallel ordering.
finder.setPickedIterType(utils::IteratorType::parallel);
@@ -614,7 +614,7 @@ static void tryRelaxAffineConstraints(linalg::GenericOp op,
fa = finder.getDimExpr();
fldx = finder.getDimExpr().getPosition();
}
- if (!ta.isa<AffineConstantExpr>()) {
+ if (!isa<AffineConstantExpr>(ta)) {
// Heuristic: we prefer reduction loop for rhs to reduce the chance
// adding reduce < parallel ordering.
finder.setPickedIterType(utils::IteratorType::reduction);
@@ -647,7 +647,7 @@ static void addFilterLoopBasedConstraints(CodegenEnv &env, OpOperand &t,
// Filter loops should be constructed after all the dependent loops,
// i.e., d0 + d1 < filter_loop(d0 + d1)
if (tldx && env.merger().isFilterLoop(*tldx)) {
- assert(!ta.isa<AffineDimExpr>() && !isDenseDLT(enc.getLvlTypes()[lvl]));
+ assert(!isa<AffineDimExpr>(ta) && !isDenseDLT(enc.getLvlTypes()[lvl]));
addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx);
// Now that the ordering of affine expression is captured by filter
// loop idx, we only need to ensure the affine ordering against filter
@@ -716,7 +716,7 @@ static void addSliceBasedConstraints(CodegenEnv &env, OpOperand &t,
const AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1));
const AffineExpr ta = map.getResult(toOrigDim(enc, lvl));
- if (fa.isa<AffineDimExpr>() || ta.isa<AffineDimExpr>()) {
+ if (isa<AffineDimExpr>(fa) || isa<AffineDimExpr>(ta)) {
AffineDimCollector fCollector;
fCollector.walkPostOrder(fa);
@@ -924,7 +924,7 @@ static Value genIndex(CodegenEnv &env, OpOperand *t) {
// but this is assuming there are in fact `dimRank` many results instead.
const AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), lvlRank - 1));
assert(a.getKind() == AffineExprKind::DimId);
- const LoopId idx = env.makeLoopId(a.cast<AffineDimExpr>().getPosition());
+ const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
return env.getLoopVar(idx);
}
@@ -1631,7 +1631,7 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
for (Level l = startLvl; l < lvlRank; l++) {
// FIXME: `toOrigDim` is deprecated.
AffineExpr lvlExpr = lvlExprs[toOrigDim(enc, l)];
- if (enc.isDenseLvl(l) && lvlExpr.isa<AffineConstantExpr>())
+ if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
env.emitter().genDenseAffineAddress(
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
else
@@ -1722,11 +1722,11 @@ static bool translateBitsToTidLvlPairs(
AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)];
// Skip simple affine expression and non-dense levels (which
// have their own filter loop).
- if (exp.isa<AffineDimExpr>() || !stt.isDenseLvl(l))
+ if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
continue;
// Constant affine expression are handled in genLoop
- if (!exp.isa<AffineConstantExpr>()) {
+ if (!isa<AffineConstantExpr>(exp)) {
bool isAtLoop = false;
if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) {
// If the compound affine is invariant and we are right at the
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index c1358e18a5b2306..7ff435a033985cc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -24,9 +24,9 @@ getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
for (const auto &map : enumerate(reassociation)) {
unsigned startPos =
- map.value().getResults().front().cast<AffineDimExpr>().getPosition();
+ cast<AffineDimExpr>(map.value().getResults().front()).getPosition();
unsigned endPos =
- map.value().getResults().back().cast<AffineDimExpr>().getPosition();
+ cast<AffineDimExpr>(map.value().getResults().back()).getPosition();
for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
expandedDimToCollapsedDim[dim] = map.index();
}
@@ -47,8 +47,8 @@ static OpFoldResult getCollapsedOutputDimFromInputShape(
}
AffineMap map = reassociationMap[dimIndex];
unsigned startPos =
- map.getResults().front().cast<AffineDimExpr>().getPosition();
- unsigned endPos = map.getResults().back().cast<AffineDimExpr>().getPosition();
+ cast<AffineDimExpr>(map.getResults().front()).getPosition();
+ unsigned endPos = cast<AffineDimExpr>(map.getResults().back()).getPosition();
AffineExpr expr;
SmallVector<OpFoldResult> dynamicDims;
for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
@@ -87,16 +87,12 @@ static OpFoldResult getExpandedOutputDimFromInputShape(
return builder.getIndexAttr(dstStaticShape[dimIndex]);
}
unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
- unsigned startPos = reassociation[sourceDimPos]
- .getResults()
- .front()
- .cast<AffineDimExpr>()
- .getPosition();
- unsigned endPos = reassociation[sourceDimPos]
- .getResults()
- .back()
- .cast<AffineDimExpr>()
- .getPosition();
+ unsigned startPos =
+ cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front())
+ .getPosition();
+ unsigned endPos =
+ cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back())
+ .getPosition();
int64_t linearizedStaticDim = 1;
for (auto d :
llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 853889269d0fbca..41c7af4593c77ce 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -149,7 +149,7 @@ unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
for (const auto &exprs : exprArrays) {
for (auto expr : exprs) {
expr.walk([&pos](AffineExpr e) {
- if (auto d = e.dyn_cast<AffineExprTy>())
+ if (auto d = dyn_cast<AffineExprTy>(e))
pos = std::max(pos, d.getPosition());
});
}
@@ -174,7 +174,7 @@ SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
ReassociationIndices indices;
indices.reserve(exprs.size());
for (const auto &expr : exprs)
- indices.push_back(expr.cast<AffineDimExpr>().getPosition());
+ indices.push_back(cast<AffineDimExpr>(expr).getPosition());
reassociationIndices.push_back(indices);
}
return reassociationIndices;
@@ -208,7 +208,7 @@ bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
return false;
}
for (auto e : m.getResults()) {
- auto d = e.dyn_cast<AffineDimExpr>();
+ auto d = dyn_cast<AffineDimExpr>(e);
if (!d || d.getPosition() != nextExpectedDim++) {
if (invalidIndex)
*invalidIndex = it.index();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 69cbdcd3f536f98..ad241f4c48ef284 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -888,14 +888,14 @@ static LogicalResult verifyOutputShape(
/*symCount=*/0, extents, ctx);
// Compose the resMap with the extentsMap, which is a constant map.
AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
- assert(llvm::all_of(
- expectedMap.getResults(),
- [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
- "expected constant extent along all dimensions.");
+ assert(
+ llvm::all_of(expectedMap.getResults(),
+ [](AffineExpr e) { return isa<AffineConstantExpr>(e); }) &&
+ "expected constant extent along all dimensions.");
// Extract the expected shape and build the type.
auto expectedShape = llvm::to_vector<4>(
llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
- return e.cast<AffineConstantExpr>().getValue();
+ return cast<AffineConstantExpr>(e).getValue();
}));
auto expected =
VectorType::get(expectedShape, resVectorType.getElementType(),
@@ -1076,7 +1076,7 @@ void ContractionOp::getIterationIndexMap(
auto index = it.index();
auto map = it.value();
for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
- auto dim = map.getResult(i).cast<AffineDimExpr>();
+ auto dim = cast<AffineDimExpr>(map.getResult(i));
iterationIndexMap[index][dim.getPosition()] = i;
}
}
@@ -3626,8 +3626,8 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) {
SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
for (auto expr : permutationMap.getResults()) {
- auto dim = expr.dyn_cast<AffineDimExpr>();
- auto zero = expr.dyn_cast<AffineConstantExpr>();
+ auto dim = dyn_cast<AffineDimExpr>(expr);
+ auto zero = dyn_cast<AffineConstantExpr>(expr);
if (zero) {
if (zero.getValue() != 0) {
return emitOpError(
@@ -3728,7 +3728,7 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
<< AffineMapAttr::get(permutationMap)
<< " vs inBounds of size: " << inBounds.size();
for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
- if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
+ if (isa<AffineConstantExpr>(permutationMap.getResult(i)) &&
!llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
return op->emitOpError("requires broadcast dimensions to be in-bounds");
}
@@ -3920,7 +3920,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
}
// Currently out-of-bounds, check whether we can statically determine it is
// inBounds.
- auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
+ auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
assert(dimExpr && "Broadcast dims must be in-bounds");
auto inBounds =
isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 6df85978a7e8601..4a5e8fcfb6edaf5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -189,7 +189,7 @@ struct TransferWritePermutationLowering
SmallVector<int64_t> indices;
llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
[](AffineExpr expr) {
- return expr.dyn_cast<AffineDimExpr>().getPosition();
+ return dyn_cast<AffineDimExpr>(expr).getPosition();
});
// Transpose in_bounds attribute.
@@ -248,7 +248,7 @@ struct TransferWriteNonPermutationLowering
// dimension then deduce the missing inner dimensions.
SmallVector<bool> foundDim(map.getNumDims(), false);
for (AffineExpr exp : map.getResults())
- foundDim[exp.cast<AffineDimExpr>().getPosition()] = true;
+ foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
SmallVector<AffineExpr> exprs;
bool foundFirstDim = false;
SmallVector<int64_t> missingInnerDim;
@@ -308,7 +308,7 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
AffineMap map = op.getPermutationMap();
unsigned numLeadingBroadcast = 0;
for (auto expr : map.getResults()) {
- auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
+ auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
if (!dimExpr || dimExpr.getValue() != 0)
break;
numLeadingBroadcast++;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..b85797d789b9939 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -97,7 +97,7 @@ struct DistributedLoadStoreHelper {
SmallVector<Value> indices(rank, zero);
if (val == distributedVal) {
for (auto dimExpr : distributionMap.getResults()) {
- int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
+ int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
indices[index] = buildDistributedOffset(b, loc, index);
}
}
@@ -142,7 +142,7 @@ struct DistributedLoadStoreHelper {
SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
if (type == distributedVectorType) {
for (auto dimExpr : distributionMap.getResults()) {
- int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
+ int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
indices[index] = buildDistributedOffset(b, loc, index);
}
}
@@ -530,11 +530,11 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
AffineExpr d0, d1;
bindDims(newWarpOp.getContext(), d0, d1);
- auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+ auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
if (!indexExpr)
continue;
unsigned indexPos = indexExpr.getPosition();
- unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
+ unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
auto scale =
rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
indices[indexPos] = affine::makeComposedAffineApply(
@@ -834,11 +834,11 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
AffineExpr d0, d1;
bindDims(read.getContext(), d0, d1);
- auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+ auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
if (!indexExpr)
continue;
unsigned indexPos = indexExpr.getPosition();
- unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
+ unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
int64_t scale = distributedType.getDimSize(vectorPos);
indices[indexPos] = affine::makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 6a45231eb80bcea..4cfac7de29ee76f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -37,7 +37,7 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
auto isBroadcast = [](AffineExpr expr) {
- if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
return constExpr.getValue() == 0;
return false;
};
@@ -46,7 +46,7 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
if (isBroadcast(dim.value()))
continue;
- unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
+ unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
auto expr = getAffineDimExpr(0, builder.getContext()) +
getAffineConstantExpr(elementOffsets[dim.index()], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index a2850b2f6d34270..c9be1af034f8e3f 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -69,13 +69,13 @@ AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
case AffineExprKind::Constant:
return *this;
case AffineExprKind::DimId: {
- unsigned dimId = cast<AffineDimExpr>().getPosition();
+ unsigned dimId = llvm::cast<AffineDimExpr>(*this).getPosition();
if (dimId >= dimReplacements.size())
return *this;
return dimReplacements[dimId];
}
case AffineExprKind::SymbolId: {
- unsigned symId = cast<AffineSymbolExpr>().getPosition();
+ unsigned symId = llvm::cast<AffineSymbolExpr>(*this).getPosition();
if (symId >= symReplacements.size())
return *this;
return symReplacements[symId];
@@ -85,7 +85,7 @@ AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod:
- auto binOp = cast<AffineBinaryOpExpr>();
+ auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
@@ -143,7 +143,7 @@ AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod:
- auto binOp = cast<AffineBinaryOpExpr>();
+ auto binOp = llvm::cast<AffineBinaryOpExpr>(*this);
auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
auto newLHS = lhs.replace(map);
auto newRHS = rhs.replace(map);
@@ -176,7 +176,7 @@ bool AffineExpr::isSymbolicOrConstant() const {
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
- auto expr = this->cast<AffineBinaryOpExpr>();
+ auto expr = llvm::cast<AffineBinaryOpExpr>(*this);
return expr.getLHS().isSymbolicOrConstant() &&
expr.getRHS().isSymbolicOrConstant();
}
@@ -193,24 +193,24 @@ bool AffineExpr::isPureAffine() const {
case AffineExprKind::Constant:
return true;
case AffineExprKind::Add: {
- auto op = cast<AffineBinaryOpExpr>();
+ auto op = llvm::cast<AffineBinaryOpExpr>(*this);
return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
}
case AffineExprKind::Mul: {
// TODO: Canonicalize the constants in binary operators to the RHS when
// possible, allowing this to merge into the next case.
- auto op = cast<AffineBinaryOpExpr>();
+ auto op = llvm::cast<AffineBinaryOpExpr>(*this);
return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
- (op.getLHS().template isa<AffineConstantExpr>() ||
- op.getRHS().template isa<AffineConstantExpr>());
+ (llvm::isa<AffineConstantExpr>(op.getLHS()) ||
+ llvm::isa<AffineConstantExpr>(op.getRHS()));
}
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
- auto op = cast<AffineBinaryOpExpr>();
+ auto op = llvm::cast<AffineBinaryOpExpr>(*this);
return op.getLHS().isPureAffine() &&
- op.getRHS().template isa<AffineConstantExpr>();
+ llvm::isa<AffineConstantExpr>(op.getRHS());
}
}
llvm_unreachable("Unknown AffineExpr");
@@ -229,8 +229,8 @@ int64_t AffineExpr::getLargestKnownDivisor() const {
case AffineExprKind::FloorDiv: {
// If the RHS is a constant and divides the known divisor on the LHS, the
// quotient is a known divisor of the expression.
- binExpr = this->cast<AffineBinaryOpExpr>();
- auto rhs = binExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
+ auto rhs = llvm::dyn_cast<AffineConstantExpr>(binExpr.getRHS());
// Leave alone undefined expressions.
if (rhs && rhs.getValue() != 0) {
int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
@@ -240,16 +240,16 @@ int64_t AffineExpr::getLargestKnownDivisor() const {
return 1;
}
case AffineExprKind::Constant:
- return std::abs(this->cast<AffineConstantExpr>().getValue());
+ return std::abs(llvm::cast<AffineConstantExpr>(*this).getValue());
case AffineExprKind::Mul: {
- binExpr = this->cast<AffineBinaryOpExpr>();
+ binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
return binExpr.getLHS().getLargestKnownDivisor() *
binExpr.getRHS().getLargestKnownDivisor();
}
case AffineExprKind::Add:
[[fallthrough]];
case AffineExprKind::Mod: {
- binExpr = cast<AffineBinaryOpExpr>();
+ binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
(uint64_t)binExpr.getRHS().getLargestKnownDivisor());
}
@@ -266,9 +266,9 @@ bool AffineExpr::isMultipleOf(int64_t factor) const {
case AffineExprKind::DimId:
return factor * factor == 1;
case AffineExprKind::Constant:
- return cast<AffineConstantExpr>().getValue() % factor == 0;
+ return llvm::cast<AffineConstantExpr>(*this).getValue() % factor == 0;
case AffineExprKind::Mul: {
- binExpr = cast<AffineBinaryOpExpr>();
+ binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
// It's probably not worth optimizing this further (to not traverse the
// whole sub-tree under - it that would require a version of isMultipleOf
// that on a 'false' return also returns the largest known divisor).
@@ -280,7 +280,7 @@ bool AffineExpr::isMultipleOf(int64_t factor) const {
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
- binExpr = cast<AffineBinaryOpExpr>();
+ binExpr = llvm::cast<AffineBinaryOpExpr>(*this);
return std::gcd((uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
(uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
factor ==
@@ -294,7 +294,7 @@ bool AffineExpr::isFunctionOfDim(unsigned position) const {
if (getKind() == AffineExprKind::DimId) {
return *this == mlir::getAffineDimExpr(position, getContext());
}
- if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
+ if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
return expr.getLHS().isFunctionOfDim(position) ||
expr.getRHS().isFunctionOfDim(position);
}
@@ -305,7 +305,7 @@ bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
if (getKind() == AffineExprKind::SymbolId) {
return *this == mlir::getAffineSymbolExpr(position, getContext());
}
- if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
+ if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(*this)) {
return expr.getLHS().isFunctionOfSymbol(position) ||
expr.getRHS().isFunctionOfSymbol(position);
}
@@ -341,14 +341,14 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
"unexpected opKind");
switch (expr.getKind()) {
case AffineExprKind::Constant:
- return expr.cast<AffineConstantExpr>().getValue() == 0;
+ return cast<AffineConstantExpr>(expr).getValue() == 0;
case AffineExprKind::DimId:
return false;
case AffineExprKind::SymbolId:
- return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
+ return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
// Checks divisibility by the given symbol for both operands.
case AffineExprKind::Add: {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
}
@@ -358,7 +358,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
// s1 but it is not divisible by s1 always. The third argument is
// `AffineExprKind::Mod` for this reason.
case AffineExprKind::Mod: {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
AffineExprKind::Mod) &&
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
@@ -366,7 +366,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
}
// Checks if any of the operand divisible by the given symbol.
case AffineExprKind::Mul: {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
}
@@ -380,7 +380,7 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
// (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (opKind != expr.getKind())
return false;
return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
@@ -400,7 +400,7 @@ static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
"unexpected opKind");
switch (expr.getKind()) {
case AffineExprKind::Constant:
- if (expr.cast<AffineConstantExpr>().getValue() != 0)
+ if (cast<AffineConstantExpr>(expr).getValue() != 0)
return nullptr;
return getAffineConstantExpr(0, expr.getContext());
case AffineExprKind::DimId:
@@ -409,14 +409,14 @@ static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
return getAffineConstantExpr(1, expr.getContext());
// Dividing both operands by the given symbol.
case AffineExprKind::Add: {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return getAffineBinaryOpExpr(
expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
}
// Dividing both operands by the given symbol.
case AffineExprKind::Mod: {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return getAffineBinaryOpExpr(
expr.getKind(),
symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
@@ -424,7 +424,7 @@ static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
}
// Dividing any of the operand by the given symbol.
case AffineExprKind::Mul: {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
return binaryExpr.getLHS() *
symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
@@ -434,7 +434,7 @@ static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
// Dividing first operand only by the given symbol.
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv: {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return getAffineBinaryOpExpr(
expr.getKind(),
symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
@@ -457,7 +457,7 @@ static AffineExpr simplifySemiAffine(AffineExpr expr) {
return expr;
case AffineExprKind::Add:
case AffineExprKind::Mul: {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
return getAffineBinaryOpExpr(expr.getKind(),
simplifySemiAffine(binaryExpr.getLHS()),
simplifySemiAffine(binaryExpr.getRHS()));
@@ -470,11 +470,11 @@ static AffineExpr simplifySemiAffine(AffineExpr expr) {
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
- AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
AffineSymbolExpr symbolExpr =
- simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
+ dyn_cast<AffineSymbolExpr>(simplifySemiAffine(binaryExpr.getRHS()));
if (!symbolExpr)
return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
unsigned symbolPos = symbolExpr.getPosition();
@@ -542,8 +542,8 @@ mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
/// Simplify add expression. Return nullptr if it can't be simplified.
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
- auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
- auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+ auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
// Fold if both LHS, RHS are a constant.
if (lhsConst && rhsConst)
return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
@@ -551,7 +551,7 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
// Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
// If only one of them is a symbolic expressions, make it the RHS.
- if (lhs.isa<AffineConstantExpr>() ||
+ if (isa<AffineConstantExpr>(lhs) ||
(lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
return rhs + lhs;
}
@@ -564,9 +564,9 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
return lhs;
}
// Fold successive additions like (d0 + 2) + 3 into d0 + 5.
- auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+ auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
- if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
+ if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
}
@@ -576,9 +576,9 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
std::optional<int64_t> rLhsConst, rRhsConst;
AffineExpr firstExpr, secondExpr;
AffineConstantExpr rLhsConstExpr;
- auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
+ auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
- (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
+ (rLhsConstExpr = dyn_cast<AffineConstantExpr>(lBinOpExpr.getRHS()))) {
rLhsConst = rLhsConstExpr.getValue();
firstExpr = lBinOpExpr.getLHS();
} else {
@@ -586,10 +586,10 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
firstExpr = lhs;
}
- auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
+ auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(rhs);
AffineConstantExpr rRhsConstExpr;
if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
- (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
+ (rRhsConstExpr = dyn_cast<AffineConstantExpr>(rBinOpExpr.getRHS()))) {
rRhsConst = rRhsConstExpr.getValue();
secondExpr = rBinOpExpr.getLHS();
} else {
@@ -605,7 +605,7 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
// When doing successive additions, bring constant to the right: turn (d0 + 2)
// + d1 into (d0 + d1) + 2.
if (lBin && lBin.getKind() == AffineExprKind::Add) {
- if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+ if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
return lBin.getLHS() + rhs + lrhs;
}
}
@@ -626,16 +626,16 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
// Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
// symbolic expression.
- auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
+ auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
// Check rrhsConstOpExpr = -1.
- auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>();
+ auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rrhs);
if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
// Check llrhs = expr floordiv q.
llrhs = lrhsBinOpExpr.getLHS();
// Check rlrhs = q.
rlrhs = lrhsBinOpExpr.getRHS();
- auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>();
+ auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(llrhs);
if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
return nullptr;
if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
@@ -643,7 +643,7 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
}
// Process lrhs, which is 'expr floordiv c'.
- AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
+ AffineBinaryOpExpr lrBinOpExpr = dyn_cast<AffineBinaryOpExpr>(lrhs);
if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
return nullptr;
@@ -670,8 +670,8 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const {
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
- auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
- auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+ auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
if (lhsConst && rhsConst)
return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
@@ -682,7 +682,7 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
// Canonicalize the mul expression so that the constant/symbolic term is the
// RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
// constant. (Note that a constant is trivially symbolic).
- if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
+ if (!rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
// At least one of them has to be symbolic.
return rhs * lhs;
}
@@ -699,16 +699,16 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
}
// Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
- auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+ auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
- if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
+ if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS()))
return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
}
// When doing successive multiplication, bring constant to the right: turn (d0
// * 2) * d1 into (d0 * d1) * 2.
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
- if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+ if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
return (lBin.getLHS() * rhs) * lrhs;
}
}
@@ -740,8 +740,8 @@ AffineExpr AffineExpr::operator-(AffineExpr other) const {
}
static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
- auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
- auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+ auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
// mlir floordiv by zero or negative numbers is undefined and preserved as is.
if (!rhsConst || rhsConst.getValue() < 1)
@@ -758,9 +758,9 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
// Simplify (expr * const) floordiv divConst when expr is known to be a
// multiple of divConst.
- auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+ auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
- if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+ if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
// rhsConst is known to be a positive constant.
if (lrhs.getValue() % rhsConst.getValue() == 0)
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
@@ -796,8 +796,8 @@ AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
}
static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
- auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
- auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+ auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
if (!rhsConst || rhsConst.getValue() < 1)
return nullptr;
@@ -813,9 +813,9 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
// Simplify (expr * const) ceildiv divConst when const is known to be a
// multiple of divConst.
- auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+ auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
- if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+ if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
// rhsConst is known to be a positive constant.
if (lrhs.getValue() % rhsConst.getValue() == 0)
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
@@ -839,8 +839,8 @@ AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
}
static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
- auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
- auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+ auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
+ auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
// mod w.r.t zero or negative numbers is undefined and preserved as is.
if (!rhsConst || rhsConst.getValue() < 1)
@@ -858,7 +858,7 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
// Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
// known to be a multiple of divConst.
- auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+ auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && lBin.getKind() == AffineExprKind::Add) {
int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
@@ -871,7 +871,7 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
// Simplify (e % a) % b to e % b when b evenly divides a
if (lBin && lBin.getKind() == AffineExprKind::Mod) {
- auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
+ auto intermediate = dyn_cast<AffineConstantExpr>(lBin.getRHS());
if (intermediate && intermediate.getValue() >= 1 &&
mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
return lBin.getLHS() % rhsConst.getValue();
@@ -1036,38 +1036,38 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
AffineExpr expr = it.value();
if (flatExprs[numDims + numSymbols + it.index()] == 0)
continue;
- AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS();
- AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS();
- if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) &&
- (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() ||
- rhs.isa<AffineConstantExpr>()))) {
+ AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
+ AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
+ if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
+ (isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
+ isa<AffineConstantExpr>(rhs)))) {
continue;
}
- if (rhs.isa<AffineConstantExpr>()) {
+ if (isa<AffineConstantExpr>(rhs)) {
// For product/modulo/division expressions, when rhs of modulo/division
// expression is constant, we put 0 in place of keyB, because we want
// them to appear earlier in the semi-affine expression we are
// constructing. When rhs is constant, we place 0 in place of keyB.
- if (lhs.isa<AffineDimExpr>()) {
- lhsPos = lhs.cast<AffineDimExpr>().getPosition();
+ if (isa<AffineDimExpr>(lhs)) {
+ lhsPos = cast<AffineDimExpr>(lhs).getPosition();
std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
expr);
} else {
- lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
+ lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
std::pair<unsigned, signed> indexEntry(
lhsPos, std::max(numDims, numSymbols) + offsetSym++);
addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
expr);
}
- } else if (lhs.isa<AffineDimExpr>()) {
+ } else if (isa<AffineDimExpr>(lhs)) {
// For product/modulo/division expressions having lhs as dimension and rhs
// as symbol, we order the terms in the semi-affine expression based on
// the pair: <keyA, keyB> for expressions of the form dimension * symbol,
// where keyA is the position number of the dimension and keyB is the
// position number of the symbol.
- lhsPos = lhs.cast<AffineDimExpr>().getPosition();
- rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
+ lhsPos = cast<AffineDimExpr>(lhs).getPosition();
+ rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
} else {
@@ -1075,8 +1075,8 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
// symbol, we design indices as a pair: <keyA, keyB> for expressions
// of the form dimension * symbol, where keyA is the position number of
// the dimension and keyB is the position number of the symbol.
- lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
- rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
+ lhsPos = cast<AffineSymbolExpr>(lhs).getPosition();
+ rhsPos = cast<AffineSymbolExpr>(rhs).getPosition();
std::pair<unsigned, signed> indexEntry(
lhsPos, std::max(numDims, numSymbols) + offsetSym++);
addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
@@ -1143,7 +1143,7 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
// Flatten semi-affine multiplication expressions by introducing a local
// variable in place of the product; the affine expression
// corresponding to the quantifier is added to `localExprs`.
- if (!expr.getRHS().isa<AffineConstantExpr>()) {
+ if (!isa<AffineConstantExpr>(expr.getRHS())) {
MLIRContext *context = expr.getContext();
AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
localExprs, context);
@@ -1194,7 +1194,7 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
// Flatten semi affine modulo expressions by introducing a local
// variable in place of the modulo value, and the affine expression
// corresponding to the quantifier is added to `localExprs`.
- if (!expr.getRHS().isa<AffineConstantExpr>()) {
+ if (!isa<AffineConstantExpr>(expr.getRHS())) {
AffineExpr dividendExpr = getAffineExprFromFlatForm(
lhs, numDims, numSymbols, localExprs, context);
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
@@ -1318,7 +1318,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
// Flatten semi affine division expressions by introducing a local
// variable in place of the quotient, and the affine expression corresponding
// to the quantifier is added to `localExprs`.
- if (!expr.getRHS().isa<AffineConstantExpr>()) {
+ if (!isa<AffineConstantExpr>(expr.getRHS())) {
AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
localExprs, context);
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
@@ -1443,11 +1443,11 @@ std::optional<int64_t> mlir::getBoundForAffineExpr(
ArrayRef<std::optional<int64_t>> constLowerBounds,
ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
// Handle divs and mods.
- if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+ if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
// If the LHS of a floor or ceil is bounded and the RHS is a constant, we
// can compute an upper bound.
if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
- auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
if (!rhsConst || rhsConst.getValue() < 1)
return std::nullopt;
auto bound =
@@ -1458,7 +1458,7 @@ std::optional<int64_t> mlir::getBoundForAffineExpr(
return mlir::floorDiv(*bound, rhsConst.getValue());
}
if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
- auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
if (rhsConst && rhsConst.getValue() >= 1) {
auto bound =
getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
@@ -1473,7 +1473,7 @@ std::optional<int64_t> mlir::getBoundForAffineExpr(
// lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
// bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
// (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
- auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ auto rhsConst = dyn_cast<AffineConstantExpr>(binOpExpr.getRHS());
if (rhsConst && rhsConst.getValue() >= 1) {
int64_t rhsConstVal = rhsConst.getValue();
auto lb = getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index cdcd71cdd7cd151..86f5b610f9ba595 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -46,6 +46,8 @@ class AffineExprConstantFolder {
return nullptr;
}
+ bool hasPoison() const { return hasPoison_; }
+
private:
std::optional<int64_t> constantFoldImpl(AffineExpr expr) {
switch (expr.getKind()) {
@@ -57,24 +59,33 @@ class AffineExprConstantFolder {
expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
case AffineExprKind::Mod:
return constantFoldBinExpr(
- expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
+ expr,
+ [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+ return mod(lhs, rhs);
+ });
case AffineExprKind::FloorDiv:
return constantFoldBinExpr(
- expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
+ expr,
+ [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+ return floorDiv(lhs, rhs);
+ });
case AffineExprKind::CeilDiv:
return constantFoldBinExpr(
- expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
+ expr,
+ [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
+ return ceilDiv(lhs, rhs);
+ });
case AffineExprKind::Constant:
- return expr.cast<AffineConstantExpr>().getValue();
+ return cast<AffineConstantExpr>(expr).getValue();
case AffineExprKind::DimId:
if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
- operandConsts[expr.cast<AffineDimExpr>().getPosition()]))
+ operandConsts[cast<AffineDimExpr>(expr).getPosition()]))
return attr.getInt();
return std::nullopt;
case AffineExprKind::SymbolId:
if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
operandConsts[numDims +
- expr.cast<AffineSymbolExpr>().getPosition()]))
+ cast<AffineSymbolExpr>(expr).getPosition()]))
return attr.getInt();
return std::nullopt;
}
@@ -82,9 +93,10 @@ class AffineExprConstantFolder {
}
// TODO: Change these to operate on APInts too.
- std::optional<int64_t> constantFoldBinExpr(AffineExpr expr,
- int64_t (*op)(int64_t, int64_t)) {
- auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ std::optional<int64_t> constantFoldBinExpr(
+ AffineExpr expr,
+ llvm::function_ref<std::optional<int64_t>(int64_t, int64_t)> op) {
+ auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
return op(*lhs, *rhs);
@@ -95,6 +107,7 @@ class AffineExprConstantFolder {
unsigned numDims;
// The constant valued operands used to evaluate this AffineExpr.
ArrayRef<Attribute> operandConsts;
+ bool hasPoison_{false};
};
} // namespace
@@ -122,7 +135,7 @@ AffineMap AffineMap::getFilteredIdentityMap(
// Apply filter to results.
llvm::SmallBitVector dropDimResults(numDims);
for (auto [idx, resultExpr] : llvm::enumerate(identityMap.getResults()))
- dropDimResults[idx] = !keepDimFilter(resultExpr.cast<AffineDimExpr>());
+ dropDimResults[idx] = !keepDimFilter(cast<AffineDimExpr>(resultExpr));
return identityMap.dropResults(dropDimResults);
}
@@ -145,13 +158,13 @@ bool AffineMap::isMinorIdentityWithBroadcasting(
for (const auto &idxAndExpr : llvm::enumerate(getResults())) {
unsigned resIdx = idxAndExpr.index();
AffineExpr expr = idxAndExpr.value();
- if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
// Each result may be either a constant 0 (broadcasted dimension).
if (constExpr.getValue() != 0)
return false;
if (broadcastedDims)
broadcastedDims->push_back(resIdx);
- } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+ } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
// Or it may be the input dimension corresponding to this result position.
if (dimExpr.getPosition() != suffixStart + resIdx)
return false;
@@ -194,11 +207,11 @@ bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting(
AffineExpr expr = idxAndExpr.value();
// Each result may be either a constant 0 (broadcast dimension) or a
// dimension.
- if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
if (constExpr.getValue() != 0)
return false;
broadcastDims.push_back(resIdx);
- } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+ } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
if (dimExpr.getPosition() < projectionStart)
return false;
unsigned newPosition =
@@ -297,7 +310,7 @@ bool AffineMap::isIdentity() const {
return false;
ArrayRef<AffineExpr> results = getResults();
for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
- auto expr = results[i].dyn_cast<AffineDimExpr>();
+ auto expr = dyn_cast<AffineDimExpr>(results[i]);
if (!expr || expr.getPosition() != i)
return false;
}
@@ -309,7 +322,7 @@ bool AffineMap::isSymbolIdentity() const {
return false;
ArrayRef<AffineExpr> results = getResults();
for (unsigned i = 0, numSymbols = getNumSymbols(); i < numSymbols; ++i) {
- auto expr = results[i].dyn_cast<AffineDimExpr>();
+ auto expr = dyn_cast<AffineDimExpr>(results[i]);
if (!expr || expr.getPosition() != i)
return false;
}
@@ -321,25 +334,25 @@ bool AffineMap::isEmpty() const {
}
bool AffineMap::isSingleConstant() const {
- return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
+ return getNumResults() == 1 && isa<AffineConstantExpr>(getResult(0));
}
bool AffineMap::isConstant() const {
return llvm::all_of(getResults(), [](AffineExpr expr) {
- return expr.isa<AffineConstantExpr>();
+ return isa<AffineConstantExpr>(expr);
});
}
int64_t AffineMap::getSingleConstantResult() const {
assert(isSingleConstant() && "map must have a single constant result");
- return getResult(0).cast<AffineConstantExpr>().getValue();
+ return cast<AffineConstantExpr>(getResult(0)).getValue();
}
SmallVector<int64_t> AffineMap::getConstantResults() const {
assert(isConstant() && "map must have only constant results");
SmallVector<int64_t> result;
for (auto expr : getResults())
- result.emplace_back(expr.cast<AffineConstantExpr>().getValue());
+ result.emplace_back(cast<AffineConstantExpr>(expr).getValue());
return result;
}
@@ -365,11 +378,11 @@ AffineExpr AffineMap::getResult(unsigned idx) const {
}
unsigned AffineMap::getDimPosition(unsigned idx) const {
- return getResult(idx).cast<AffineDimExpr>().getPosition();
+ return cast<AffineDimExpr>(getResult(idx)).getPosition();
}
std::optional<unsigned> AffineMap::getResultPosition(AffineExpr input) const {
- if (!input.isa<AffineDimExpr>())
+ if (!isa<AffineDimExpr>(input))
return std::nullopt;
for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) {
@@ -536,7 +549,7 @@ SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
SmallVector<int64_t, 4> res;
res.reserve(resMap.getNumResults());
for (auto e : resMap.getResults())
- res.push_back(e.cast<AffineConstantExpr>().getValue());
+ res.push_back(cast<AffineConstantExpr>(e).getValue());
return res;
}
@@ -555,12 +568,12 @@ bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
// number of result expressions is lower or equal than the number of input
// expressions.
for (auto expr : getResults()) {
- if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
+ if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
if (seen[dim.getPosition()])
return false;
seen[dim.getPosition()] = true;
} else {
- auto constExpr = expr.dyn_cast<AffineConstantExpr>();
+ auto constExpr = dyn_cast<AffineConstantExpr>(expr);
if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0)
return false;
}
@@ -722,7 +735,7 @@ AffineMap mlir::inversePermutation(AffineMap map) {
for (const auto &en : llvm::enumerate(map.getResults())) {
auto expr = en.value();
// Skip non-permutations.
- if (auto d = expr.dyn_cast<AffineDimExpr>()) {
+ if (auto d = dyn_cast<AffineDimExpr>(expr)) {
if (exprs[d.getPosition()])
continue;
exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
@@ -746,7 +759,7 @@ AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) {
SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero);
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
// Skip zeros from input map. 'exprs' is already initialized to zero.
- if (auto constExpr = map.getResult(i).dyn_cast<AffineConstantExpr>()) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(map.getResult(i))) {
assert(constExpr.getValue() == 0 &&
"Unexpected constant in projected permutation");
(void)constExpr;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 82e1e96229b79e0..83cdf0d22410aa6 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2815,7 +2815,7 @@ void AsmPrinter::Impl::printAffineExprInternal(
const char *binopSpelling = nullptr;
switch (expr.getKind()) {
case AffineExprKind::SymbolId: {
- unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
+ unsigned pos = cast<AffineSymbolExpr>(expr).getPosition();
if (printValueName)
printValueName(pos, /*isSymbol=*/true);
else
@@ -2823,7 +2823,7 @@ void AsmPrinter::Impl::printAffineExprInternal(
return;
}
case AffineExprKind::DimId: {
- unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+ unsigned pos = cast<AffineDimExpr>(expr).getPosition();
if (printValueName)
printValueName(pos, /*isSymbol=*/false);
else
@@ -2831,7 +2831,7 @@ void AsmPrinter::Impl::printAffineExprInternal(
return;
}
case AffineExprKind::Constant:
- os << expr.cast<AffineConstantExpr>().getValue();
+ os << cast<AffineConstantExpr>(expr).getValue();
return;
case AffineExprKind::Add:
binopSpelling = " + ";
@@ -2850,7 +2850,7 @@ void AsmPrinter::Impl::printAffineExprInternal(
break;
}
- auto binOp = expr.cast<AffineBinaryOpExpr>();
+ auto binOp = cast<AffineBinaryOpExpr>(expr);
AffineExpr lhsExpr = binOp.getLHS();
AffineExpr rhsExpr = binOp.getRHS();
@@ -2860,7 +2860,7 @@ void AsmPrinter::Impl::printAffineExprInternal(
os << '(';
// Pretty print multiplication with -1.
- auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
+ auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr);
if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
rhsConst.getValue() == -1) {
os << "-";
@@ -2886,10 +2886,10 @@ void AsmPrinter::Impl::printAffineExprInternal(
// Pretty print addition to a product that has a negative operand as a
// subtraction.
- if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
+ if (auto rhs = dyn_cast<AffineBinaryOpExpr>(rhsExpr)) {
if (rhs.getKind() == AffineExprKind::Mul) {
AffineExpr rrhsExpr = rhs.getRHS();
- if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
+ if (auto rrhs = dyn_cast<AffineConstantExpr>(rrhsExpr)) {
if (rrhs.getValue() == -1) {
printAffineExprInternal(lhsExpr, BindingStrength::Weak,
printValueName);
@@ -2923,7 +2923,7 @@ void AsmPrinter::Impl::printAffineExprInternal(
}
// Pretty print addition to a negative number as a subtraction.
- if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
+ if (auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr)) {
if (rhsConst.getValue() < 0) {
printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
os << " - " << -rhsConst.getValue();
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a9284d5714637bc..9b8ee3d45280353 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -678,7 +678,7 @@ static void extractStridesFromTerm(AffineExpr e,
AffineExpr multiplicativeFactor,
MutableArrayRef<AffineExpr> strides,
AffineExpr &offset) {
- if (auto dim = e.dyn_cast<AffineDimExpr>())
+ if (auto dim = dyn_cast<AffineDimExpr>(e))
strides[dim.getPosition()] =
strides[dim.getPosition()] + multiplicativeFactor;
else
@@ -693,7 +693,7 @@ static LogicalResult extractStrides(AffineExpr e,
AffineExpr multiplicativeFactor,
MutableArrayRef<AffineExpr> strides,
AffineExpr &offset) {
- auto bin = e.dyn_cast<AffineBinaryOpExpr>();
+ auto bin = dyn_cast<AffineBinaryOpExpr>(e);
if (!bin) {
extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
return success();
@@ -705,7 +705,7 @@ static LogicalResult extractStrides(AffineExpr e,
return failure();
if (bin.getKind() == AffineExprKind::Mul) {
- auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
+ auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
if (dim) {
strides[dim.getPosition()] =
strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
@@ -820,12 +820,12 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
SmallVector<AffineExpr, 4> strideExprs;
if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
return failure();
- if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
+ if (auto cst = dyn_cast<AffineConstantExpr>(offsetExpr))
offset = cst.getValue();
else
offset = ShapedType::kDynamic;
for (auto e : strideExprs) {
- if (auto c = e.dyn_cast<AffineConstantExpr>())
+ if (auto c = dyn_cast<AffineConstantExpr>(e))
strides.push_back(c.getValue());
else
strides.push_back(ShapedType::kDynamic);
@@ -888,7 +888,7 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
// Corner-case for 0-D affine maps.
if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
- if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
+ if (auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
if (cst.getValue() == 0)
return MemRefType::Builder(t).setLayout({});
return t;
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 5898b0f7d69e832..fb3c9d48f9a9821 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -869,7 +869,7 @@ exprs.push_back(getAffineConstantExpr(cst{1}, context));
assert(arg.indexAttrMap);
for (auto [idx, result] :
llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) {
- if (auto symbol = result.dyn_cast<AffineSymbolExpr>()) {
+ if (auto symbol = dyn_cast<AffineSymbolExpr>(result)) {
std::string argName = arg.name;
argName[0] = toupper(argName[0]);
symbolBindings[symbol.getPosition()] =
More information about the Mlir-commits
mailing list