[Mlir-commits] [mlir] [mlir][NFC] Simplify type checks with isa predicates (PR #87183)
Jakub Kuderski
llvmlistbot at llvm.org
Sat Mar 30 23:09:00 PDT 2024
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/87183
For more context on isa predicates, see: https://github.com/llvm/llvm-project/pull/83753.
>From 383887c49fd08be5f8f07a1c2a6cde60e6bf3a6d Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 31 Mar 2024 02:07:35 -0400
Subject: [PATCH] [mlir][NFC] Simplify type checks with isa predicates
For more context on isa predicates, see: https://github.com/llvm/llvm-project/pull/83753.
---
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 3 +-
.../Conversion/VectorToGPU/VectorToGPU.cpp | 10 ++---
.../Affine/Analysis/AffineAnalysis.cpp | 3 +-
.../Affine/Transforms/SuperVectorize.cpp | 5 +--
.../Dialect/Affine/Utils/LoopFusionUtils.cpp | 3 +-
.../FuncBufferizableOpInterfaceImpl.cpp | 2 +-
.../Transforms/OneShotModuleBufferize.cpp | 8 ++--
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 3 +-
.../GPU/TransformOps/GPUTransformOps.cpp | 24 ++++-------
.../GPU/Transforms/AsyncRegionRewriter.cpp | 5 +--
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 6 +--
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 6 +--
.../Linalg/Transforms/ElementwiseToLinalg.cpp | 3 +-
.../Linalg/Transforms/Vectorization.cpp | 9 ++--
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 2 +-
.../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 3 +-
mlir/lib/Dialect/Shape/IR/Shape.cpp | 5 +--
mlir/lib/Dialect/Traits.cpp | 15 ++++---
.../lib/Dialect/Transform/IR/TransformOps.cpp | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 ++--
.../Vector/Transforms/VectorDistribute.cpp | 41 +++++++++----------
mlir/lib/IR/AffineMap.cpp | 4 +-
mlir/lib/IR/Operation.cpp | 4 +-
mlir/lib/TableGen/Class.cpp | 4 +-
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 5 +--
.../Target/SPIRV/Serialization/Serializer.cpp | 8 ++--
.../Transforms/Utils/DialectConversion.cpp | 6 +--
.../Transforms/Utils/OneToNTypeConversion.cpp | 3 +-
28 files changed, 83 insertions(+), 118 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 73d418cb841327..993c09b03c0fde 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -545,8 +545,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter) {
TypeRange operandTypes(operands);
- if (llvm::none_of(operandTypes,
- [](Type type) { return isa<VectorType>(type); })) {
+ if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
return rewriter.notifyMatchFailure(op, "expected vector operand");
}
if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 85fb8a539912f7..399c0450824ee5 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -202,9 +202,7 @@ template <typename ExtOpTy>
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
return false;
- return llvm::all_of(extOp->getUsers(), [](Operation *user) {
- return isa<vector::ContractionOp>(user);
- });
+ return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
}
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
@@ -345,15 +343,13 @@ getSliceContract(Operation *op,
static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
bool useNvGpu) {
auto hasVectorDest = [](Operation *op) {
- return llvm::any_of(op->getResultTypes(),
- [](Type t) { return isa<VectorType>(t); });
+ return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
};
BackwardSliceOptions backwardSliceOptions;
backwardSliceOptions.filter = hasVectorDest;
auto hasVectorSrc = [](Operation *op) {
- return llvm::any_of(op->getOperandTypes(),
- [](Type t) { return isa<VectorType>(t); });
+ return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
};
ForwardSliceOptions forwardSliceOptions;
forwardSliceOptions.filter = hasVectorSrc;
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 61244921bc38ac..69b3d41e17c2d4 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -136,8 +136,7 @@ static bool isLocallyDefined(Value v, Operation *enclosingOp) {
bool mlir::affine::isLoopMemoryParallel(AffineForOp forOp) {
// Any memref-typed iteration arguments are treated as serializing.
- if (llvm::any_of(forOp.getResultTypes(),
- [](Type type) { return isa<BaseMemRefType>(type); }))
+ if (llvm::any_of(forOp.getResultTypes(), llvm::IsaPred<BaseMemRefType>))
return false;
// Collect all load and store ops in loop nest rooted at 'forOp'.
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 46c7871f40232f..71e9648a5e00fa 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -609,9 +609,8 @@ makePattern(const DenseSet<Operation *> ¶llelLoops, int vectorRank,
}
static NestedPattern &vectorTransferPattern() {
- static auto pattern = affine::matcher::Op([](Operation &op) {
- return isa<vector::TransferReadOp, vector::TransferWriteOp>(op);
- });
+ static auto pattern = affine::matcher::Op(
+ llvm::IsaPred<vector::TransferReadOp, vector::TransferWriteOp>);
return pattern;
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index fb45528ad5e7d1..84ae4b52dcf4e8 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -211,8 +211,7 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
// Return common loop depth for loads if there are no store ops.
- if (all_of(targetDstOps,
- [&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
+ if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>))
return loopDepth;
// Check dependences on all pairs of ops in 'targetDstOps' and store the
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 4cdbbf35dc876b..053ea7935260a2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -326,7 +326,7 @@ struct FuncOpInterface
static bool supportsUnstructuredControlFlow() { return true; }
bool hasTensorSemantics(Operation *op) const {
- auto isaTensor = [](Type type) { return isa<TensorType>(type); };
+ auto isaTensor = llvm::IsaPred<TensorType>;
// A function has tensor semantics if it has tensor arguments/results.
auto funcOp = cast<FuncOp>(op);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 33feea0b956ca0..0a4072605c265f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -67,6 +67,7 @@
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
using namespace mlir;
@@ -277,9 +278,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
/// Return "true" if the given function signature has tensor semantics.
static bool hasTensorSignature(func::FuncOp funcOp) {
- auto isaTensor = [](Type t) { return isa<TensorType>(t); };
- return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) ||
- llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor);
+ return llvm::any_of(funcOp.getFunctionType().getInputs(),
+ llvm::IsaPred<TensorType>) ||
+ llvm::any_of(funcOp.getFunctionType().getResults(),
+ llvm::IsaPred<TensorType>);
}
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ab5c418e844fbf..f4a9dc3ca509c8 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -224,8 +224,7 @@ LogicalResult emitc::CallOpaqueOp::verify() {
}
}
- if (llvm::any_of(getResultTypes(),
- [](Type type) { return isa<ArrayType>(type); })) {
+ if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
return emitOpError() << "cannot return array type";
}
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index fc3a4375694588..b584f63f16e0aa 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -296,22 +296,14 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
"scf.forall op requires a mapping attribute");
}
- bool hasBlockMapping =
- llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return isa<GPUBlockMappingAttr>(attr);
- });
- bool hasWarpgroupMapping =
- llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return isa<GPUWarpgroupMappingAttr>(attr);
- });
- bool hasWarpMapping =
- llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return isa<GPUWarpMappingAttr>(attr);
- });
- bool hasThreadMapping =
- llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return isa<GPUThreadMappingAttr>(attr);
- });
+ bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
+ llvm::IsaPred<GPUBlockMappingAttr>);
+ bool hasWarpgroupMapping = llvm::any_of(
+ forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
+ bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
+ llvm::IsaPred<GPUWarpMappingAttr>);
+ bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
+ llvm::IsaPred<GPUThreadMappingAttr>);
int64_t countMappingTypes = 0;
countMappingTypes += hasBlockMapping ? 1 : 0;
countMappingTypes += hasWarpgroupMapping ? 1 : 0;
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index 40903f199afddd..b2fa3a99c53fc3 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -232,9 +232,8 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
// control flow code.
static bool areAllUsersExecuteOrAwait(Value token) {
return !token.use_empty() &&
- llvm::all_of(token.getUsers(), [](Operation *user) {
- return isa<async::ExecuteOp, async::AwaitOp>(user);
- });
+ llvm::all_of(token.getUsers(),
+ llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
}
// Add the `asyncToken` as dependency as needed after `op`.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3ba6ac6ccc8142..e5c19a916392e1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2786,10 +2786,8 @@ LogicalResult LLVM::BitcastOp::verify() {
if (!resultType)
return success();
- auto isVector = [](Type type) {
- return llvm::isa<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>(
- type);
- };
+ auto isVector =
+ llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>;
// Due to bitcast requiring both operands to be of the same size, it is not
// possible for only one of the two to be a pointer of vectors.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6954eee93efd14..2d7219fef87c64 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -28,6 +28,7 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
@@ -119,8 +120,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<NamedAttribute> attrs,
RegionBuilderFn regionBuilder) {
- assert(llvm::all_of(outputTypes,
- [](Type t) { return llvm::isa<ShapedType>(t); }));
+ assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
SmallVector<Type, 8> argTypes;
SmallVector<Location, 8> argLocs;
@@ -162,7 +162,7 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
resultTensorTypes.value_or(TypeRange());
if (!resultTensorTypes)
copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
- [](Type type) { return llvm::isa<RankedTensorType>(type); });
+ llvm::IsaPred<RankedTensorType>);
state.addOperands(inputs);
state.addOperands(outputs);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 5508aaf9d87537..28d6752fc2d388 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -27,8 +27,7 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
// TODO: The conversion pattern can be made to work for `any_of` here, but
// it's more complex as it requires tracking which operands are scalars.
- return llvm::all_of(op->getOperandTypes(),
- [](Type type) { return isa<RankedTensorType>(type); });
+ return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c74ab1e6448bec..25785653a71675 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3537,15 +3537,14 @@ struct Conv1DGenerator
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
// must be block arguments or extension of block arguments.
bool setOperKind(Operation *reduceOp) {
- int numBlockArguments = llvm::count_if(
- reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); });
+ int numBlockArguments =
+ llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
switch (numBlockArguments) {
case 1: {
// Will be convolution if feeder is a MulOp.
// Otherwise, if it can be pooling.
- auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
- return !isa<BlockArgument>(v);
- });
+ auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
+ llvm::IsaPred<BlockArgument>);
Operation *feedOp = (*feedValIt).getDefiningOp();
if (isCastOfBlockArgument(feedOp)) {
oper = Pool;
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index c09a3403f9a3e3..9ba96e4be7d1fc 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -457,7 +457,7 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
}
static bool isComputeOperation(Operation *op) {
- return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
+ return isa<acc::ParallelOp, acc::LoopOp>(op);
}
namespace {
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index d532d466334a56..2ff3efdc96a7f8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -125,8 +125,7 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
if (getMatrixOperands()) {
Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
typeC.getElementType()};
- if (!llvm::all_of(elementTypes,
- [](Type ty) { return isa<IntegerType>(ty); })) {
+ if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) {
return emitOpError("Matrix Operands require all matrix element types to "
"be Integer Types");
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index f5a3717f815de5..58c3f4c334577c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -65,9 +65,8 @@ LogicalResult shape::getShapeVec(Value input,
}
static bool isErrorPropagationPossible(TypeRange operandTypes) {
- return llvm::any_of(operandTypes, [](Type ty) {
- return llvm::isa<SizeType, ShapeType, ValueShapeType>(ty);
- });
+ return llvm::any_of(operandTypes,
+ llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
}
static LogicalResult verifySizeOrIndexOp(Operation *op) {
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index d4e0f8a3137053..2efc157ce79617 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -188,9 +188,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
/// Returns a tuple corresponding to whether range has tensor or vector type.
template <typename iterator_range>
static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
- return std::make_tuple(
- llvm::any_of(types, [](Type t) { return isa<TensorType>(t); }),
- llvm::any_of(types, [](Type t) { return isa<VectorType>(t); }));
+ return {llvm::any_of(types, llvm::IsaPred<TensorType>),
+ llvm::any_of(types, llvm::IsaPred<VectorType>)};
}
static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
@@ -202,7 +201,7 @@ static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
};
if (inferred.size() != existing.size())
return false;
- for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing))
+ for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
if (!isCompatible(inferredDim, existingDim))
return false;
return true;
@@ -238,8 +237,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
std::get<1>(resultsHasTensorVectorType)))
return op->emitError("cannot broadcast vector with tensor");
- auto rankedOperands = make_filter_range(
- op->getOperandTypes(), [](Type t) { return isa<RankedTensorType>(t); });
+ auto rankedOperands =
+ make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
// If all operands are unranked, then all result shapes are possible.
if (rankedOperands.empty())
@@ -257,8 +256,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
return op->emitOpError("operands don't have broadcast-compatible shapes");
}
- auto rankedResults = make_filter_range(
- op->getResultTypes(), [](Type t) { return isa<RankedTensorType>(t); });
+ auto rankedResults =
+ make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>);
// If all of the results are unranked then no further verification.
if (rankedResults.empty())
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 578b2492bbab46..c8d06ba157b904 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -819,7 +819,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
assert(outputs.size() == 1 && "expected one output");
return llvm::all_of(
std::initializer_list<Type>{inputs.front(), outputs.front()},
- [](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); });
+ llvm::IsaPred<transform::TransformHandleTypeInterface>);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e566bfacf37984..3e6425879cc67f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -898,13 +898,12 @@ static LogicalResult verifyOutputShape(
AffineMap resMap = op.getIndexingMapsArray()[2];
auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
- /*symCount=*/0, extents, ctx);
+ /*symbolCount=*/0, extents, ctx);
// Compose the resMap with the extentsMap, which is a constant map.
AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
- assert(
- llvm::all_of(expectedMap.getResults(),
- [](AffineExpr e) { return isa<AffineConstantExpr>(e); }) &&
- "expected constant extent along all dimensions.");
+ assert(llvm::all_of(expectedMap.getResults(),
+ llvm::IsaPred<AffineConstantExpr>) &&
+ "expected constant extent along all dimensions.");
// Extract the expected shape and build the type.
auto expectedShape = llvm::to_vector<4>(
llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index b3ab4a916121e3..a67e03e85f7145 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -598,9 +598,8 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
// Do not process warp ops that contain only TransferWriteOps.
- if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
- return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
- }))
+ if (llvm::all_of(warpOp.getOps(),
+ llvm::IsaPred<vector::TransferWriteOp, vector::YieldOp>))
return failure();
SmallVector<Value> yieldValues = {writeOp.getVector()};
@@ -746,8 +745,8 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *yieldOperand = getWarpResult(
- warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); });
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
if (!yieldOperand)
return failure();
auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
@@ -1060,8 +1059,8 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(
- warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
@@ -1097,8 +1096,8 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(
- warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
if (!operand)
return failure();
@@ -1156,8 +1155,8 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *yieldOperand = getWarpResult(
- warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
if (!yieldOperand)
return failure();
@@ -1222,8 +1221,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(
- warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
@@ -1325,9 +1324,8 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
warpShuffleFromIdxFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
- return isa<vector::ExtractElementOp>(op);
- });
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
@@ -1422,8 +1420,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(
- warpOp, [](Operation *op) { return isa<vector::InsertElementOp>(op); });
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
@@ -1503,8 +1501,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(
- warpOp, [](Operation *op) { return isa<vector::InsertOp>(op); });
+ OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
@@ -1808,8 +1805,8 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *yieldOperand = getWarpResult(
- warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
+ OpOperand *yieldOperand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
if (!yieldOperand)
return failure();
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 00a0f05b633303..6cdc2682753fc7 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -359,9 +359,7 @@ bool AffineMap::isSingleConstant() const {
}
bool AffineMap::isConstant() const {
- return llvm::all_of(getResults(), [](AffineExpr expr) {
- return isa<AffineConstantExpr>(expr);
- });
+ return llvm::all_of(getResults(), llvm::IsaPred<AffineConstantExpr>);
}
int64_t AffineMap::getSingleConstantResult() const {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index d6d59837d48ac8..ca5ff9f72e3e29 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1288,9 +1288,7 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
}
LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
- auto isMappableType = [](Type type) {
- return llvm::isa<VectorType, TensorType>(type);
- };
+ auto isMappableType = llvm::IsaPred<VectorType, TensorType>;
auto resultMappableTypes = llvm::to_vector<1>(
llvm::make_filter_range(op->getResultTypes(), isMappableType));
auto operandMappableTypes = llvm::to_vector<2>(
diff --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp
index 9092adcc627c08..fedf64fd96b0d4 100644
--- a/mlir/lib/TableGen/Class.cpp
+++ b/mlir/lib/TableGen/Class.cpp
@@ -369,9 +369,7 @@ void Class::finalize() {
Visibility Class::getLastVisibilityDecl() const {
auto reverseDecls = llvm::reverse(declarations);
- auto it = llvm::find_if(reverseDecls, [](auto &decl) {
- return isa<VisibilityDeclaration>(decl);
- });
+ auto it = llvm::find_if(reverseDecls, llvm::IsaPred<VisibilityDeclaration>);
return it == reverseDecls.end()
? (isStruct ? Visibility::Public : Visibility::Private)
: cast<VisibilityDeclaration>(**it).getVisibility();
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 95c7af2f07be46..0b07b4b06dfc71 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1000,8 +1000,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
"with multiple blocks needs variables declared at top");
}
- if (llvm::any_of(functionOp.getResultTypes(),
- [](Type type) { return isa<ArrayType>(type); })) {
+ if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) {
return functionOp.emitOpError() << "cannot emit array type as result type";
}
@@ -1576,7 +1575,7 @@ LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
}
LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
- if (llvm::any_of(types, [](Type type) { return isa<ArrayType>(type); })) {
+ if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
return emitError(loc, "cannot emit tuple of array type");
}
os << "std::tuple<";
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 4a4e878d8af915..9a74ac115e9555 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1031,9 +1031,9 @@ Serializer::processBlock(Block *block, bool omitLabel,
// into multiple basic blocks. If that's the case, we need to emit the merge
// right now and then create new blocks for further serialization of the ops
// in this block.
- if (emitMerge && llvm::any_of(block->getOperations(), [](Operation &op) {
- return isa<spirv::LoopOp, spirv::SelectionOp>(op);
- })) {
+ if (emitMerge &&
+ llvm::any_of(block->getOperations(),
+ llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
if (failed(emitMerge()))
return failure();
emitMerge = nullptr;
@@ -1045,7 +1045,7 @@ Serializer::processBlock(Block *block, bool omitLabel,
}
// Process each op in this block except the terminator.
- for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
+ for (Operation &op : llvm::drop_end(*block)) {
if (failed(processOperation(&op)))
return failure();
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3c72c8789e8ec5..3d309f15b3cc3a 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2827,9 +2827,9 @@ static void computeNecessaryMaterializations(
}
// Check to see if this is an argument materialization.
- auto isBlockArg = [](Value v) { return isa<BlockArgument>(v); };
- if (llvm::any_of(op->getOperands(), isBlockArg) ||
- llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
+ if (llvm::any_of(op->getOperands(), llvm::IsaPred<BlockArgument>) ||
+ llvm::any_of(inverseMapping[op->getResult(0)],
+ llvm::IsaPred<BlockArgument>)) {
mat->setMaterializationKind(MaterializationKind::Argument);
}
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 7cb957d5ec29ea..fef9d8eb0fef74 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -392,8 +392,7 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
// Argument materialization.
assert(castKind == getCastKindName(CastKind::Argument) &&
"unexpected value of cast kind attribute");
- assert(llvm::all_of(operands,
- [&](Value v) { return isa<BlockArgument>(v); }));
+ assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>));
maybeResult = typeConverter.materializeArgumentConversion(
rewriter, castOp->getLoc(), resultTypes.front(),
castOp.getOperands());
More information about the Mlir-commits
mailing list