[Mlir-commits] [mlir] ee394e6 - [MLIR] Add variadic isa<> for Type, Value, and Attribute

Rahul Joshi llvmlistbot at llvm.org
Mon Jun 29 15:05:20 PDT 2020


Author: Rahul Joshi
Date: 2020-06-29T15:04:48-07:00
New Revision: ee394e6842733a38ee0953d8ee018547ecbef8fd

URL: https://github.com/llvm/llvm-project/commit/ee394e6842733a38ee0953d8ee018547ecbef8fd
DIFF: https://github.com/llvm/llvm-project/commit/ee394e6842733a38ee0953d8ee018547ecbef8fd.diff

LOG: [MLIR] Add variadic isa<> for Type, Value, and Attribute

- Also adopt variadic llvm::isa<> in more places.
- Fixes https://bugs.llvm.org/show_bug.cgi?id=46445

Differential Revision: https://reviews.llvm.org/D82769

Added: 
    

Modified: 
    mlir/docs/Tutorials/Toy/Ch-7.md
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/include/mlir/EDSC/Builders.h
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/Matchers.h
    mlir/include/mlir/IR/StandardTypes.h
    mlir/include/mlir/IR/Types.h
    mlir/include/mlir/IR/Value.h
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/lib/Dialect/Affine/EDSC/Builders.cpp
    mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
    mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Quant/IR/QuantOps.cpp
    mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Traits.cpp
    mlir/lib/IR/Attributes.cpp
    mlir/lib/IR/StandardTypes.cpp
    mlir/lib/IR/SymbolTable.cpp
    mlir/lib/Parser/AttributeParser.cpp
    mlir/lib/Parser/TypeParser.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
    mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md
index 0dce6d25c904..733e22c5b0a5 100644
--- a/mlir/docs/Tutorials/Toy/Ch-7.md
+++ b/mlir/docs/Tutorials/Toy/Ch-7.md
@@ -287,8 +287,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
       return nullptr;
 
     // Check that the type is either a TensorType or another StructType.
-    if (!elementType.isa<mlir::TensorType>() &&
-        !elementType.isa<StructType>()) {
+    if (!elementType.isa<mlir::TensorType, StructType>()) {
       parser.emitError(typeLoc, "element type for a struct must either "
                                 "be a TensorType or a StructType, got: ")
           << elementType;

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 867da9c202cc..fc7bf2a2375c 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -510,8 +510,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
       return nullptr;
 
     // Check that the type is either a TensorType or another StructType.
-    if (!elementType.isa<mlir::TensorType>() &&
-        !elementType.isa<StructType>()) {
+    if (!elementType.isa<mlir::TensorType, StructType>()) {
       parser.emitError(typeLoc, "element type for a struct must either "
                                 "be a TensorType or a StructType, got: ")
           << elementType;

diff  --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h
index 64df2c9fe367..1f21af617e4d 100644
--- a/mlir/include/mlir/EDSC/Builders.h
+++ b/mlir/include/mlir/EDSC/Builders.h
@@ -139,15 +139,12 @@ struct StructuredIndexed {
 
   StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
       : value(v), exprs(indexings.begin(), indexings.end()) {
-    assert((v.getType().isa<MemRefType>() ||
-            v.getType().isa<RankedTensorType>() ||
-            v.getType().isa<VectorType>()) &&
+    assert((v.getType().isa<MemRefType, RankedTensorType, VectorType>()) &&
            "MemRef, RankedTensor or Vector expected");
   }
   StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
       : type(t), exprs(indexings.begin(), indexings.end()) {
-    assert((t.isa<MemRefType>() || t.isa<RankedTensorType>() ||
-            t.isa<VectorType>()) &&
+    assert((t.isa<MemRefType, RankedTensorType, VectorType>()) &&
            "MemRef, RankedTensor or Vector expected");
   }
 

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index f9d8efd42272..ea3011f0fdc7 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -85,6 +85,8 @@ class Attribute {
   bool operator!() const { return impl == nullptr; }
 
   template <typename U> bool isa() const;
+  template <typename First, typename Second, typename... Rest>
+  bool isa() const;
   template <typename U> U dyn_cast() const;
   template <typename U> U dyn_cast_or_null() const;
   template <typename U> U cast() const;
@@ -1630,6 +1632,12 @@ template <typename U> bool Attribute::isa() const {
   assert(impl && "isa<> used on a null attribute.");
   return U::classof(*this);
 }
+
+template <typename First, typename Second, typename... Rest>
+bool Attribute::isa() const {
+  return isa<First>() || isa<Second, Rest...>();
+}
+
 template <typename U> U Attribute::dyn_cast() const {
   return isa<U>() ? U(impl) : U(nullptr);
 }

diff  --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 0f74f1b9cd43..72e17e1699a0 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -97,9 +97,9 @@ struct constant_int_op_binder {
       return false;
     auto type = op->getResult(0).getType();
 
-    if (type.isa<IntegerType>() || type.isa<IndexType>())
+    if (type.isa<IntegerType, IndexType>())
       return attr_value_binder<IntegerAttr>(bind_value).match(attr);
-    if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
+    if (type.isa<VectorType, RankedTensorType>()) {
       if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
         return attr_value_binder<IntegerAttr>(bind_value)
             .match(splatAttr.getSplatValue());

diff  --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 99ad3b10f6d9..cde43843e8b5 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -357,7 +357,7 @@ class VectorType
   /// Returns true of the given type can be used as an element of a vector type.
   /// In particular, vectors can consist of integer or float primitives.
   static bool isValidElementType(Type t) {
-    return t.isa<IntegerType>() || t.isa<FloatType>();
+    return t.isa<IntegerType, FloatType>();
   }
 
   ArrayRef<int64_t> getShape() const;
@@ -381,9 +381,8 @@ class TensorType : public ShapedType {
     // Note: Non standard/builtin types are allowed to exist within tensor
     // types. Dialects are expected to verify that tensor types have a valid
     // element type within that dialect.
-    return type.isa<ComplexType>() || type.isa<FloatType>() ||
-           type.isa<IntegerType>() || type.isa<OpaqueType>() ||
-           type.isa<VectorType>() || type.isa<IndexType>() ||
+    return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
+                    IndexType>() ||
            (type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
   }
 

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 32ae13f86dc9..60bc04a8708c 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -121,6 +121,8 @@ class Type {
   bool operator!() const { return impl == nullptr; }
 
   template <typename U> bool isa() const;
+  template <typename First, typename Second, typename... Rest>
+  bool isa() const;
   template <typename U> U dyn_cast() const;
   template <typename U> U dyn_cast_or_null() const;
   template <typename U> U cast() const;
@@ -271,6 +273,12 @@ template <typename U> bool Type::isa() const {
   assert(impl && "isa<> used on a null type.");
   return U::classof(*this);
 }
+
+template <typename First, typename Second, typename... Rest>
+bool Type::isa() const {
+  return isa<First>() || isa<Second, Rest...>();
+}
+
 template <typename U> U Type::dyn_cast() const {
   return isa<U>() ? U(impl) : U(nullptr);
 }

diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index f5cb16f347ed..c22741ee6cb6 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -81,6 +81,12 @@ class Value {
     assert(*this && "isa<> used on a null type.");
     return U::classof(*this);
   }
+
+  template <typename First, typename Second, typename... Rest>
+  bool isa() const {
+    return isa<First>() || isa<Second, Rest...>();
+  }
+
   template <typename U> U dyn_cast() const {
     return isa<U>() ? U(ownerAndKind) : U(nullptr);
   }

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 8a29fdbfd00b..ab273f8d95d5 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -956,8 +956,7 @@ static Optional<int64_t> getMemoryFootprintBytes(Block &block,
 
   // Walk this 'affine.for' operation to gather all memory regions.
   auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
-    if (!isa<AffineReadOpInterface>(opInst) &&
-        !isa<AffineWriteOpInterface>(opInst)) {
+    if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
       // Neither load nor a store op.
       return WalkResult::advance();
     }
@@ -1017,11 +1016,9 @@ bool mlir::isLoopParallel(AffineForOp forOp) {
   // Collect all load and store ops in loop nest rooted at 'forOp'.
   SmallVector<Operation *, 8> loadAndStoreOpInsts;
   auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult {
-    if (isa<AffineReadOpInterface>(opInst) ||
-        isa<AffineWriteOpInterface>(opInst))
+    if (isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst))
       loadAndStoreOpInsts.push_back(opInst);
-    else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
-             !isa<AffineIfOp>(opInst) &&
+    else if (!isa<AffineForOp, AffineTerminatorOp, AffineIfOp>(opInst) &&
              !MemoryEffectOpInterface::hasNoEffect(opInst))
       return WalkResult::interrupt();
 

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 9b651bb8b80a..b6900d13094c 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -302,7 +302,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
     auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
     if (!converted)
       return {};
-    if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>())
+    if (t.isa<MemRefType, UnrankedMemRefType>())
       converted = converted.getPointerTo();
     inputs.push_back(converted);
   }
@@ -1044,7 +1044,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
       FunctionType type, SmallVectorImpl<UnsignedTypePair> &argsInfo) const {
     argsInfo.reserve(type.getNumInputs());
     for (auto en : llvm::enumerate(type.getInputs())) {
-      if (en.value().isa<MemRefType>() || en.value().isa<UnrankedMemRefType>())
+      if (en.value().isa<MemRefType, UnrankedMemRefType>())
         argsInfo.push_back({en.index(), en.value()});
     }
   }

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 45992c888d72..aac275548891 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -518,7 +518,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
     return failure();
 
   // std.constant should only have vector or tenor types.
-  assert(srcType.isa<VectorType>() || srcType.isa<RankedTensorType>());
+  assert((srcType.isa<VectorType, RankedTensorType>()));
 
   auto dstType = typeConverter.convertType(srcType);
   if (!dstType)

diff  --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
index e5bf1c015e02..beeaeaa9cf27 100644
--- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
@@ -117,7 +117,7 @@ static Value createBinaryHandle(
     return ValueBuilder<IOp>(lhs, rhs);
   } else if (thisType.isa<FloatType>()) {
     return ValueBuilder<FOp>(lhs, rhs);
-  } else if (thisType.isa<VectorType>() || thisType.isa<TensorType>()) {
+  } else if (thisType.isa<VectorType, TensorType>()) {
     auto aggregateType = thisType.cast<ShapedType>();
     if (aggregateType.getElementType().isSignlessInteger())
       return ValueBuilder<IOp>(lhs, rhs);

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 89cbca0444f5..ea66fcb3b090 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -218,7 +218,7 @@ void AffineDataCopyGeneration::runOnFunction() {
     nest->walk([&](Operation *op) {
       if (auto forOp = dyn_cast<AffineForOp>(op))
         promoteIfSingleIteration(forOp);
-      else if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
+      else if (isa<AffineLoadOp, AffineStoreOp>(op))
         copyOps.push_back(op);
     });
 

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
index e060aac03e44..aaa21104e1fd 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
@@ -80,7 +80,7 @@ bool isOpLoopInvariant(Operation &op, Value indVar,
     // If the body of a predicated region has a for loop, we don't hoist the
     // 'affine.if'.
     return false;
-  } else if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) {
+  } else if (isa<AffineDmaStartOp, AffineDmaWaitOp>(op)) {
     // TODO(asabne): Support DMA ops.
     return false;
   } else if (!isa<ConstantOp>(op)) {
@@ -91,7 +91,7 @@ bool isOpLoopInvariant(Operation &op, Value indVar,
       for (auto *user : memref.getUsers()) {
         // If this memref has a user that is a DMA, give up because these
         // operations write to this memref.
-        if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) {
+        if (isa<AffineDmaStartOp, AffineDmaWaitOp>(op)) {
           return false;
         }
         // If the memref used by the load/store is used in a store elsewhere in

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 48231642eae7..e37146e73954 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -923,11 +923,11 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
     return nullptr;
 
   // Fuse when consumer is GenericOp or IndexedGenericOp.
-  if (isa<GenericOp>(consumer) || isa<IndexedGenericOp>(consumer)) {
+  if (isa<GenericOp, IndexedGenericOp>(consumer)) {
     auto linalgOpConsumer = cast<LinalgOp>(consumer);
     if (!linalgOpConsumer.hasTensorSemantics())
       return nullptr;
-    if (isa<GenericOp>(producer) || isa<IndexedGenericOp>(producer)) {
+    if (isa<GenericOp, IndexedGenericOp>(producer)) {
       auto linalgOpProducer = cast<LinalgOp>(producer);
       if (linalgOpProducer.hasTensorSemantics())
         return FuseGenericOpsOnTensors::fuse(linalgOpProducer, linalgOpConsumer,

diff  --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index b0dc1fa10679..07f881fbc52c 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -46,7 +46,7 @@ OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
 static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
   if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
     Type spec = typeAttr.getValue();
-    if (spec.isa<TensorType>() || spec.isa<VectorType>())
+    if (spec.isa<TensorType, VectorType>())
       return false;
 
     // The spec should be either a quantized type which is compatible to the

diff  --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
index 2ff23123b474..88eb314a852b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
@@ -69,8 +69,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
   }
 
   // Is the constant value a type expressed in a way that we support?
-  if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
-      !value.isa<SparseElementsAttr>()) {
+  if (!value.isa<FloatAttr, DenseElementsAttr, SparseElementsAttr>()) {
     return failure();
   }
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 8368a8e1857b..1ac6a1e6d75b 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1292,7 +1292,7 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
     return failure();
 
   Type type = value.getType();
-  if (type.isa<NoneType>() || type.isa<TensorType>()) {
+  if (type.isa<NoneType, TensorType>()) {
     if (parser.parseColonType(type))
       return failure();
   }
@@ -1827,8 +1827,8 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
     // TODO: Currently only variable initialization with specialization
     // constants and other variables is supported. They could be normal
     // constants in the module scope as well.
-    if (!initOp || !(isa<spirv::GlobalVariableOp>(initOp) ||
-                     isa<spirv::SpecConstantOp>(initOp))) {
+    if (!initOp ||
+        !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
       return varOp.emitOpError("initializer must be result of a "
                                "spv.specConstant or spv.globalVariable op");
     }
@@ -2093,8 +2093,7 @@ void spirv::LoopOp::addEntryAndMergeBlock() {
 
 static LogicalResult verify(spirv::MergeOp mergeOp) {
   auto *parentOp = mergeOp.getParentOp();
-  if (!parentOp ||
-      (!isa<spirv::SelectionOp>(parentOp) && !isa<spirv::LoopOp>(parentOp)))
+  if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
     return mergeOp.emitOpError(
         "expected parent op to be 'spv.selection' or 'spv.loop'");
 
@@ -2620,9 +2619,9 @@ static LogicalResult verify(spirv::VariableOp varOp) {
     // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
     // a global (module scope) OpVariable instruction".
     auto *initOp = varOp.getOperand(0).getDefiningOp();
-    if (!initOp || !(isa<spirv::ConstantOp>(initOp) ||    // for normal constant
-                     isa<spirv::ReferenceOfOp>(initOp) || // for spec constant
-                     isa<spirv::AddressOfOp>(initOp)))
+    if (!initOp || !isa<spirv::ConstantOp,    // for normal constant
+                        spirv::ReferenceOfOp, // for spec constant
+                        spirv::AddressOfOp>(initOp))
       return varOp.emitOpError("initializer must be the result of a "
                                "constant or spv.globalVariable op");
   }

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 8df99378868b..b81f7f4c7387 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1176,8 +1176,7 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
   if (value.getType() != type)
     return false;
   // Finally, check that the attribute kind is handled.
-  return value.isa<IntegerAttr>() || value.isa<FloatAttr>() ||
-         value.isa<ElementsAttr>() || value.isa<UnitAttr>();
+  return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
 }
 
 void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
@@ -2103,7 +2102,7 @@ static LogicalResult verify(SelectOp op) {
   // If the result type is a vector or tensor, the type can be a mask with the
   // same elements.
   Type resultType = op.getType();
-  if (!resultType.isa<TensorType>() && !resultType.isa<VectorType>())
+  if (!resultType.isa<TensorType, VectorType>())
     return op.emitOpError()
            << "expected condition to be a signless i1, but got "
            << conditionType;
@@ -2222,8 +2221,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 1 && "splat takes one operand");
 
   auto constOperand = operands.front();
-  if (!constOperand ||
-      (!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>()))
+  if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
     return {};
 
   auto shapedType = getType().cast<ShapedType>();

diff  --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index dc721adc7472..c974e2fc097b 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -107,7 +107,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
   // Returns the type kind if the given type is a vector or ranked tensor type.
   // Returns llvm::None otherwise.
   auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
-    if (type.isa<VectorType>() || type.isa<RankedTensorType>())
+    if (type.isa<VectorType, RankedTensorType>())
       return static_cast<StandardTypes::Kind>(type.getKind());
     return llvm::None;
   };

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index bc929bcb0c74..a7613fa4ad33 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -337,7 +337,7 @@ uint64_t IntegerAttr::getUInt() const {
 }
 
 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
-  if (type.isa<IntegerType>() || type.isa<IndexType>())
+  if (type.isa<IntegerType, IndexType>())
     return success();
   return emitError(loc, "expected integer or index type");
 }
@@ -1090,7 +1090,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
                                                    ArrayRef<char> data,
                                                    bool isSplat) {
-  assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+  assert((type.isa<RankedTensorType, VectorType>()) &&
          "type must be ranked tensor or vector");
   assert(type.hasStaticShape() && "type must have static shape");
   return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
@@ -1247,7 +1247,7 @@ SparseElementsAttr SparseElementsAttr::get(ShapedType type,
                                            DenseElementsAttr values) {
   assert(indices.getType().getElementType().isInteger(64) &&
          "expected sparse indices to be 64-bit integer values");
-  assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+  assert((type.isa<RankedTensorType, VectorType>()) &&
          "type must be ranked tensor or vector");
   assert(type.hasStaticShape() && "type must have static shape");
   return Base::get(type.getContext(), StandardAttributes::SparseElements, type,

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 117cd6810968..c76ff30d6c79 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -72,11 +72,11 @@ bool Type::isUnsignedInteger(unsigned width) {
 }
 
 bool Type::isSignlessIntOrIndex() {
-  return isa<IndexType>() || isSignlessInteger();
+  return isSignlessInteger() || isa<IndexType>();
 }
 
 bool Type::isSignlessIntOrIndexOrFloat() {
-  return isa<IndexType>() || isSignlessInteger() || isa<FloatType>();
+  return isSignlessInteger() || isa<IndexType, FloatType>();
 }
 
 bool Type::isSignlessIntOrFloat() {
@@ -85,7 +85,7 @@ bool Type::isSignlessIntOrFloat() {
 
 bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); }
 
-bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
+bool Type::isIntOrFloat() { return isa<IntegerType, FloatType>(); }
 
 bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
 
@@ -200,7 +200,7 @@ int64_t ShapedType::getNumElements() const {
 int64_t ShapedType::getRank() const { return getShape().size(); }
 
 bool ShapedType::hasRank() const {
-  return !isa<UnrankedMemRefType>() && !isa<UnrankedTensorType>();
+  return !isa<UnrankedMemRefType, UnrankedTensorType>();
 }
 
 int64_t ShapedType::getDimSize(unsigned idx) const {
@@ -233,7 +233,7 @@ int64_t ShapedType::getSizeInBits() const {
   // Tensors can have vectors and other tensors as elements, other shaped types
   // cannot.
   assert(isa<TensorType>() && "unsupported element type");
-  assert((elementType.isa<VectorType>() || elementType.isa<TensorType>()) &&
+  assert((elementType.isa<VectorType, TensorType>()) &&
          "unsupported tensor element type");
   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
 }
@@ -398,8 +398,8 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
   auto *context = elementType.getContext();
 
   // Check that memref is formed from allowed types.
-  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
-      !elementType.isa<ComplexType>())
+  if (!elementType.isIntOrFloat() &&
+      !elementType.isa<VectorType, ComplexType>())
     return emitOptionalError(location, "invalid memref element type"),
            MemRefType();
 
@@ -476,8 +476,8 @@ LogicalResult
 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
                                                  unsigned memorySpace) {
   // Check that memref is formed from allowed types.
-  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
-      !elementType.isa<ComplexType>())
+  if (!elementType.isIntOrFloat() &&
+      !elementType.isa<VectorType, ComplexType>())
     return emitError(loc, "invalid memref element type");
   return success();
 }

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 16e548f4430c..b064d83b5faa 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -397,7 +397,7 @@ static WalkResult walkSymbolRefs(
     for (Attribute attr : llvm::drop_begin(attrRange, index)) {
       /// Check for a nested container attribute, these will also need to be
       /// walked.
-      if (attr.isa<ArrayAttr>() || attr.isa<DictionaryAttr>()) {
+      if (attr.isa<ArrayAttr, DictionaryAttr>()) {
         attrWorklist.push_back(attr);
         curAccessChain.push_back(-1);
         return WalkResult::advance();

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index ebbb1293f19d..609d7ad3f8d2 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -345,7 +345,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
     return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
   }
 
-  if (!type.isa<IntegerType>() && !type.isa<IndexType>())
+  if (!type.isa<IntegerType, IndexType>())
     return emitError(loc, "integer literal not valid for specified type"),
            nullptr;
 
@@ -823,7 +823,7 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
       return nullptr;
   }
 
-  if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
+  if (!type.isa<RankedTensorType, VectorType>()) {
     emitError("elements literal must be a ranked tensor or vector type");
     return nullptr;
   }

diff  --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 68d381f968ad..9d8d198aa1c8 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -217,8 +217,8 @@ Type Parser::parseMemRefType() {
     return nullptr;
 
   // Check that memref is formed from allowed types.
-  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
-      !elementType.isa<ComplexType>())
+  if (!elementType.isIntOrFloat() &&
+      !elementType.isa<VectorType, ComplexType>())
     return emitError(typeLoc, "invalid memref element type"), nullptr;
 
   // Parse semi-affine-map-composition.

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 633fe5e19703..075ce9f6089f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -778,8 +778,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
 
 LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) {
   for (Operation &o : getModuleBody(m).getOperations())
-    if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) &&
-        !o.isKnownTerminator())
+    if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp>(&o) && !o.isKnownTerminator())
       return o.emitOpError("unsupported module-level operation");
   return success();
 }

diff  --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index 2a735a58a8b0..18fc872cdf7f 100644
--- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -294,7 +294,7 @@ bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
     unsigned count = 0;
     stats->opCountMap[childForOp] = 0;
     for (auto &op : *forOp.getBody()) {
-      if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
+      if (!isa<AffineForOp, AffineIfOp>(op))
         ++count;
     }
     stats->opCountMap[childForOp] = count;

diff  --git a/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp b/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp
index 34db53b6ce1e..7a67bef93bc2 100644
--- a/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp
+++ b/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp
@@ -103,7 +103,7 @@ void TestMemRefDependenceCheck::runOnFunction() {
   // Collect the loads and stores within the function.
   loadsAndStores.clear();
   getFunction().walk([&](Operation *op) {
-    if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
+    if (isa<AffineLoadOp, AffineStoreOp>(op))
       loadsAndStores.push_back(op);
   });
 


        


More information about the Mlir-commits mailing list