[Mlir-commits] [mlir] [mlir][NFC] Simplify type checks with isa predicates (PR #87183)

Jakub Kuderski llvmlistbot at llvm.org
Sat Mar 30 23:09:00 PDT 2024


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/87183

For more context on isa predicates, see: https://github.com/llvm/llvm-project/pull/83753.

>From 383887c49fd08be5f8f07a1c2a6cde60e6bf3a6d Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 31 Mar 2024 02:07:35 -0400
Subject: [PATCH] [mlir][NFC] Simplify type checks with isa predicates

For more context on isa predicates, see: https://github.com/llvm/llvm-project/pull/83753.
---
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   |  3 +-
 .../Conversion/VectorToGPU/VectorToGPU.cpp    | 10 ++---
 .../Affine/Analysis/AffineAnalysis.cpp        |  3 +-
 .../Affine/Transforms/SuperVectorize.cpp      |  5 +--
 .../Dialect/Affine/Utils/LoopFusionUtils.cpp  |  3 +-
 .../FuncBufferizableOpInterfaceImpl.cpp       |  2 +-
 .../Transforms/OneShotModuleBufferize.cpp     |  8 ++--
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |  3 +-
 .../GPU/TransformOps/GPUTransformOps.cpp      | 24 ++++-------
 .../GPU/Transforms/AsyncRegionRewriter.cpp    |  5 +--
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    |  6 +--
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  6 +--
 .../Linalg/Transforms/ElementwiseToLinalg.cpp |  3 +-
 .../Linalg/Transforms/Vectorization.cpp       |  9 ++--
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       |  2 +-
 .../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp |  3 +-
 mlir/lib/Dialect/Shape/IR/Shape.cpp           |  5 +--
 mlir/lib/Dialect/Traits.cpp                   | 15 ++++---
 .../lib/Dialect/Transform/IR/TransformOps.cpp |  2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  9 ++--
 .../Vector/Transforms/VectorDistribute.cpp    | 41 +++++++++----------
 mlir/lib/IR/AffineMap.cpp                     |  4 +-
 mlir/lib/IR/Operation.cpp                     |  4 +-
 mlir/lib/TableGen/Class.cpp                   |  4 +-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        |  5 +--
 .../Target/SPIRV/Serialization/Serializer.cpp |  8 ++--
 .../Transforms/Utils/DialectConversion.cpp    |  6 +--
 .../Transforms/Utils/OneToNTypeConversion.cpp |  3 +-
 28 files changed, 83 insertions(+), 118 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 73d418cb841327..993c09b03c0fde 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -545,8 +545,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
                                       ConversionPatternRewriter &rewriter,
                                       const LLVMTypeConverter &converter) {
   TypeRange operandTypes(operands);
-  if (llvm::none_of(operandTypes,
-                    [](Type type) { return isa<VectorType>(type); })) {
+  if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
     return rewriter.notifyMatchFailure(op, "expected vector operand");
   }
   if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 85fb8a539912f7..399c0450824ee5 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -202,9 +202,7 @@ template <typename ExtOpTy>
 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
   if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
     return false;
-  return llvm::all_of(extOp->getUsers(), [](Operation *user) {
-    return isa<vector::ContractionOp>(user);
-  });
+  return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
 }
 
 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
@@ -345,15 +343,13 @@ getSliceContract(Operation *op,
 static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
                                              bool useNvGpu) {
   auto hasVectorDest = [](Operation *op) {
-    return llvm::any_of(op->getResultTypes(),
-                        [](Type t) { return isa<VectorType>(t); });
+    return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
   };
   BackwardSliceOptions backwardSliceOptions;
   backwardSliceOptions.filter = hasVectorDest;
 
   auto hasVectorSrc = [](Operation *op) {
-    return llvm::any_of(op->getOperandTypes(),
-                        [](Type t) { return isa<VectorType>(t); });
+    return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
   };
   ForwardSliceOptions forwardSliceOptions;
   forwardSliceOptions.filter = hasVectorSrc;
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 61244921bc38ac..69b3d41e17c2d4 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -136,8 +136,7 @@ static bool isLocallyDefined(Value v, Operation *enclosingOp) {
 
 bool mlir::affine::isLoopMemoryParallel(AffineForOp forOp) {
   // Any memref-typed iteration arguments are treated as serializing.
-  if (llvm::any_of(forOp.getResultTypes(),
-                   [](Type type) { return isa<BaseMemRefType>(type); }))
+  if (llvm::any_of(forOp.getResultTypes(), llvm::IsaPred<BaseMemRefType>))
     return false;
 
   // Collect all load and store ops in loop nest rooted at 'forOp'.
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 46c7871f40232f..71e9648a5e00fa 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -609,9 +609,8 @@ makePattern(const DenseSet<Operation *> &parallelLoops, int vectorRank,
 }
 
 static NestedPattern &vectorTransferPattern() {
-  static auto pattern = affine::matcher::Op([](Operation &op) {
-    return isa<vector::TransferReadOp, vector::TransferWriteOp>(op);
-  });
+  static auto pattern = affine::matcher::Op(
+      llvm::IsaPred<vector::TransferReadOp, vector::TransferWriteOp>);
   return pattern;
 }
 
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index fb45528ad5e7d1..84ae4b52dcf4e8 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -211,8 +211,7 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
   unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
 
   // Return common loop depth for loads if there are no store ops.
-  if (all_of(targetDstOps,
-             [&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
+  if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>))
     return loopDepth;
 
   // Check dependences on all pairs of ops in 'targetDstOps' and store the
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 4cdbbf35dc876b..053ea7935260a2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -326,7 +326,7 @@ struct FuncOpInterface
   static bool supportsUnstructuredControlFlow() { return true; }
 
   bool hasTensorSemantics(Operation *op) const {
-    auto isaTensor = [](Type type) { return isa<TensorType>(type); };
+    auto isaTensor = llvm::IsaPred<TensorType>;
 
     // A function has tensor semantics if it has tensor arguments/results.
     auto funcOp = cast<FuncOp>(op);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 33feea0b956ca0..0a4072605c265f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -67,6 +67,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
 
 using namespace mlir;
@@ -277,9 +278,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
 
 /// Return "true" if the given function signature has tensor semantics.
 static bool hasTensorSignature(func::FuncOp funcOp) {
-  auto isaTensor = [](Type t) { return isa<TensorType>(t); };
-  return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) ||
-         llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor);
+  return llvm::any_of(funcOp.getFunctionType().getInputs(),
+                      llvm::IsaPred<TensorType>) ||
+         llvm::any_of(funcOp.getFunctionType().getResults(),
+                      llvm::IsaPred<TensorType>);
 }
 
 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ab5c418e844fbf..f4a9dc3ca509c8 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -224,8 +224,7 @@ LogicalResult emitc::CallOpaqueOp::verify() {
     }
   }
 
-  if (llvm::any_of(getResultTypes(),
-                   [](Type type) { return isa<ArrayType>(type); })) {
+  if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
     return emitOpError() << "cannot return array type";
   }
 
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index fc3a4375694588..b584f63f16e0aa 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -296,22 +296,14 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
                                  "scf.forall op requires a mapping attribute");
   }
 
-  bool hasBlockMapping =
-      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
-        return isa<GPUBlockMappingAttr>(attr);
-      });
-  bool hasWarpgroupMapping =
-      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
-        return isa<GPUWarpgroupMappingAttr>(attr);
-      });
-  bool hasWarpMapping =
-      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
-        return isa<GPUWarpMappingAttr>(attr);
-      });
-  bool hasThreadMapping =
-      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
-        return isa<GPUThreadMappingAttr>(attr);
-      });
+  bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
+                                      llvm::IsaPred<GPUBlockMappingAttr>);
+  bool hasWarpgroupMapping = llvm::any_of(
+      forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
+  bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
+                                     llvm::IsaPred<GPUWarpMappingAttr>);
+  bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
+                                       llvm::IsaPred<GPUThreadMappingAttr>);
   int64_t countMappingTypes = 0;
   countMappingTypes += hasBlockMapping ? 1 : 0;
   countMappingTypes += hasWarpgroupMapping ? 1 : 0;
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index 40903f199afddd..b2fa3a99c53fc3 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -232,9 +232,8 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
   // control flow code.
   static bool areAllUsersExecuteOrAwait(Value token) {
     return !token.use_empty() &&
-           llvm::all_of(token.getUsers(), [](Operation *user) {
-             return isa<async::ExecuteOp, async::AwaitOp>(user);
-           });
+           llvm::all_of(token.getUsers(),
+                        llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
   }
 
   // Add the `asyncToken` as dependency as needed after `op`.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3ba6ac6ccc8142..e5c19a916392e1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2786,10 +2786,8 @@ LogicalResult LLVM::BitcastOp::verify() {
   if (!resultType)
     return success();
 
-  auto isVector = [](Type type) {
-    return llvm::isa<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>(
-        type);
-  };
+  auto isVector =
+      llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>;
 
   // Due to bitcast requiring both operands to be of the same size, it is not
   // possible for only one of the two to be a pointer of vectors.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6954eee93efd14..2d7219fef87c64 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
@@ -119,8 +120,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
                                    TypeRange inputTypes, TypeRange outputTypes,
                                    ArrayRef<NamedAttribute> attrs,
                                    RegionBuilderFn regionBuilder) {
-  assert(llvm::all_of(outputTypes,
-                      [](Type t) { return llvm::isa<ShapedType>(t); }));
+  assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
 
   SmallVector<Type, 8> argTypes;
   SmallVector<Location, 8> argLocs;
@@ -162,7 +162,7 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
       resultTensorTypes.value_or(TypeRange());
   if (!resultTensorTypes)
     copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
-            [](Type type) { return llvm::isa<RankedTensorType>(type); });
+            llvm::IsaPred<RankedTensorType>);
 
   state.addOperands(inputs);
   state.addOperands(outputs);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 5508aaf9d87537..28d6752fc2d388 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -27,8 +27,7 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
 
   // TODO: The conversion pattern can be made to work for `any_of` here, but
   // it's more complex as it requires tracking which operands are scalars.
-  return llvm::all_of(op->getOperandTypes(),
-                      [](Type type) { return isa<RankedTensorType>(type); });
+  return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
 }
 
 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c74ab1e6448bec..25785653a71675 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3537,15 +3537,14 @@ struct Conv1DGenerator
   // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
   // must be block arguments or extension of block arguments.
   bool setOperKind(Operation *reduceOp) {
-    int numBlockArguments = llvm::count_if(
-        reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); });
+    int numBlockArguments =
+        llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
     switch (numBlockArguments) {
     case 1: {
       // Will be convolution if feeder is a MulOp.
       // Otherwise, if it can be pooling.
-      auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
-        return !isa<BlockArgument>(v);
-      });
+      auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
+                                         llvm::IsaPred<BlockArgument>);
       Operation *feedOp = (*feedValIt).getDefiningOp();
       if (isCastOfBlockArgument(feedOp)) {
         oper = Pool;
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index c09a3403f9a3e3..9ba96e4be7d1fc 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -457,7 +457,7 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
 }
 
 static bool isComputeOperation(Operation *op) {
-  return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
+  return isa<acc::ParallelOp, acc::LoopOp>(op);
 }
 
 namespace {
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index d532d466334a56..2ff3efdc96a7f8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -125,8 +125,7 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
   if (getMatrixOperands()) {
     Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
                            typeC.getElementType()};
-    if (!llvm::all_of(elementTypes,
-                      [](Type ty) { return isa<IntegerType>(ty); })) {
+    if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) {
       return emitOpError("Matrix Operands require all matrix element types to "
                          "be Integer Types");
     }
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index f5a3717f815de5..58c3f4c334577c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -65,9 +65,8 @@ LogicalResult shape::getShapeVec(Value input,
 }
 
 static bool isErrorPropagationPossible(TypeRange operandTypes) {
-  return llvm::any_of(operandTypes, [](Type ty) {
-    return llvm::isa<SizeType, ShapeType, ValueShapeType>(ty);
-  });
+  return llvm::any_of(operandTypes,
+                      llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
 }
 
 static LogicalResult verifySizeOrIndexOp(Operation *op) {
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index d4e0f8a3137053..2efc157ce79617 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -188,9 +188,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
 /// Returns a tuple corresponding to whether range has tensor or vector type.
 template <typename iterator_range>
 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
-  return std::make_tuple(
-      llvm::any_of(types, [](Type t) { return isa<TensorType>(t); }),
-      llvm::any_of(types, [](Type t) { return isa<VectorType>(t); }));
+  return {llvm::any_of(types, llvm::IsaPred<TensorType>),
+          llvm::any_of(types, llvm::IsaPred<VectorType>)};
 }
 
 static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
@@ -202,7 +201,7 @@ static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
   };
   if (inferred.size() != existing.size())
     return false;
-  for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing))
+  for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
     if (!isCompatible(inferredDim, existingDim))
       return false;
   return true;
@@ -238,8 +237,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
        std::get<1>(resultsHasTensorVectorType)))
     return op->emitError("cannot broadcast vector with tensor");
 
-  auto rankedOperands = make_filter_range(
-      op->getOperandTypes(), [](Type t) { return isa<RankedTensorType>(t); });
+  auto rankedOperands =
+      make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
 
   // If all operands are unranked, then all result shapes are possible.
   if (rankedOperands.empty())
@@ -257,8 +256,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
       return op->emitOpError("operands don't have broadcast-compatible shapes");
   }
 
-  auto rankedResults = make_filter_range(
-      op->getResultTypes(), [](Type t) { return isa<RankedTensorType>(t); });
+  auto rankedResults =
+      make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>);
 
   // If all of the results are unranked then no further verification.
   if (rankedResults.empty())
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 578b2492bbab46..c8d06ba157b904 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -819,7 +819,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   assert(outputs.size() == 1 && "expected one output");
   return llvm::all_of(
       std::initializer_list<Type>{inputs.front(), outputs.front()},
-      [](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); });
+      llvm::IsaPred<transform::TransformHandleTypeInterface>);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e566bfacf37984..3e6425879cc67f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -898,13 +898,12 @@ static LogicalResult verifyOutputShape(
 
     AffineMap resMap = op.getIndexingMapsArray()[2];
     auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
-                                     /*symCount=*/0, extents, ctx);
+                                     /*symbolCount=*/0, extents, ctx);
     // Compose the resMap with the extentsMap, which is a constant map.
     AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
-    assert(
-        llvm::all_of(expectedMap.getResults(),
-                     [](AffineExpr e) { return isa<AffineConstantExpr>(e); }) &&
-        "expected constant extent along all dimensions.");
+    assert(llvm::all_of(expectedMap.getResults(),
+                        llvm::IsaPred<AffineConstantExpr>) &&
+           "expected constant extent along all dimensions.");
     // Extract the expected shape and build the type.
     auto expectedShape = llvm::to_vector<4>(
         llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index b3ab4a916121e3..a67e03e85f7145 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -598,9 +598,8 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
     }
 
     // Do not process warp ops that contain only TransferWriteOps.
-    if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
-          return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
-        }))
+    if (llvm::all_of(warpOp.getOps(),
+                     llvm::IsaPred<vector::TransferWriteOp, vector::YieldOp>))
       return failure();
 
     SmallVector<Value> yieldValues = {writeOp.getVector()};
@@ -746,8 +745,8 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *yieldOperand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); });
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
     if (!yieldOperand)
       return failure();
     auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
@@ -1060,8 +1059,8 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
     if (!operand)
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
@@ -1097,8 +1096,8 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
     if (!operand)
       return failure();
 
@@ -1156,8 +1155,8 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *yieldOperand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); });
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
     if (!yieldOperand)
       return failure();
 
@@ -1222,8 +1221,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
     if (!operand)
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
@@ -1325,9 +1324,8 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
         warpShuffleFromIdxFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
-      return isa<vector::ExtractElementOp>(op);
-    });
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
     if (!operand)
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
@@ -1422,8 +1420,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<vector::InsertElementOp>(op); });
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
     if (!operand)
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
@@ -1503,8 +1501,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<vector::InsertOp>(op); });
+    OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
     if (!operand)
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
@@ -1808,8 +1805,8 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *yieldOperand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
     if (!yieldOperand)
       return failure();
 
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 00a0f05b633303..6cdc2682753fc7 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -359,9 +359,7 @@ bool AffineMap::isSingleConstant() const {
 }
 
 bool AffineMap::isConstant() const {
-  return llvm::all_of(getResults(), [](AffineExpr expr) {
-    return isa<AffineConstantExpr>(expr);
-  });
+  return llvm::all_of(getResults(), llvm::IsaPred<AffineConstantExpr>);
 }
 
 int64_t AffineMap::getSingleConstantResult() const {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index d6d59837d48ac8..ca5ff9f72e3e29 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1288,9 +1288,7 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
 }
 
 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
-  auto isMappableType = [](Type type) {
-    return llvm::isa<VectorType, TensorType>(type);
-  };
+  auto isMappableType = llvm::IsaPred<VectorType, TensorType>;
   auto resultMappableTypes = llvm::to_vector<1>(
       llvm::make_filter_range(op->getResultTypes(), isMappableType));
   auto operandMappableTypes = llvm::to_vector<2>(
diff --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp
index 9092adcc627c08..fedf64fd96b0d4 100644
--- a/mlir/lib/TableGen/Class.cpp
+++ b/mlir/lib/TableGen/Class.cpp
@@ -369,9 +369,7 @@ void Class::finalize() {
 
 Visibility Class::getLastVisibilityDecl() const {
   auto reverseDecls = llvm::reverse(declarations);
-  auto it = llvm::find_if(reverseDecls, [](auto &decl) {
-    return isa<VisibilityDeclaration>(decl);
-  });
+  auto it = llvm::find_if(reverseDecls, llvm::IsaPred<VisibilityDeclaration>);
   return it == reverseDecls.end()
              ? (isStruct ? Visibility::Public : Visibility::Private)
              : cast<VisibilityDeclaration>(**it).getVisibility();
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 95c7af2f07be46..0b07b4b06dfc71 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1000,8 +1000,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
         "with multiple blocks needs variables declared at top");
   }
 
-  if (llvm::any_of(functionOp.getResultTypes(),
-                   [](Type type) { return isa<ArrayType>(type); })) {
+  if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) {
     return functionOp.emitOpError() << "cannot emit array type as result type";
   }
 
@@ -1576,7 +1575,7 @@ LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
 }
 
 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
-  if (llvm::any_of(types, [](Type type) { return isa<ArrayType>(type); })) {
+  if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
     return emitError(loc, "cannot emit tuple of array type");
   }
   os << "std::tuple<";
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 4a4e878d8af915..9a74ac115e9555 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1031,9 +1031,9 @@ Serializer::processBlock(Block *block, bool omitLabel,
   // into multiple basic blocks. If that's the case, we need to emit the merge
   // right now and then create new blocks for further serialization of the ops
   // in this block.
-  if (emitMerge && llvm::any_of(block->getOperations(), [](Operation &op) {
-        return isa<spirv::LoopOp, spirv::SelectionOp>(op);
-      })) {
+  if (emitMerge &&
+      llvm::any_of(block->getOperations(),
+                   llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
     if (failed(emitMerge()))
       return failure();
     emitMerge = nullptr;
@@ -1045,7 +1045,7 @@ Serializer::processBlock(Block *block, bool omitLabel,
   }
 
   // Process each op in this block except the terminator.
-  for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
+  for (Operation &op : llvm::drop_end(*block)) {
     if (failed(processOperation(&op)))
       return failure();
   }
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3c72c8789e8ec5..3d309f15b3cc3a 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2827,9 +2827,9 @@ static void computeNecessaryMaterializations(
     }
 
     // Check to see if this is an argument materialization.
-    auto isBlockArg = [](Value v) { return isa<BlockArgument>(v); };
-    if (llvm::any_of(op->getOperands(), isBlockArg) ||
-        llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
+    if (llvm::any_of(op->getOperands(), llvm::IsaPred<BlockArgument>) ||
+        llvm::any_of(inverseMapping[op->getResult(0)],
+                     llvm::IsaPred<BlockArgument>)) {
       mat->setMaterializationKind(MaterializationKind::Argument);
     }
 
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 7cb957d5ec29ea..fef9d8eb0fef74 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -392,8 +392,7 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
         // Argument materialization.
         assert(castKind == getCastKindName(CastKind::Argument) &&
                "unexpected value of cast kind attribute");
-        assert(llvm::all_of(operands,
-                            [&](Value v) { return isa<BlockArgument>(v); }));
+        assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>));
         maybeResult = typeConverter.materializeArgumentConversion(
             rewriter, castOp->getLoc(), resultTypes.front(),
             castOp.getOperands());



More information about the Mlir-commits mailing list