[Mlir-commits] [mlir] 9efb4b4 - [mlir][spirv] Make SPIRVTypeConverter target environment aware

Lei Zhang llvmlistbot at llvm.org
Wed Mar 18 17:13:28 PDT 2020


Author: Lei Zhang
Date: 2020-03-18T20:11:05-04:00
New Revision: 9efb4b4023272c46e10738c96936710d3ed70d5a

URL: https://github.com/llvm/llvm-project/commit/9efb4b4023272c46e10738c96936710d3ed70d5a
DIFF: https://github.com/llvm/llvm-project/commit/9efb4b4023272c46e10738c96936710d3ed70d5a.diff

LOG: [mlir][spirv] Make SPIRVTypeConverter target environment aware

Non-32-bit scalar types requires special hardware support that may
not exist on all Vulkan-capable GPUs. This is reflected as non-32-bit
scalar types require special capabilities or extensions to be used.
This commit makes SPIRVTypeConverter target environment aware so
that it can properly convert standard types to what is accepted on
the target environment.

Right now if a scalar type bitwidth is not supported in the target
environment, we use 32-bit unconditionally. This requires Vulkan
runtime to also feed in data with a matched bitwidth and layout,
especially for interface types. The Vulkan runtime can do that by
inspecting the SPIR-V module. Longer term, we might want to introduce
a way to control how such case are handled and explicitly fail
if wanted.

Differential Revision: https://reviews.llvm.org/D76244

Added: 
    mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
    mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/lib/Dialect/SPIRV/TargetAndABI.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index a97c83b7a553..ba0b7ea0714c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -23,9 +23,20 @@ namespace mlir {
 
 /// Type conversion from standard types to SPIR-V types for shader interface.
 ///
-/// For composite types, this converter additionally performs type wrapping to
+/// Non-32-bit scalar types require special hardware support that may not exist
+/// on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar types
+/// require special capabilities or extensions. Right now if a scalar type of a
+/// certain bitwidth is not supported in the target environment, we use 32-bit
+/// ones unconditionally. This requires the runtime to also feed in data with
+/// a matched bitwidth and layout for interface types. The runtime can do that
+/// by inspecting the SPIR-V module.
+///
+/// For memref types, this converter additionally performs type wrapping to
 /// satisfy shader interface requirements: shader interface types must be
 /// pointers to structs.
+///
+/// TODO(antiagainst): We might want to introduce a way to control how
+/// unsupported bitwidth are handled and explicitly fail if wanted.
 class SPIRVTypeConverter : public TypeConverter {
 public:
   explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr);

diff  --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
index 27278b6d3f23..3f14addd9b6b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
+++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
@@ -44,7 +44,7 @@ class TargetEnv {
   Optional<Extension> allows(ArrayRef<Extension>) const;
 
   /// Returns the MLIRContext.
-  MLIRContext *getContext();
+  MLIRContext *getContext() const;
 
   /// Allows implicity converting to the underlying spirv::TargetEnvAttr.
   operator TargetEnvAttr() const { return targetAttr; }

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 5451048aabe0..3fd987b0e565 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -24,6 +24,64 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Checks that `candidates` extension requirements are possible to be satisfied
+/// with the given `targetEnv`.
+///
+///  `candidates` is a vector of vector for extension requirements following
+/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
+/// convention.
+template <typename LabelT>
+static LogicalResult checkExtensionRequirements(
+    LabelT label, const spirv::TargetEnv &targetEnv,
+    const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
+  for (const auto &ors : candidates) {
+    if (targetEnv.allows(ors))
+      continue;
+
+    SmallVector<StringRef, 4> extStrings;
+    for (spirv::Extension ext : ors)
+      extStrings.push_back(spirv::stringifyExtension(ext));
+
+    LLVM_DEBUG(llvm::dbgs()
+               << label << " illegal: requires at least one extension in ["
+               << llvm::join(extStrings, ", ")
+               << "] but none allowed in target environment\n");
+    return failure();
+  }
+  return success();
+}
+
+/// Checks that `candidates`capability requirements are possible to be satisfied
+/// with the given `isAllowedFn`.
+///
+///  `candidates` is a vector of vector for capability requirements following
+/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
+/// convention.
+template <typename LabelT>
+static LogicalResult checkCapabilityRequirements(
+    LabelT label, const spirv::TargetEnv &targetEnv,
+    const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
+  for (const auto &ors : candidates) {
+    if (targetEnv.allows(ors))
+      continue;
+
+    SmallVector<StringRef, 4> capStrings;
+    for (spirv::Capability cap : ors)
+      capStrings.push_back(spirv::stringifyCapability(cap));
+
+    LLVM_DEBUG(llvm::dbgs()
+               << label << " illegal: requires at least one capability in ["
+               << llvm::join(capStrings, ", ")
+               << "] but none allowed in target environment\n");
+    return failure();
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Type Conversion
 //===----------------------------------------------------------------------===//
@@ -159,62 +217,212 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
   return llvm::None;
 }
 
+/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
+static Optional<Type>
+convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
+                  Optional<spirv::StorageClass> storageClass = {}) {
+  // Get extension and capability requirements for the given type.
+  SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
+  SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
+  type.getExtensions(extensions, storageClass);
+  type.getCapabilities(capabilities, storageClass);
+
+  // If all requirements are met, then we can accept this type as-is.
+  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
+      succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+    return type;
+
+  // Otherwise we need to adjust the type, which really means adjusting the
+  // bitwidth given this is a scalar type.
+  // TODO(antiagainst): We are unconditionally converting the bitwidth here,
+  // this might be okay for non-interface types (i.e., types used in
+  // Priviate/Function storage classes), but not for interface types (i.e.,
+  // types used in StorageBuffer/Uniform/PushConstant/etc. storage classes).
+  // This is because the later actually affects the ABI contract with the
+  // runtime. So we may want to expose a control on SPIRVTypeConverter to fail
+  // conversion if we cannot change there.
+
+  if (auto floatType = type.dyn_cast<FloatType>()) {
+    LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
+    return Builder(targetEnv.getContext()).getF32Type();
+  }
+
+  auto intType = type.cast<IntegerType>();
+  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
+  return IntegerType::get(/*width=*/32, intType.getSignedness(),
+                          targetEnv.getContext());
+}
+
+/// Converts a vector `type` to a suitable type under the given `targetEnv`.
+static Optional<Type>
+convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
+                  Optional<spirv::StorageClass> storageClass = {}) {
+  if (!spirv::CompositeType::isValid(type)) {
+    // TODO(antiagainst): One-element vector types can be translated into scalar
+    // types. Vector types with more than four elements can be translated into
+    // array types.
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: 1- and > 4-element unimplemented\n");
+    return llvm::None;
+  }
+
+  // Get extension and capability requirements for the given type.
+  SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
+  SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
+  type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
+  type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
+
+  // If all requirements are met, then we can accept this type as-is.
+  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
+      succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+    return type;
+
+  auto elementType = convertScalarType(
+      targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass);
+  if (elementType)
+    return VectorType::get(type.getShape(), *elementType);
+  return llvm::None;
+}
+
+/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
+///
+/// Note that this is mainly for lowering constant tensors.In SPIR-V one can
+/// create composite constants with OpConstantComposite to embed relative large
+/// constant values and use OpCompositeExtract and OpCompositeInsert to
+/// manipulate, like what we do for vectors.
+static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
+                                        TensorType type) {
+  // TODO(ravishankarm) : Handle dynamic shapes.
+  if (!type.hasStaticShape()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: dynamic shape unimplemented\n");
+    return llvm::None;
+  }
+
+  auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
+  if (!scalarType) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot convert non-scalar element type\n");
+    return llvm::None;
+  }
+
+  Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
+  Optional<int64_t> tensorSize = getTypeNumBytes(type);
+  if (!scalarSize || !tensorSize) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot deduce element count\n");
+    return llvm::None;
+  }
+
+  auto arrayElemCount = *tensorSize / *scalarSize;
+  auto arrayElemType = convertScalarType(targetEnv, scalarType);
+  if (!arrayElemType)
+    return llvm::None;
+  Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
+  if (!arrayElemSize) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot deduce converted element size\n");
+    return llvm::None;
+  }
+
+  return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
+}
+
+static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
+                                        MemRefType type) {
+  // TODO(ravishankarm) : Handle dynamic shapes.
+  if (!type.hasStaticShape()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: dynamic shape unimplemented\n");
+    return llvm::None;
+  }
+
+  auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
+  if (!scalarType) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot convert non-scalar element type\n");
+    return llvm::None;
+  }
+
+  Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
+  Optional<int64_t> memrefSize = getTypeNumBytes(type);
+  if (!scalarSize || !memrefSize) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot deduce element count\n");
+    return llvm::None;
+  }
+
+  auto arrayElemCount = *memrefSize / *scalarSize;
+
+  auto storageClass =
+      SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
+  if (!storageClass) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot convert memory space\n");
+    return llvm::None;
+  }
+
+  auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
+  if (!arrayElemType)
+    return llvm::None;
+  Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
+  if (!arrayElemSize) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot deduce converted element size\n");
+    return llvm::None;
+  }
+
+  auto arrayType =
+      spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
+
+  // Wrap in a struct to satisfy Vulkan interface requirements.
+  auto structType = spirv::StructType::get(arrayType, 0);
+  return spirv::PointerType::get(structType, *storageClass);
+}
+
 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
     : targetEnv(targetAttr) {
-  addConversion([](Type type) -> Optional<Type> {
-    // If the type is already valid in SPIR-V, directly return.
-    return type.isa<spirv::SPIRVType>() ? type : Optional<Type>();
-  });
+  // Add conversions. The order matters here: later ones will be tried earlier.
+
+  // All other cases failed. Then we cannot convert this type.
+  addConversion([](Type type) { return llvm::None; });
+
+  // Allow all SPIR-V dialect specific types. This assumes all standard types
+  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
+  // were tried before.
+  //
+  // TODO(antiagainst): this assumes that the SPIR-V types are valid to use in
+  // the given target environment, which should be the case if the whole
+  // pipeline is driven by the same target environment. Still, we probably still
+  // want to validate and convert to be safe.
+  addConversion([](spirv::SPIRVType type) { return type; });
+
   addConversion([](IndexType indexType) {
     return SPIRVTypeConverter::getIndexType(indexType.getContext());
   });
-  addConversion([this](MemRefType memRefType) -> Type {
-    auto elementType = convertType(memRefType.getElementType());
-    if (!elementType)
-      return Type();
-
-    auto elementSize = getTypeNumBytes(elementType);
-    if (!elementSize)
-      return Type();
-
-    // TODO(ravishankarm) : Handle dynamic shapes.
-    if (memRefType.hasStaticShape()) {
-      auto arraySize = getTypeNumBytes(memRefType);
-      if (!arraySize)
-        return Type();
-
-      auto arrayType = spirv::ArrayType::get(
-          elementType, arraySize.getValue() / elementSize.getValue(),
-          elementSize.getValue());
-
-      // Wrap in a struct to satisfy Vulkan interface requirements.
-      auto structType = spirv::StructType::get(arrayType, 0);
-      if (auto sc = getStorageClassForMemorySpace(memRefType.getMemorySpace()))
-        return spirv::PointerType::get(structType, *sc);
-      return Type();
-    }
-    return Type();
+
+  addConversion([this](IntegerType intType) -> Optional<Type> {
+    if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
+      return convertScalarType(targetEnv, scalarType);
+    return llvm::None;
+  });
+
+  addConversion([this](FloatType floatType) -> Optional<Type> {
+    if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
+      return convertScalarType(targetEnv, scalarType);
+    return llvm::None;
+  });
+
+  addConversion([this](VectorType vectorType) {
+    return convertVectorType(targetEnv, vectorType);
   });
-  addConversion([this](TensorType tensorType) -> Type {
-    // TODO(ravishankarm) : Handle dynamic shapes.
-    if (!tensorType.hasStaticShape())
-      return Type();
-
-    auto elementType = convertType(tensorType.getElementType());
-    if (!elementType)
-      return Type();
-
-    auto elementSize = getTypeNumBytes(elementType);
-    if (!elementSize)
-      return Type();
-
-    auto tensorSize = getTypeNumBytes(tensorType);
-    if (!tensorSize)
-      return Type();
-
-    return spirv::ArrayType::get(elementType,
-                                 tensorSize.getValue() / elementSize.getValue(),
-                                 elementSize.getValue());
+
+  addConversion([this](TensorType tensorType) {
+    return convertTensorType(targetEnv, tensorType);
+  });
+
+  addConversion([this](MemRefType memRefType) {
+    return convertMemrefType(targetEnv, memRefType);
   });
 }
 
@@ -429,58 +637,6 @@ spirv::SPIRVConversionTarget::SPIRVConversionTarget(
     spirv::TargetEnvAttr targetAttr)
     : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
 
-/// Checks that `candidates` extension requirements are possible to be satisfied
-/// with the given `targetEnv`.
-///
-///  `candidates` is a vector of vector for extension requirements following
-/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
-/// convention.
-static LogicalResult checkExtensionRequirements(
-    Operation *op, const spirv::TargetEnv &targetEnv,
-    const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
-  for (const auto &ors : candidates) {
-    if (targetEnv.allows(ors))
-      continue;
-
-    SmallVector<StringRef, 4> extStrings;
-    for (spirv::Extension ext : ors)
-      extStrings.push_back(spirv::stringifyExtension(ext));
-
-    LLVM_DEBUG(llvm::dbgs() << op->getName()
-                            << " illegal: requires at least one extension in ["
-                            << llvm::join(extStrings, ", ")
-                            << "] but none allowed in target environment\n");
-    return failure();
-  }
-  return success();
-}
-
-/// Checks that `candidates`capability requirements are possible to be satisfied
-/// with the given `isAllowedFn`.
-///
-///  `candidates` is a vector of vector for capability requirements following
-/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
-/// convention.
-static LogicalResult checkCapabilityRequirements(
-    Operation *op, const spirv::TargetEnv &targetEnv,
-    const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
-  for (const auto &ors : candidates) {
-    if (targetEnv.allows(ors))
-      continue;
-
-    SmallVector<StringRef, 4> capStrings;
-    for (spirv::Capability cap : ors)
-      capStrings.push_back(spirv::stringifyCapability(cap));
-
-    LLVM_DEBUG(llvm::dbgs() << op->getName()
-                            << " illegal: requires at least one capability in ["
-                            << llvm::join(capStrings, ", ")
-                            << "] but none allowed in target environment\n");
-    return failure();
-  }
-  return success();
-}
-
 bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
   // Make sure this op is available at the given version. Ops not implementing
   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
@@ -506,7 +662,7 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
   // implementing QueryExtensionInterface do not require extensions to be
   // available.
   if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
-    if (failed(checkExtensionRequirements(op, this->targetEnv,
+    if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
                                           extensions.getExtensions())))
       return false;
 
@@ -514,7 +670,7 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
   // implementing QueryCapabilityInterface do not require capabilities to be
   // available.
   if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
-    if (failed(checkCapabilityRequirements(op, this->targetEnv,
+    if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
                                            capabilities.getCapabilities())))
       return false;
 
@@ -534,13 +690,14 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
   for (Type valueType : valueTypes) {
     typeExtensions.clear();
     valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
-    if (failed(checkExtensionRequirements(op, this->targetEnv, typeExtensions)))
+    if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
+                                          typeExtensions)))
       return false;
 
     typeCapabilities.clear();
     valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
-    if (failed(
-            checkCapabilityRequirements(op, this->targetEnv, typeCapabilities)))
+    if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
+                                           typeCapabilities)))
       return false;
   }
 

diff  --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
index 779a752e2e6c..491fcf9a6f21 100644
--- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
@@ -70,7 +70,9 @@ spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
   return llvm::None;
 }
 
-MLIRContext *spirv::TargetEnv::getContext() { return targetAttr.getContext(); }
+MLIRContext *spirv::TargetEnv::getContext() const {
+  return targetAttr.getContext();
+}
 
 //===----------------------------------------------------------------------===//
 // Utility functions

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
new file mode 100644
index 000000000000..a88678fd34ac
--- /dev/null
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -0,0 +1,552 @@
+// RUN: mlir-opt -split-input-file -convert-std-to-spirv %s -o - | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Integer types
+//===----------------------------------------------------------------------===//
+
+// Check that non-32-bit integer types are converted to 32-bit types if the
+// corresponding capabilities are not available.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @integer8
+// CHECK-SAME: i32
+// CHECK-SAME: si32
+// CHECK-SAME: ui32
+func @integer8(%arg0: i8, %arg1: si8, %arg2: ui8) { return }
+
+// CHECK-LABEL: spv.func @integer16
+// CHECK-SAME: i32
+// CHECK-SAME: si32
+// CHECK-SAME: ui32
+func @integer16(%arg0: i16, %arg1: si16, %arg2: ui16) { return }
+
+// CHECK-LABEL: spv.func @integer64
+// CHECK-SAME: i32
+// CHECK-SAME: si32
+// CHECK-SAME: ui32
+func @integer64(%arg0: i64, %arg1: si64, %arg2: ui64) { return }
+
+} // end module
+
+// -----
+
+// Check that non-32-bit integer types are kept untouched if the corresponding
+// capabilities are available.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Int8, Int16, Int64], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @integer8
+// CHECK-SAME: i8
+// CHECK-SAME: si8
+// CHECK-SAME: ui8
+func @integer8(%arg0: i8, %arg1: si8, %arg2: ui8) { return }
+
+// CHECK-LABEL: spv.func @integer16
+// CHECK-SAME: i16
+// CHECK-SAME: si16
+// CHECK-SAME: ui16
+func @integer16(%arg0: i16, %arg1: si16, %arg2: ui16) { return }
+
+// CHECK-LABEL: spv.func @integer64
+// CHECK-SAME: i64
+// CHECK-SAME: si64
+// CHECK-SAME: ui64
+func @integer64(%arg0: i64, %arg1: si64, %arg2: ui64) { return }
+
+} // end module
+
+// -----
+
+// Check that weird bitwidths are not supported.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-NOT: spv.func @integer4
+func @integer4(%arg0: i4) { return }
+
+// CHECK-NOT: spv.func @integer128
+func @integer128(%arg0: i128) { return }
+
+// CHECK-NOT: spv.func @integer42
+func @integer42(%arg0: i42) { return }
+
+} // end module
+// -----
+
+//===----------------------------------------------------------------------===//
+// Index type
+//===----------------------------------------------------------------------===//
+
+// The index type is always converted into i32.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @index_type
+// CHECK-SAME: %{{.*}}: i32
+func @index_type(%arg0: index) { return }
+
+} // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Float types
+//===----------------------------------------------------------------------===//
+
+// Check that non-32-bit float types are converted to 32-bit types if the
+// corresponding capabilities are not available.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @float16
+// CHECK-SAME: f32
+func @float16(%arg0: f16) { return }
+
+// CHECK-LABEL: spv.func @float64
+// CHECK-SAME: f32
+func @float64(%arg0: f64) { return }
+
+} // end module
+
+// -----
+
+// Check that non-32-bit float types are kept untouched if the corresponding
+// capabilities are available.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Float16, Float64], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @float16
+// CHECK-SAME: f16
+func @float16(%arg0: f16) { return }
+
+// CHECK-LABEL: spv.func @float64
+// CHECK-SAME: f64
+func @float64(%arg0: f64) { return }
+
+} // end module
+
+// -----
+
+// Check that bf16 is not supported.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-NOT: spv.func @bf16_type
+func @bf16_type(%arg0: bf16) { return }
+
+} // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Vector types
+//===----------------------------------------------------------------------===//
+
+// Check that capabilities for scalar types affects vector types too: no special
+// capabilities available means using turning element types to 32-bit.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @int_vector
+// CHECK-SAME: vector<2xi32>
+// CHECK-SAME: vector<3xsi32>
+// CHECK-SAME: vector<4xui32>
+func @int_vector(
+  %arg0: vector<2xi8>,
+  %arg1: vector<3xsi16>,
+  %arg2: vector<4xui64>
+) { return }
+
+// CHECK-LABEL: spv.func @float_vector
+// CHECK-SAME: vector<2xf32>
+// CHECK-SAME: vector<3xf32>
+func @float_vector(
+  %arg0: vector<2xf16>,
+  %arg1: vector<3xf64>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that capabilities for scalar types affects vector types too: having
+// special capabilities means keep vector types untouched.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @int_vector
+// CHECK-SAME: vector<2xi8>
+// CHECK-SAME: vector<3xsi16>
+// CHECK-SAME: vector<4xui64>
+func @int_vector(
+  %arg0: vector<2xi8>,
+  %arg1: vector<3xsi16>,
+  %arg2: vector<4xui64>
+) { return }
+
+// CHECK-LABEL: spv.func @float_vector
+// CHECK-SAME: vector<2xf16>
+// CHECK-SAME: vector<3xf64>
+func @float_vector(
+  %arg0: vector<2xf16>,
+  %arg1: vector<3xf64>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that 1- or > 4-element vectors are not supported.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-NOT: spv.func @one_element_vector
+func @one_element_vector(%arg0: vector<1xi32>) { return }
+
+// CHECK-NOT: spv.func @large_vector
+func @large_vector(%arg0: vector<1024xi32>) { return }
+
+} // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// MemRef types
+//===----------------------------------------------------------------------===//
+
+// Check that using non-32-bit scalar types in interface storage classes
+// requires special capability and extension: convert them to 32-bit if not
+// satisfied.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i32 [4]> [0]>, StorageBuffer>
+func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
+
+// CHECK-LABEL: spv.func @memref_8bit_Uniform
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x si32 [4]> [0]>, Uniform>
+func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return }
+
+// CHECK-LABEL: spv.func @memref_8bit_PushConstant
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x ui32 [4]> [0]>, PushConstant>
+func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i32 [4]> [0]>, StorageBuffer>
+func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_Uniform
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x si32 [4]> [0]>, Uniform>
+func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_PushConstant
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x ui32 [4]> [0]>, PushConstant>
+func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_Input
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f32 [4]> [0]>, Input>
+func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_Output
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f32 [4]> [0]>, Output>
+func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
+
+} // end module
+
+// -----
+
+// Check that using non-32-bit scalar types in interface storage classes
+// requires special capability and extension: keep as-is when the capability
+// and extension is available.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [StoragePushConstant8, StoragePushConstant16],
+             [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @memref_8bit_PushConstant
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i8 [1]> [0]>, PushConstant>
+func @memref_8bit_PushConstant(%arg0: memref<16xi8, 7>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_PushConstant
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16 [2]> [0]>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16 [2]> [0]>, PushConstant>
+func @memref_16bit_PushConstant(
+  %arg0: memref<16xi16, 7>,
+  %arg1: memref<16xf16, 7>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that using non-32-bit scalar types in interface storage classes
+// requires special capability and extension: keep as-is when the capability
+// and extension is available.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [StorageBuffer8BitAccess, StorageBuffer16BitAccess],
+             [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i8 [1]> [0]>, StorageBuffer>
+func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16 [2]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16 [2]> [0]>, StorageBuffer>
+func @memref_16bit_StorageBuffer(
+  %arg0: memref<16xi16, 0>,
+  %arg1: memref<16xf16, 0>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that using non-32-bit scalar types in interface storage classes
+// requires special capability and extension: keep as-is when the capability
+// and extension is available.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16],
+             [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @memref_8bit_Uniform
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i8 [1]> [0]>, Uniform>
+func @memref_8bit_Uniform(%arg0: memref<16xi8, 4>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_Uniform
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16 [2]> [0]>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16 [2]> [0]>, Uniform>
+func @memref_16bit_Uniform(
+  %arg0: memref<16xi16, 4>,
+  %arg1: memref<16xf16, 4>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that using non-32-bit scalar types in interface storage classes
+// requires special capability and extension: keep as-is when the capability
+// and extension is available.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [StorageInputOutput16], [SPV_KHR_16bit_storage]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @memref_16bit_Input
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16 [2]> [0]>, Input>
+func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_Output
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16 [2]> [0]>, Output>
+func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
+
+} // end module
+
+// -----
+
+// Check that memref offset and strides affect the array size.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [StorageBuffer16BitAccess], [SPV_KHR_16bit_storage]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @memref_offset_strides
+func @memref_offset_strides(
+// CHECK-SAME: !spv.array<64 x f32 [4]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<72 x f32 [4]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<256 x f32 [4]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<64 x f32 [4]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<88 x f32 [4]> [0]>, StorageBuffer>
+  %arg0: memref<16x4xf32, offset: 0, strides: [4, 1]>,  // tightly packed; row major
+  %arg1: memref<16x4xf32, offset: 8, strides: [4, 1]>,  // offset 8
+  %arg2: memref<16x4xf32, offset: 0, strides: [16, 1]>, // pad 12 after each row
+  %arg3: memref<16x4xf32, offset: 0, strides: [1, 16]>, // tightly packed; col major
+  %arg4: memref<16x4xf32, offset: 0, strides: [1, 22]>, // pad 4 after each col
+
+// CHECK-SAME: !spv.array<64 x f16 [2]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<72 x f16 [2]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<256 x f16 [2]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<64 x f16 [2]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<88 x f16 [2]> [0]>, StorageBuffer>
+  %arg5: memref<16x4xf16, offset: 0, strides: [4, 1]>,
+  %arg6: memref<16x4xf16, offset: 8, strides: [4, 1]>,
+  %arg7: memref<16x4xf16, offset: 0, strides: [16, 1]>,
+  %arg8: memref<16x4xf16, offset: 0, strides: [1, 16]>,
+  %arg9: memref<16x4xf16, offset: 0, strides: [1, 22]>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that dynamic shapes are not supported.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: func @unranked_memref
+// CHECK-SAME: memref<*xi32>
+func @unranked_memref(%arg0: memref<*xi32>) { return }
+
+// CHECK-LABEL: func @dynamic_dim_memref
+// CHECK-SAME: memref<8x?xi32>
+func @dynamic_dim_memref(%arg0: memref<8x?xi32>) { return }
+
+} // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Tensor types
+//===----------------------------------------------------------------------===//
+
+// Check that tensor element types are kept untouched with proper capabilites.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @int_tensor_types
+// CHECK-SAME: !spv.array<32 x i64 [8]>
+// CHECK-SAME: !spv.array<32 x i32 [4]>
+// CHECK-SAME: !spv.array<32 x i16 [2]>
+// CHECK-SAME: !spv.array<32 x i8 [1]>
+func @int_tensor_types(
+  %arg0: tensor<8x4xi64>,
+  %arg1: tensor<8x4xi32>,
+  %arg2: tensor<8x4xi16>,
+  %arg3: tensor<8x4xi8>
+) { return }
+
+// CHECK-LABEL: spv.func @float_tensor_types
+// CHECK-SAME: !spv.array<32 x f64 [8]>
+// CHECK-SAME: !spv.array<32 x f32 [4]>
+// CHECK-SAME: !spv.array<32 x f16 [2]>
+func @float_tensor_types(
+  %arg0: tensor<8x4xf64>,
+  %arg1: tensor<8x4xf32>,
+  %arg2: tensor<8x4xf16>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that tensor element types are changed to 32-bit without capabilities.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: spv.func @int_tensor_types
+// CHECK-SAME: !spv.array<32 x i32 [4]>
+// CHECK-SAME: !spv.array<32 x i32 [4]>
+// CHECK-SAME: !spv.array<32 x i32 [4]>
+// CHECK-SAME: !spv.array<32 x i32 [4]>
+func @int_tensor_types(
+  %arg0: tensor<8x4xi64>,
+  %arg1: tensor<8x4xi32>,
+  %arg2: tensor<8x4xi16>,
+  %arg3: tensor<8x4xi8>
+) { return }
+
+// CHECK-LABEL: spv.func @float_tensor_types
+// CHECK-SAME: !spv.array<32 x f32 [4]>
+// CHECK-SAME: !spv.array<32 x f32 [4]>
+// CHECK-SAME: !spv.array<32 x f32 [4]>
+func @float_tensor_types(
+  %arg0: tensor<8x4xf64>,
+  %arg1: tensor<8x4xf32>,
+  %arg2: tensor<8x4xf16>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that dynamic shapes are not supported.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: func @unranked_tensor
+// CHECK-SAME: tensor<*xi32>
+func @unranked_tensor(%arg0: tensor<*xi32>) { return }
+
+// CHECK-LABEL: func @dynamic_dim_tensor
+// CHECK-SAME: tensor<8x?xi32>
+func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
+
+} // end module


        


More information about the Mlir-commits mailing list