[Mlir-commits] [mlir] ee394e6 - [MLIR] Add variadic isa<> for Type, Value, and Attribute
Rahul Joshi
llvmlistbot at llvm.org
Mon Jun 29 15:05:20 PDT 2020
Author: Rahul Joshi
Date: 2020-06-29T15:04:48-07:00
New Revision: ee394e6842733a38ee0953d8ee018547ecbef8fd
URL: https://github.com/llvm/llvm-project/commit/ee394e6842733a38ee0953d8ee018547ecbef8fd
DIFF: https://github.com/llvm/llvm-project/commit/ee394e6842733a38ee0953d8ee018547ecbef8fd.diff
LOG: [MLIR] Add variadic isa<> for Type, Value, and Attribute
- Also adopt variadic llvm::isa<> in more places.
- Fixes https://bugs.llvm.org/show_bug.cgi?id=46445
Differential Revision: https://reviews.llvm.org/D82769
Added:
Modified:
mlir/docs/Tutorials/Toy/Ch-7.md
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/EDSC/Builders.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Matchers.h
mlir/include/mlir/IR/StandardTypes.h
mlir/include/mlir/IR/Types.h
mlir/include/mlir/IR/Value.h
mlir/lib/Analysis/Utils.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Dialect/Affine/EDSC/Builders.cpp
mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Quant/IR/QuantOps.cpp
mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Traits.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/TypeParser.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md
index 0dce6d25c904..733e22c5b0a5 100644
--- a/mlir/docs/Tutorials/Toy/Ch-7.md
+++ b/mlir/docs/Tutorials/Toy/Ch-7.md
@@ -287,8 +287,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
return nullptr;
// Check that the type is either a TensorType or another StructType.
- if (!elementType.isa<mlir::TensorType>() &&
- !elementType.isa<StructType>()) {
+ if (!elementType.isa<mlir::TensorType, StructType>()) {
parser.emitError(typeLoc, "element type for a struct must either "
"be a TensorType or a StructType, got: ")
<< elementType;
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 867da9c202cc..fc7bf2a2375c 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -510,8 +510,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
return nullptr;
// Check that the type is either a TensorType or another StructType.
- if (!elementType.isa<mlir::TensorType>() &&
- !elementType.isa<StructType>()) {
+ if (!elementType.isa<mlir::TensorType, StructType>()) {
parser.emitError(typeLoc, "element type for a struct must either "
"be a TensorType or a StructType, got: ")
<< elementType;
diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h
index 64df2c9fe367..1f21af617e4d 100644
--- a/mlir/include/mlir/EDSC/Builders.h
+++ b/mlir/include/mlir/EDSC/Builders.h
@@ -139,15 +139,12 @@ struct StructuredIndexed {
StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
: value(v), exprs(indexings.begin(), indexings.end()) {
- assert((v.getType().isa<MemRefType>() ||
- v.getType().isa<RankedTensorType>() ||
- v.getType().isa<VectorType>()) &&
+ assert((v.getType().isa<MemRefType, RankedTensorType, VectorType>()) &&
"MemRef, RankedTensor or Vector expected");
}
StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
: type(t), exprs(indexings.begin(), indexings.end()) {
- assert((t.isa<MemRefType>() || t.isa<RankedTensorType>() ||
- t.isa<VectorType>()) &&
+ assert((t.isa<MemRefType, RankedTensorType, VectorType>()) &&
"MemRef, RankedTensor or Vector expected");
}
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index f9d8efd42272..ea3011f0fdc7 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -85,6 +85,8 @@ class Attribute {
bool operator!() const { return impl == nullptr; }
template <typename U> bool isa() const;
+ template <typename First, typename Second, typename... Rest>
+ bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;
@@ -1630,6 +1632,12 @@ template <typename U> bool Attribute::isa() const {
assert(impl && "isa<> used on a null attribute.");
return U::classof(*this);
}
+
+template <typename First, typename Second, typename... Rest>
+bool Attribute::isa() const {
+ return isa<First>() || isa<Second, Rest...>();
+}
+
template <typename U> U Attribute::dyn_cast() const {
return isa<U>() ? U(impl) : U(nullptr);
}
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 0f74f1b9cd43..72e17e1699a0 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -97,9 +97,9 @@ struct constant_int_op_binder {
return false;
auto type = op->getResult(0).getType();
- if (type.isa<IntegerType>() || type.isa<IndexType>())
+ if (type.isa<IntegerType, IndexType>())
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
- if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
+ if (type.isa<VectorType, RankedTensorType>()) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<IntegerAttr>(bind_value)
.match(splatAttr.getSplatValue());
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 99ad3b10f6d9..cde43843e8b5 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -357,7 +357,7 @@ class VectorType
/// Returns true of the given type can be used as an element of a vector type.
/// In particular, vectors can consist of integer or float primitives.
static bool isValidElementType(Type t) {
- return t.isa<IntegerType>() || t.isa<FloatType>();
+ return t.isa<IntegerType, FloatType>();
}
ArrayRef<int64_t> getShape() const;
@@ -381,9 +381,8 @@ class TensorType : public ShapedType {
// Note: Non standard/builtin types are allowed to exist within tensor
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
- return type.isa<ComplexType>() || type.isa<FloatType>() ||
- type.isa<IntegerType>() || type.isa<OpaqueType>() ||
- type.isa<VectorType>() || type.isa<IndexType>() ||
+ return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
+ IndexType>() ||
(type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
}
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 32ae13f86dc9..60bc04a8708c 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -121,6 +121,8 @@ class Type {
bool operator!() const { return impl == nullptr; }
template <typename U> bool isa() const;
+ template <typename First, typename Second, typename... Rest>
+ bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;
@@ -271,6 +273,12 @@ template <typename U> bool Type::isa() const {
assert(impl && "isa<> used on a null type.");
return U::classof(*this);
}
+
+template <typename First, typename Second, typename... Rest>
+bool Type::isa() const {
+ return isa<First>() || isa<Second, Rest...>();
+}
+
template <typename U> U Type::dyn_cast() const {
return isa<U>() ? U(impl) : U(nullptr);
}
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index f5cb16f347ed..c22741ee6cb6 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -81,6 +81,12 @@ class Value {
assert(*this && "isa<> used on a null type.");
return U::classof(*this);
}
+
+ template <typename First, typename Second, typename... Rest>
+ bool isa() const {
+ return isa<First>() || isa<Second, Rest...>();
+ }
+
template <typename U> U dyn_cast() const {
return isa<U>() ? U(ownerAndKind) : U(nullptr);
}
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 8a29fdbfd00b..ab273f8d95d5 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -956,8 +956,7 @@ static Optional<int64_t> getMemoryFootprintBytes(Block &block,
// Walk this 'affine.for' operation to gather all memory regions.
auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
- if (!isa<AffineReadOpInterface>(opInst) &&
- !isa<AffineWriteOpInterface>(opInst)) {
+ if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
// Neither load nor a store op.
return WalkResult::advance();
}
@@ -1017,11 +1016,9 @@ bool mlir::isLoopParallel(AffineForOp forOp) {
// Collect all load and store ops in loop nest rooted at 'forOp'.
SmallVector<Operation *, 8> loadAndStoreOpInsts;
auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult {
- if (isa<AffineReadOpInterface>(opInst) ||
- isa<AffineWriteOpInterface>(opInst))
+ if (isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst))
loadAndStoreOpInsts.push_back(opInst);
- else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
- !isa<AffineIfOp>(opInst) &&
+ else if (!isa<AffineForOp, AffineTerminatorOp, AffineIfOp>(opInst) &&
!MemoryEffectOpInterface::hasNoEffect(opInst))
return WalkResult::interrupt();
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 9b651bb8b80a..b6900d13094c 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -302,7 +302,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
if (!converted)
return {};
- if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>())
+ if (t.isa<MemRefType, UnrankedMemRefType>())
converted = converted.getPointerTo();
inputs.push_back(converted);
}
@@ -1044,7 +1044,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
FunctionType type, SmallVectorImpl<UnsignedTypePair> &argsInfo) const {
argsInfo.reserve(type.getNumInputs());
for (auto en : llvm::enumerate(type.getInputs())) {
- if (en.value().isa<MemRefType>() || en.value().isa<UnrankedMemRefType>())
+ if (en.value().isa<MemRefType, UnrankedMemRefType>())
argsInfo.push_back({en.index(), en.value()});
}
}
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 45992c888d72..aac275548891 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -518,7 +518,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
return failure();
// std.constant should only have vector or tenor types.
- assert(srcType.isa<VectorType>() || srcType.isa<RankedTensorType>());
+ assert((srcType.isa<VectorType, RankedTensorType>()));
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
diff --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
index e5bf1c015e02..beeaeaa9cf27 100644
--- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
@@ -117,7 +117,7 @@ static Value createBinaryHandle(
return ValueBuilder<IOp>(lhs, rhs);
} else if (thisType.isa<FloatType>()) {
return ValueBuilder<FOp>(lhs, rhs);
- } else if (thisType.isa<VectorType>() || thisType.isa<TensorType>()) {
+ } else if (thisType.isa<VectorType, TensorType>()) {
auto aggregateType = thisType.cast<ShapedType>();
if (aggregateType.getElementType().isSignlessInteger())
return ValueBuilder<IOp>(lhs, rhs);
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 89cbca0444f5..ea66fcb3b090 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -218,7 +218,7 @@ void AffineDataCopyGeneration::runOnFunction() {
nest->walk([&](Operation *op) {
if (auto forOp = dyn_cast<AffineForOp>(op))
promoteIfSingleIteration(forOp);
- else if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
+ else if (isa<AffineLoadOp, AffineStoreOp>(op))
copyOps.push_back(op);
});
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
index e060aac03e44..aaa21104e1fd 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
@@ -80,7 +80,7 @@ bool isOpLoopInvariant(Operation &op, Value indVar,
// If the body of a predicated region has a for loop, we don't hoist the
// 'affine.if'.
return false;
- } else if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) {
+ } else if (isa<AffineDmaStartOp, AffineDmaWaitOp>(op)) {
// TODO(asabne): Support DMA ops.
return false;
} else if (!isa<ConstantOp>(op)) {
@@ -91,7 +91,7 @@ bool isOpLoopInvariant(Operation &op, Value indVar,
for (auto *user : memref.getUsers()) {
// If this memref has a user that is a DMA, give up because these
// operations write to this memref.
- if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) {
+ if (isa<AffineDmaStartOp, AffineDmaWaitOp>(op)) {
return false;
}
// If the memref used by the load/store is used in a store elsewhere in
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 48231642eae7..e37146e73954 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -923,11 +923,11 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
return nullptr;
// Fuse when consumer is GenericOp or IndexedGenericOp.
- if (isa<GenericOp>(consumer) || isa<IndexedGenericOp>(consumer)) {
+ if (isa<GenericOp, IndexedGenericOp>(consumer)) {
auto linalgOpConsumer = cast<LinalgOp>(consumer);
if (!linalgOpConsumer.hasTensorSemantics())
return nullptr;
- if (isa<GenericOp>(producer) || isa<IndexedGenericOp>(producer)) {
+ if (isa<GenericOp, IndexedGenericOp>(producer)) {
auto linalgOpProducer = cast<LinalgOp>(producer);
if (linalgOpProducer.hasTensorSemantics())
return FuseGenericOpsOnTensors::fuse(linalgOpProducer, linalgOpConsumer,
diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
index b0dc1fa10679..07f881fbc52c 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp
@@ -46,7 +46,7 @@ OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
Type spec = typeAttr.getValue();
- if (spec.isa<TensorType>() || spec.isa<VectorType>())
+ if (spec.isa<TensorType, VectorType>())
return false;
// The spec should be either a quantized type which is compatible to the
diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
index 2ff23123b474..88eb314a852b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
@@ -69,8 +69,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
}
// Is the constant value a type expressed in a way that we support?
- if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
- !value.isa<SparseElementsAttr>()) {
+ if (!value.isa<FloatAttr, DenseElementsAttr, SparseElementsAttr>()) {
return failure();
}
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 8368a8e1857b..1ac6a1e6d75b 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1292,7 +1292,7 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
return failure();
Type type = value.getType();
- if (type.isa<NoneType>() || type.isa<TensorType>()) {
+ if (type.isa<NoneType, TensorType>()) {
if (parser.parseColonType(type))
return failure();
}
@@ -1827,8 +1827,8 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
// TODO: Currently only variable initialization with specialization
// constants and other variables is supported. They could be normal
// constants in the module scope as well.
- if (!initOp || !(isa<spirv::GlobalVariableOp>(initOp) ||
- isa<spirv::SpecConstantOp>(initOp))) {
+ if (!initOp ||
+ !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
return varOp.emitOpError("initializer must be result of a "
"spv.specConstant or spv.globalVariable op");
}
@@ -2093,8 +2093,7 @@ void spirv::LoopOp::addEntryAndMergeBlock() {
static LogicalResult verify(spirv::MergeOp mergeOp) {
auto *parentOp = mergeOp.getParentOp();
- if (!parentOp ||
- (!isa<spirv::SelectionOp>(parentOp) && !isa<spirv::LoopOp>(parentOp)))
+ if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
return mergeOp.emitOpError(
"expected parent op to be 'spv.selection' or 'spv.loop'");
@@ -2620,9 +2619,9 @@ static LogicalResult verify(spirv::VariableOp varOp) {
// SPIR-V spec: "Initializer must be an <id> from a constant instruction or
// a global (module scope) OpVariable instruction".
auto *initOp = varOp.getOperand(0).getDefiningOp();
- if (!initOp || !(isa<spirv::ConstantOp>(initOp) || // for normal constant
- isa<spirv::ReferenceOfOp>(initOp) || // for spec constant
- isa<spirv::AddressOfOp>(initOp)))
+ if (!initOp || !isa<spirv::ConstantOp, // for normal constant
+ spirv::ReferenceOfOp, // for spec constant
+ spirv::AddressOfOp>(initOp))
return varOp.emitOpError("initializer must be the result of a "
"constant or spv.globalVariable op");
}
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 8df99378868b..b81f7f4c7387 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1176,8 +1176,7 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
if (value.getType() != type)
return false;
// Finally, check that the attribute kind is handled.
- return value.isa<IntegerAttr>() || value.isa<FloatAttr>() ||
- value.isa<ElementsAttr>() || value.isa<UnitAttr>();
+ return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
}
void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
@@ -2103,7 +2102,7 @@ static LogicalResult verify(SelectOp op) {
// If the result type is a vector or tensor, the type can be a mask with the
// same elements.
Type resultType = op.getType();
- if (!resultType.isa<TensorType>() && !resultType.isa<VectorType>())
+ if (!resultType.isa<TensorType, VectorType>())
return op.emitOpError()
<< "expected condition to be a signless i1, but got "
<< conditionType;
@@ -2222,8 +2221,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "splat takes one operand");
auto constOperand = operands.front();
- if (!constOperand ||
- (!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>()))
+ if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
return {};
auto shapedType = getType().cast<ShapedType>();
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index dc721adc7472..c974e2fc097b 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -107,7 +107,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
// Returns the type kind if the given type is a vector or ranked tensor type.
// Returns llvm::None otherwise.
auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
- if (type.isa<VectorType>() || type.isa<RankedTensorType>())
+ if (type.isa<VectorType, RankedTensorType>())
return static_cast<StandardTypes::Kind>(type.getKind());
return llvm::None;
};
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index bc929bcb0c74..a7613fa4ad33 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -337,7 +337,7 @@ uint64_t IntegerAttr::getUInt() const {
}
static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
- if (type.isa<IntegerType>() || type.isa<IndexType>())
+ if (type.isa<IntegerType, IndexType>())
return success();
return emitError(loc, "expected integer or index type");
}
@@ -1090,7 +1090,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
ArrayRef<char> data,
bool isSplat) {
- assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+ assert((type.isa<RankedTensorType, VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
@@ -1247,7 +1247,7 @@ SparseElementsAttr SparseElementsAttr::get(ShapedType type,
DenseElementsAttr values) {
assert(indices.getType().getElementType().isInteger(64) &&
"expected sparse indices to be 64-bit integer values");
- assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+ assert((type.isa<RankedTensorType, VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 117cd6810968..c76ff30d6c79 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -72,11 +72,11 @@ bool Type::isUnsignedInteger(unsigned width) {
}
bool Type::isSignlessIntOrIndex() {
- return isa<IndexType>() || isSignlessInteger();
+ return isSignlessInteger() || isa<IndexType>();
}
bool Type::isSignlessIntOrIndexOrFloat() {
- return isa<IndexType>() || isSignlessInteger() || isa<FloatType>();
+ return isSignlessInteger() || isa<IndexType, FloatType>();
}
bool Type::isSignlessIntOrFloat() {
@@ -85,7 +85,7 @@ bool Type::isSignlessIntOrFloat() {
bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); }
-bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
+bool Type::isIntOrFloat() { return isa<IntegerType, FloatType>(); }
bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
@@ -200,7 +200,7 @@ int64_t ShapedType::getNumElements() const {
int64_t ShapedType::getRank() const { return getShape().size(); }
bool ShapedType::hasRank() const {
- return !isa<UnrankedMemRefType>() && !isa<UnrankedTensorType>();
+ return !isa<UnrankedMemRefType, UnrankedTensorType>();
}
int64_t ShapedType::getDimSize(unsigned idx) const {
@@ -233,7 +233,7 @@ int64_t ShapedType::getSizeInBits() const {
// Tensors can have vectors and other tensors as elements, other shaped types
// cannot.
assert(isa<TensorType>() && "unsupported element type");
- assert((elementType.isa<VectorType>() || elementType.isa<TensorType>()) &&
+ assert((elementType.isa<VectorType, TensorType>()) &&
"unsupported tensor element type");
return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
}
@@ -398,8 +398,8 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
auto *context = elementType.getContext();
// Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
- !elementType.isa<ComplexType>())
+ if (!elementType.isIntOrFloat() &&
+ !elementType.isa<VectorType, ComplexType>())
return emitOptionalError(location, "invalid memref element type"),
MemRefType();
@@ -476,8 +476,8 @@ LogicalResult
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned memorySpace) {
// Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
- !elementType.isa<ComplexType>())
+ if (!elementType.isIntOrFloat() &&
+ !elementType.isa<VectorType, ComplexType>())
return emitError(loc, "invalid memref element type");
return success();
}
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 16e548f4430c..b064d83b5faa 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -397,7 +397,7 @@ static WalkResult walkSymbolRefs(
for (Attribute attr : llvm::drop_begin(attrRange, index)) {
/// Check for a nested container attribute, these will also need to be
/// walked.
- if (attr.isa<ArrayAttr>() || attr.isa<DictionaryAttr>()) {
+ if (attr.isa<ArrayAttr, DictionaryAttr>()) {
attrWorklist.push_back(attr);
curAccessChain.push_back(-1);
return WalkResult::advance();
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index ebbb1293f19d..609d7ad3f8d2 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -345,7 +345,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
}
- if (!type.isa<IntegerType>() && !type.isa<IndexType>())
+ if (!type.isa<IntegerType, IndexType>())
return emitError(loc, "integer literal not valid for specified type"),
nullptr;
@@ -823,7 +823,7 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
return nullptr;
}
- if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
+ if (!type.isa<RankedTensorType, VectorType>()) {
emitError("elements literal must be a ranked tensor or vector type");
return nullptr;
}
diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 68d381f968ad..9d8d198aa1c8 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -217,8 +217,8 @@ Type Parser::parseMemRefType() {
return nullptr;
// Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
- !elementType.isa<ComplexType>())
+ if (!elementType.isIntOrFloat() &&
+ !elementType.isa<VectorType, ComplexType>())
return emitError(typeLoc, "invalid memref element type"), nullptr;
// Parse semi-affine-map-composition.
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 633fe5e19703..075ce9f6089f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -778,8 +778,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) {
for (Operation &o : getModuleBody(m).getOperations())
- if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) &&
- !o.isKnownTerminator())
+ if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp>(&o) && !o.isKnownTerminator())
return o.emitOpError("unsupported module-level operation");
return success();
}
diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index 2a735a58a8b0..18fc872cdf7f 100644
--- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -294,7 +294,7 @@ bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
unsigned count = 0;
stats->opCountMap[childForOp] = 0;
for (auto &op : *forOp.getBody()) {
- if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
+ if (!isa<AffineForOp, AffineIfOp>(op))
++count;
}
stats->opCountMap[childForOp] = count;
diff --git a/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp b/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp
index 34db53b6ce1e..7a67bef93bc2 100644
--- a/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp
+++ b/mlir/test/lib/Transforms/TestMemRefDependenceCheck.cpp
@@ -103,7 +103,7 @@ void TestMemRefDependenceCheck::runOnFunction() {
// Collect the loads and stores within the function.
loadsAndStores.clear();
getFunction().walk([&](Operation *op) {
- if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
+ if (isa<AffineLoadOp, AffineStoreOp>(op))
loadsAndStores.push_back(op);
});
More information about the Mlir-commits
mailing list