[Mlir-commits] [mlir] f741b8e - [mlir][spirv] Move type checks from dialect class to type hierarchy
Lei Zhang
llvmlistbot at llvm.org
Wed Mar 18 17:13:26 PDT 2020
Author: Lei Zhang
Date: 2020-03-18T20:11:05-04:00
New Revision: f741b8eabed393a559459c8af024014580170d17
URL: https://github.com/llvm/llvm-project/commit/f741b8eabed393a559459c8af024014580170d17
DIFF: https://github.com/llvm/llvm-project/commit/f741b8eabed393a559459c8af024014580170d17.diff
LOG: [mlir][spirv] Move type checks from dialect class to type hierarchy
Types should be checked with the type hierarchy. This should result in
better responsibility division and API surface.
Differential Revision: https://reviews.llvm.org/D76243
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 26d8f1401c32..b6715dc9fcd7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -52,16 +52,6 @@ def SPIRV_Dialect : Dialect {
let hasRegionResultAttrVerify = 1;
let extraClassDeclaration = [{
- //===------------------------------------------------------------------===//
- // Type
- //===------------------------------------------------------------------===//
-
- /// Checks if the given `type` is valid in SPIR-V dialect.
- static bool isValidType(Type type);
-
- /// Checks if the given `scalar type` is valid in SPIR-V dialect.
- static bool isValidScalarType(Type type);
-
//===------------------------------------------------------------------===//
// Attribute
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 385e79a0445e..85b35f73f82c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -78,6 +78,8 @@ class SPIRVType : public Type {
static bool classof(Type type);
+ bool isScalarOrVector();
+
/// The extension requirements for each type are following the
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
/// convention.
@@ -109,6 +111,11 @@ class ScalarType : public SPIRVType {
static bool classof(Type type);
+ /// Returns true if the given integer type is valid for the SPIR-V dialect.
+ static bool isValid(FloatType);
+ /// Returns true if the given float type is valid for the SPIR-V dialect.
+ static bool isValid(IntegerType);
+
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
@@ -122,6 +129,9 @@ class CompositeType : public SPIRVType {
static bool classof(Type type);
+ /// Returns true if the given vector type is valid for the SPIR-V dialect.
+ static bool isValid(VectorType);
+
unsigned getNumElements() const;
Type getElementType(unsigned) const;
diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
index 44930b91e0ff..d4ce17c93706 100644
--- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
@@ -59,7 +59,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
VulkanLayoutUtils::Size &alignment) {
- if (spirv::SPIRVDialect::isValidScalarType(type)) {
+ if (type.isa<spirv::ScalarType>()) {
alignment = VulkanLayoutUtils::getScalarTypeAlignment(type);
// Vulkan spec does not specify any padding for a scalar type.
size = alignment;
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
index f378047f36ea..953d95b449d1 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/Functional.h"
@@ -358,15 +359,6 @@ struct ConvertSelectionOpToSelect
rhs.getOperation()->getAttrList().getDictionary();
}
- // Checks that given type is valid for `spv.SelectOp`.
- // According to SPIR-V spec:
- // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
- // Starting with version 1.4, Result Type can additionally be a composite type
- // other than a vector."
- bool isValidType(Type type) const {
- return spirv::SPIRVDialect::isValidScalarType(type) ||
- type.isa<VectorType>();
- }
// Returns a source value for the given block.
Value getSrcValue(Block *block) const {
@@ -401,11 +393,20 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
return failure();
}
+ // Checks that given type is valid for `spv.SelectOp`.
+ // According to SPIR-V spec:
+ // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
+ // Starting with version 1.4, Result Type can additionally be a composite type
+ // other than a vector."
+ bool isScalarOrVector = trueBrStoreOp.value()
+ .getType()
+ .cast<spirv::SPIRVType>()
+ .isScalarOrVector();
+
// Check that each `spv.Store` uses the same pointer, memory access
// attributes and a valid type of the value.
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
- !isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
- !isValidType(trueBrStoreOp.value().getType())) {
+ !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
return failure();
}
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index f2868a34f076..8ed417cad58d 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -152,42 +152,6 @@ template <>
Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
DialectAsmParser &parser);
-static bool isValidSPIRVIntType(IntegerType type) {
- return llvm::is_contained(ArrayRef<unsigned>({1, 8, 16, 32, 64}),
- type.getWidth());
-}
-
-bool SPIRVDialect::isValidScalarType(Type type) {
- if (type.isa<FloatType>()) {
- return !type.isBF16();
- }
- if (auto intType = type.dyn_cast<IntegerType>()) {
- return isValidSPIRVIntType(intType);
- }
- return false;
-}
-
-static bool isValidSPIRVVectorType(VectorType type) {
- return type.getRank() == 1 &&
- SPIRVDialect::isValidScalarType(type.getElementType()) &&
- type.getNumElements() >= 2 && type.getNumElements() <= 4;
-}
-
-bool SPIRVDialect::isValidType(Type type) {
- // Allow SPIR-V dialect types
- if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
- type.getKind() <= TypeKind::LAST_SPIRV_TYPE) {
- return true;
- }
- if (SPIRVDialect::isValidScalarType(type)) {
- return true;
- }
- if (auto vectorType = type.dyn_cast<VectorType>()) {
- return isValidSPIRVVectorType(vectorType);
- }
- return false;
-}
-
static Type parseAndVerifyType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
Type type;
@@ -206,7 +170,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
return Type();
}
} else if (auto t = type.dyn_cast<IntegerType>()) {
- if (!isValidSPIRVIntType(t)) {
+ if (!ScalarType::isValid(t)) {
parser.emitError(typeLoc,
"only 1/8/16/32/64-bit integer type allowed but found ")
<< type;
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index e5b630b82fb1..5451048aabe0 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -99,7 +99,7 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
// TODO(ravishankarm): This is a utility function that should probably be
// exposed by the SPIR-V dialect. Keeping it local till the use case arises.
static Optional<int64_t> getTypeNumBytes(Type t) {
- if (spirv::SPIRVDialect::isValidScalarType(t)) {
+ if (t.isa<spirv::ScalarType>()) {
auto bitWidth = t.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
// "There is no physical size or bit pattern defined for values with boolean
@@ -163,7 +163,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
: targetEnv(targetAttr) {
addConversion([](Type type) -> Optional<Type> {
// If the type is already valid in SPIR-V, directly return.
- return spirv::SPIRVDialect::isValidType(type) ? type : Optional<Type>();
+ return type.isa<spirv::SPIRVType>() ? type : Optional<Type>();
});
addConversion([](IndexType indexType) {
return SPIRVTypeConverter::getIndexType(indexType.getContext());
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 377242482b2a..f6b862156c49 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1373,7 +1373,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
bool spirv::ConstantOp::isBuildableWith(Type type) {
// Must be valid SPIR-V type first.
- if (!SPIRVDialect::isValidType(type))
+ if (!type.isa<spirv::SPIRVType>())
return false;
if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
@@ -2460,7 +2460,7 @@ static LogicalResult verify(spirv::SpecConstantOp constOp) {
case StandardAttributes::Integer:
case StandardAttributes::Float: {
// Make sure bitwidth is allowed.
- if (!spirv::SPIRVDialect::isValidType(value.getType()))
+ if (!value.getType().isa<spirv::SPIRVType>())
return constOp.emitOpError("default value bitwidth disallowed");
return success();
}
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 92dc5b82bb8a..3f963bd1d8a8 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -163,13 +163,19 @@ bool CompositeType::classof(Type type) {
case TypeKind::Array:
case TypeKind::RuntimeArray:
case TypeKind::Struct:
- case StandardTypes::Vector:
return true;
+ case StandardTypes::Vector:
+ return isValid(type.cast<VectorType>());
default:
return false;
}
}
+bool CompositeType::isValid(VectorType type) {
+ return type.getRank() == 1 && type.getElementType().isa<ScalarType>() &&
+ type.getNumElements() >= 2 && type.getNumElements() <= 4;
+}
+
Type CompositeType::getElementType(unsigned index) const {
switch (getKind()) {
case spirv::TypeKind::Array:
@@ -560,7 +566,30 @@ void RuntimeArrayType::getCapabilities(
// ScalarType
//===----------------------------------------------------------------------===//
-bool ScalarType::classof(Type type) { return type.isIntOrFloat(); }
+bool ScalarType::classof(Type type) {
+ if (auto floatType = type.dyn_cast<FloatType>()) {
+ return isValid(floatType);
+ }
+ if (auto intType = type.dyn_cast<IntegerType>()) {
+ return isValid(intType);
+ }
+ return false;
+}
+
+bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
+
+bool ScalarType::isValid(IntegerType type) {
+ switch (type.getWidth()) {
+ case 1:
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ return true;
+ default:
+ return false;
+ }
+}
void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
@@ -678,9 +707,19 @@ void ScalarType::getCapabilities(
//===----------------------------------------------------------------------===//
bool SPIRVType::classof(Type type) {
- return type.isa<ScalarType>() || type.isa<VectorType>() ||
- (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
- type.getKind() <= TypeKind::LAST_SPIRV_TYPE);
+ // Allow SPIR-V dialect types
+ if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
+ type.getKind() <= TypeKind::LAST_SPIRV_TYPE)
+ return true;
+ if (type.isa<ScalarType>())
+ return true;
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ return CompositeType::isValid(vectorType);
+ return false;
+}
+
+bool SPIRVType::isScalarOrVector() {
+ return isIntOrFloat() || isa<VectorType>();
}
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 516b9eca8544..1ca9cad977af 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -21,12 +21,6 @@
using namespace mlir;
-/// Checks if the `type` is a scalar or vector type. It is assumed that they are
-/// valid for SPIR-V dialect already.
-static bool isScalarOrVectorType(Type type) {
- return spirv::SPIRVDialect::isValidScalarType(type) || type.isa<VectorType>();
-}
-
/// Creates a global variable for an argument based on the ABI info.
static spirv::GlobalVariableOp
createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
@@ -45,7 +39,7 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
// info create a variable of type !spv.ptr<!spv.struct<elementType>>. If not
// it must already be a !spv.ptr<!spv.struct<...>>.
auto varType = funcOp.getType().getInput(argIndex);
- if (isScalarOrVectorType(varType)) {
+ if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
auto storageClass =
static_cast<spirv::StorageClass>(abiInfo.storage_class().getInt());
varType =
@@ -198,7 +192,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
// at the start of the function. It is probably better to do the load just
// before the use. There might be multiple loads and currently there is no
// easy way to replace all uses with a sequence of operations.
- if (isScalarOrVectorType(argType.value())) {
+ if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext());
auto zero =
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter);
More information about the Mlir-commits
mailing list