[Mlir-commits] [mlir] f741b8e - [mlir][spirv] Move type checks from dialect class to type hierarchy

Lei Zhang llvmlistbot at llvm.org
Wed Mar 18 17:13:26 PDT 2020


Author: Lei Zhang
Date: 2020-03-18T20:11:05-04:00
New Revision: f741b8eabed393a559459c8af024014580170d17

URL: https://github.com/llvm/llvm-project/commit/f741b8eabed393a559459c8af024014580170d17
DIFF: https://github.com/llvm/llvm-project/commit/f741b8eabed393a559459c8af024014580170d17.diff

LOG: [mlir][spirv] Move type checks from dialect class to type hierarchy

Types should be checked with the type hierarchy. This should result in
better responsibility division and API surface.

Differential Revision: https://reviews.llvm.org/D76243

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
    mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 26d8f1401c32..b6715dc9fcd7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -52,16 +52,6 @@ def SPIRV_Dialect : Dialect {
   let hasRegionResultAttrVerify = 1;
 
   let extraClassDeclaration = [{
-    //===------------------------------------------------------------------===//
-    // Type
-    //===------------------------------------------------------------------===//
-
-    /// Checks if the given `type` is valid in SPIR-V dialect.
-    static bool isValidType(Type type);
-
-    /// Checks if the given `scalar type` is valid in SPIR-V dialect.
-    static bool isValidScalarType(Type type);
-
     //===------------------------------------------------------------------===//
     // Attribute
     //===------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 385e79a0445e..85b35f73f82c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -78,6 +78,8 @@ class SPIRVType : public Type {
 
   static bool classof(Type type);
 
+  bool isScalarOrVector();
+
   /// The extension requirements for each type are following the
   /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
   /// convention.
@@ -109,6 +111,11 @@ class ScalarType : public SPIRVType {
 
   static bool classof(Type type);
 
+  /// Returns true if the given integer type is valid for the SPIR-V dialect.
+  static bool isValid(FloatType);
+  /// Returns true if the given float type is valid for the SPIR-V dialect.
+  static bool isValid(IntegerType);
+
   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                      Optional<spirv::StorageClass> storage = llvm::None);
   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
@@ -122,6 +129,9 @@ class CompositeType : public SPIRVType {
 
   static bool classof(Type type);
 
+  /// Returns true if the given vector type is valid for the SPIR-V dialect.
+  static bool isValid(VectorType);
+
   unsigned getNumElements() const;
 
   Type getElementType(unsigned) const;

diff  --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
index 44930b91e0ff..d4ce17c93706 100644
--- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
@@ -59,7 +59,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
 
 Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
                                      VulkanLayoutUtils::Size &alignment) {
-  if (spirv::SPIRVDialect::isValidScalarType(type)) {
+  if (type.isa<spirv::ScalarType>()) {
     alignment = VulkanLayoutUtils::getScalarTypeAlignment(type);
     // Vulkan spec does not specify any padding for a scalar type.
     size = alignment;

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
index f378047f36ea..953d95b449d1 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/Functional.h"
@@ -358,15 +359,6 @@ struct ConvertSelectionOpToSelect
            rhs.getOperation()->getAttrList().getDictionary();
   }
 
-  // Checks that given type is valid for `spv.SelectOp`.
-  // According to SPIR-V spec:
-  // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
-  // Starting with version 1.4, Result Type can additionally be a composite type
-  // other than a vector."
-  bool isValidType(Type type) const {
-    return spirv::SPIRVDialect::isValidScalarType(type) ||
-           type.isa<VectorType>();
-  }
 
   // Returns a source value for the given block.
   Value getSrcValue(Block *block) const {
@@ -401,11 +393,20 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
     return failure();
   }
 
+  // Checks that given type is valid for `spv.SelectOp`.
+  // According to SPIR-V spec:
+  // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
+  // Starting with version 1.4, Result Type can additionally be a composite type
+  // other than a vector."
+  bool isScalarOrVector = trueBrStoreOp.value()
+                              .getType()
+                              .cast<spirv::SPIRVType>()
+                              .isScalarOrVector();
+
   // Check that each `spv.Store` uses the same pointer, memory access
   // attributes and a valid type of the value.
   if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
-      !isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
-      !isValidType(trueBrStoreOp.value().getType())) {
+      !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
     return failure();
   }
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index f2868a34f076..8ed417cad58d 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -152,42 +152,6 @@ template <>
 Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
                                             DialectAsmParser &parser);
 
-static bool isValidSPIRVIntType(IntegerType type) {
-  return llvm::is_contained(ArrayRef<unsigned>({1, 8, 16, 32, 64}),
-                            type.getWidth());
-}
-
-bool SPIRVDialect::isValidScalarType(Type type) {
-  if (type.isa<FloatType>()) {
-    return !type.isBF16();
-  }
-  if (auto intType = type.dyn_cast<IntegerType>()) {
-    return isValidSPIRVIntType(intType);
-  }
-  return false;
-}
-
-static bool isValidSPIRVVectorType(VectorType type) {
-  return type.getRank() == 1 &&
-         SPIRVDialect::isValidScalarType(type.getElementType()) &&
-         type.getNumElements() >= 2 && type.getNumElements() <= 4;
-}
-
-bool SPIRVDialect::isValidType(Type type) {
-  // Allow SPIR-V dialect types
-  if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
-      type.getKind() <= TypeKind::LAST_SPIRV_TYPE) {
-    return true;
-  }
-  if (SPIRVDialect::isValidScalarType(type)) {
-    return true;
-  }
-  if (auto vectorType = type.dyn_cast<VectorType>()) {
-    return isValidSPIRVVectorType(vectorType);
-  }
-  return false;
-}
-
 static Type parseAndVerifyType(SPIRVDialect const &dialect,
                                DialectAsmParser &parser) {
   Type type;
@@ -206,7 +170,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
       return Type();
     }
   } else if (auto t = type.dyn_cast<IntegerType>()) {
-    if (!isValidSPIRVIntType(t)) {
+    if (!ScalarType::isValid(t)) {
       parser.emitError(typeLoc,
                        "only 1/8/16/32/64-bit integer type allowed but found ")
           << type;

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index e5b630b82fb1..5451048aabe0 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -99,7 +99,7 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
 // TODO(ravishankarm): This is a utility function that should probably be
 // exposed by the SPIR-V dialect. Keeping it local till the use case arises.
 static Optional<int64_t> getTypeNumBytes(Type t) {
-  if (spirv::SPIRVDialect::isValidScalarType(t)) {
+  if (t.isa<spirv::ScalarType>()) {
     auto bitWidth = t.getIntOrFloatBitWidth();
     // According to the SPIR-V spec:
     // "There is no physical size or bit pattern defined for values with boolean
@@ -163,7 +163,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
     : targetEnv(targetAttr) {
   addConversion([](Type type) -> Optional<Type> {
     // If the type is already valid in SPIR-V, directly return.
-    return spirv::SPIRVDialect::isValidType(type) ? type : Optional<Type>();
+    return type.isa<spirv::SPIRVType>() ? type : Optional<Type>();
   });
   addConversion([](IndexType indexType) {
     return SPIRVTypeConverter::getIndexType(indexType.getContext());

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 377242482b2a..f6b862156c49 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1373,7 +1373,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
 
 bool spirv::ConstantOp::isBuildableWith(Type type) {
   // Must be valid SPIR-V type first.
-  if (!SPIRVDialect::isValidType(type))
+  if (!type.isa<spirv::SPIRVType>())
     return false;
 
   if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
@@ -2460,7 +2460,7 @@ static LogicalResult verify(spirv::SpecConstantOp constOp) {
   case StandardAttributes::Integer:
   case StandardAttributes::Float: {
     // Make sure bitwidth is allowed.
-    if (!spirv::SPIRVDialect::isValidType(value.getType()))
+    if (!value.getType().isa<spirv::SPIRVType>())
       return constOp.emitOpError("default value bitwidth disallowed");
     return success();
   }

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 92dc5b82bb8a..3f963bd1d8a8 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -163,13 +163,19 @@ bool CompositeType::classof(Type type) {
   case TypeKind::Array:
   case TypeKind::RuntimeArray:
   case TypeKind::Struct:
-  case StandardTypes::Vector:
     return true;
+  case StandardTypes::Vector:
+    return isValid(type.cast<VectorType>());
   default:
     return false;
   }
 }
 
+bool CompositeType::isValid(VectorType type) {
+  return type.getRank() == 1 && type.getElementType().isa<ScalarType>() &&
+         type.getNumElements() >= 2 && type.getNumElements() <= 4;
+}
+
 Type CompositeType::getElementType(unsigned index) const {
   switch (getKind()) {
   case spirv::TypeKind::Array:
@@ -560,7 +566,30 @@ void RuntimeArrayType::getCapabilities(
 // ScalarType
 //===----------------------------------------------------------------------===//
 
-bool ScalarType::classof(Type type) { return type.isIntOrFloat(); }
+bool ScalarType::classof(Type type) {
+  if (auto floatType = type.dyn_cast<FloatType>()) {
+    return isValid(floatType);
+  }
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    return isValid(intType);
+  }
+  return false;
+}
+
+bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
+
+bool ScalarType::isValid(IntegerType type) {
+  switch (type.getWidth()) {
+  case 1:
+  case 8:
+  case 16:
+  case 32:
+  case 64:
+    return true;
+  default:
+    return false;
+  }
+}
 
 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                                Optional<StorageClass> storage) {
@@ -678,9 +707,19 @@ void ScalarType::getCapabilities(
 //===----------------------------------------------------------------------===//
 
 bool SPIRVType::classof(Type type) {
-  return type.isa<ScalarType>() || type.isa<VectorType>() ||
-         (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
-          type.getKind() <= TypeKind::LAST_SPIRV_TYPE);
+  // Allow SPIR-V dialect types
+  if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
+      type.getKind() <= TypeKind::LAST_SPIRV_TYPE)
+    return true;
+  if (type.isa<ScalarType>())
+    return true;
+  if (auto vectorType = type.dyn_cast<VectorType>())
+    return CompositeType::isValid(vectorType);
+  return false;
+}
+
+bool SPIRVType::isScalarOrVector() {
+  return isIntOrFloat() || isa<VectorType>();
 }
 
 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 516b9eca8544..1ca9cad977af 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -21,12 +21,6 @@
 
 using namespace mlir;
 
-/// Checks if the `type` is a scalar or vector type. It is assumed that they are
-/// valid for SPIR-V dialect already.
-static bool isScalarOrVectorType(Type type) {
-  return spirv::SPIRVDialect::isValidScalarType(type) || type.isa<VectorType>();
-}
-
 /// Creates a global variable for an argument based on the ABI info.
 static spirv::GlobalVariableOp
 createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
@@ -45,7 +39,7 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
   // info create a variable of type !spv.ptr<!spv.struct<elementType>>. If not
   // it must already be a !spv.ptr<!spv.struct<...>>.
   auto varType = funcOp.getType().getInput(argIndex);
-  if (isScalarOrVectorType(varType)) {
+  if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
     auto storageClass =
         static_cast<spirv::StorageClass>(abiInfo.storage_class().getInt());
     varType =
@@ -198,7 +192,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
     // at the start of the function. It is probably better to do the load just
     // before the use. There might be multiple loads and currently there is no
     // easy way to replace all uses with a sequence of operations.
-    if (isScalarOrVectorType(argType.value())) {
+    if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
       auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext());
       auto zero =
           spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter);


        


More information about the Mlir-commits mailing list