[Mlir-commits] [mlir] [mlir][spirv] Add support for VectorAnyINTEL capability (PR #68034)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 2 13:41:37 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

<details>
<summary>Changes</summary>

Allow vector of any lengths between [2-2^32-1].
VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension) relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4.

---

Patch is 44.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68034.diff


16 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+8-3) 
- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+70) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+5-2) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+16-6) 
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+102-24) 
- (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir (+2-2) 
- (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+40) 
- (modified) mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir (+15-2) 
- (modified) mlir/test/Conversion/GPUToSPIRV/reductions.mlir (+32-32) 
- (modified) mlir/test/Dialect/SPIRV/IR/bit-ops.mlir (+3-3) 
- (modified) mlir/test/Dialect/SPIRV/IR/gl-ops.mlir (+1-1) 
- (modified) mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir (+2-2) 
- (modified) mlir/test/Dialect/SPIRV/IR/logical-ops.mlir (+1-1) 
- (modified) mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir (+22-20) 
- (modified) mlir/test/Target/SPIRV/arithmetic-ops.mlir (+3-3) 
- (modified) mlir/test/Target/SPIRV/ocl-ops.mlir (+6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1013cbc8ca562b7..c458a500eb367f9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
 def SPIRV_Float32 : TypeAlias<F32, "Float32">;
 def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
 def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
-def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
+// Remove the vector size restriction.
+// Although the vector size can be upto (2^64-1), uint64,
+// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose
+// for all practical cases.
+// Also unsigned is used for the number elements for composite tyeps.
+def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
                                        [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
 // Component type check is done in the type parser for the following SPIR-V
 // dialect-specific types so we use "Any" here.
@@ -4206,10 +4211,10 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
     "Joint Matrix">;
 
 class SPIRV_ScalarOrVectorOf<Type type> :
-    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
+    AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>]>;
 
 class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
-    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+    AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>,
                SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
 
 class SPIRV_MatrixOrCoopMatrixOf<Type type> :
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4fc14e30b8a10d0..703122547df7493 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -546,6 +546,76 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
+// Whether the number of elements of a vector is from the given
+// `allowedRanges` list, the list has two values, start and end
+// of the range (inclusive).
+class IsVectorOfLengthRangePred<list<int> allowedRanges>
+  : And<[IsVectorTypePred,
+        And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()>= }] # allowedRanges[0]>,
+            CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Whether the number of elements of a fixed-length vector is from the given
+// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
+class IsFixedVectorOfLengthRangePred<list<int> allowedRanges>
+  : And<[IsFixedVectorTypePred,
+        And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
+            CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Whether the number of elements of a scalable vector is from the given
+// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
+class IsScalableVectorOfLengthRangePred<list<int> allowedRanges>
+  : And<[IsScalableVectorTypePred,
+        And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
+            CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Any vector where the number of elements is from the given
+// `allowedRanges` list.
+class VectorOfLengthRange<list<int> allowedRanges>
+  : Type<IsVectorOfLengthRangePred<allowedRanges>,
+    " of length " # !interleave(allowedRanges, "-"),
+    "::mlir::VectorType">;
+
+// Any fixed-length vector where the number of elements is from the given
+// `allowedRanges` list.
+class FixedVectorOfLengthRange<list<int> allowedRanges>
+  : Type<IsFixedVectorOfLengthRangePred<allowedRanges>,
+    " of length " # !interleave(allowedRanges, "-"),
+    "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedRanges` list.
+class ScalableVectorOfLengthRange<list<int> allowedRanges>
+  : Type<IsScalableVectorOfLengthRangePred<allowedRanges>,
+    " of length " # !interleave(allowedRanges, "-"),
+    "::mlir::VectorType">;
+
+// Any vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class VectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+  : Type<And<[VectorOf<allowedTypes>.predicate, VectorOfLengthRange<allowedRanges>.predicate]>,
+        VectorOf<allowedTypes>.summary # VectorOfLengthRange<allowedRanges>.summary,
+        "::mlir::VectorType">;
+
+// Any fixed-length vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class FixedVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+  : Type<
+      And<[FixedVectorOf<allowedTypes>.predicate, FixedVectorOfLengthRange<allowedRanges>.predicate]>,
+      FixedVectorOf<allowedTypes>.summary # FixedVectorOfLengthRange<allowedRanges>.summary,
+      "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class ScalableVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+  : Type<
+      And<[ScalableVectorOf<allowedTypes>.predicate, ScalableVectorOfLengthRange<allowedRanges>.predicate]>,
+      ScalableVectorOf<allowedTypes>.summary # ScalableVectorOfLengthRange<allowedRanges>.summary,
+      "::mlir::VectorType">;
+
+
 def AnyVector : VectorOf<[AnyType]>;
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a51d77dda78bf2f..be85d3c330a887a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -184,9 +184,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
       return Type();
     }
-    if (t.getNumElements() > 4) {
+    // Number of elements should be between [2 - 2^32 -1],
+    // since getNumElements() returns an unsigned, the upper limit check is
+    // unnecessary.
+    if (t.getNumElements() < 2) {
       parser.emitError(
-          typeLoc, "vector length has to be less than or equal to 4 but found ")
+          typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
           << t.getNumElements();
       return Type();
     }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 39d6603a46f965d..9d39d99b4148253 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) {
 }
 
 bool CompositeType::isValid(VectorType type) {
-  return type.getRank() == 1 &&
-         llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
-         llvm::isa<ScalarType>(type.getElementType());
+  // Number of elements should be between [2 - 2^32 -1],
+  // since getNumElements() returns an unsigned, the upper limit check is
+  // unnecessary.
+  return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType()) &&
+         type.getNumElements() >= 2;
 }
 
 Type CompositeType::getElementType(unsigned index) const {
@@ -171,9 +173,17 @@ void CompositeType::getCapabilities(
       .Case<VectorType>([&](VectorType type) {
         auto vecSize = getNumElements();
         if (vecSize == 8 || vecSize == 16) {
-          static const Capability caps[] = {Capability::Vector16};
-          ArrayRef<Capability> ref(caps, std::size(caps));
-          capabilities.push_back(ref);
+          static constexpr Capability caps[] = {Capability::Vector16,
+                                                Capability::VectorAnyINTEL};
+          capabilities.push_back(caps);
+        }
+        // VectorAnyINTEL capability removes the vector size restriction and
+        // allows the vector size to be up to (2^32-1).
+        // Vector16 capability allows the vector size to be 8 and 16
+        SmallVector<unsigned, 5> allowedVecRange = {2, 3, 4, 8, 16};
+        if (vecSize >= 2 && !llvm::is_contained(allowedVecRange, vecSize)) {
+          static constexpr Capability caps[] = {Capability::VectorAnyINTEL};
+          capabilities.push_back(caps);
         }
         return llvm::cast<ScalarType>(type.getElementType())
             .getCapabilities(capabilities, storage);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9e09..25e6a080642e681 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -43,9 +43,13 @@ using namespace mlir;
 template <typename LabelT>
 static LogicalResult checkExtensionRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
-    const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
+    const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
+    const ArrayRef<spirv::Extension> elidedCandidates = {}) {
   for (const auto &ors : candidates) {
-    if (targetEnv.allows(ors))
+    if (targetEnv.allows(ors) ||
+        llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) {
+          return llvm::is_contained(ors, elidedExt);
+        }))
       continue;
 
     LLVM_DEBUG({
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
 template <typename LabelT>
 static LogicalResult checkCapabilityRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
-    const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
+    const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
+    const ArrayRef<spirv::Capability> elidedCandidates = {}) {
   for (const auto &ors : candidates) {
-    if (targetEnv.allows(ors))
+    if (targetEnv.allows(ors) ||
+        llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) {
+          return llvm::is_contained(ors, elidedCap);
+        }))
       continue;
 
     LLVM_DEBUG({
@@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements(
   return success();
 }
 
-/// Returns true if the given `storageClass` needs explicit layout when used in
-/// Shader environments.
+/// Check capabilities and extensions requirements
+/// Checks that `capCandidates`, `extCandidates`, and capability
+/// (`capCandidates`) infered extension requirements are possible to be
+/// satisfied with the given `targetEnv`.
+/// It also provides a way to relax requirements for certain capabilities and
+/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to
+/// allow passes to relax certain requirements based on an option (e.g.,
+/// relaxing bitwidth requirement, see `convertScalarType()`,
+/// `ConvertVectorType()`).
+template <typename LabelT>
+static LogicalResult checkCapabilityAndExtensionRequirements(
+    LabelT label, const spirv::TargetEnv &targetEnv,
+    const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates,
+    const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates,
+    const ArrayRef<spirv::Capability> elidedCapCandidates = {},
+    const ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
+  SmallVector<ArrayRef<spirv::Extension>, 8> updatedExtCandidates;
+  llvm::append_range(updatedExtCandidates, extCandidates);
+
+  if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates,
+                                         elidedCapCandidates)))
+    return failure();
+  // Add capablity infered extensions to the list of extension requirement list,
+  // only considers the capabilities that already available in the `targetEnv`.
+
+  // WARNING: Some capabilities are part of both the core SPIR-V
+  // specification and an extension (e.g., 'Groups' capability is part of both
+  // core specification and SPV_AMD_shader_ballot extension, hence we should
+  // relax the capability inferred extension for these cases).
+  static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups};
+  ArrayRef<spirv::Capability> multiModalCapsArrayRef(multiModalCaps,
+                                                     std::size(multiModalCaps));
+
+  for (auto cap : targetEnv.getAttr().getCapabilities()) {
+    if (llvm::any_of(multiModalCapsArrayRef,
+                     [&cap](spirv::Capability mMCap) { return cap == mMCap; }))
+      continue;
+    std::optional<ArrayRef<spirv::Extension>> ext = getExtensions(cap);
+    if (ext)
+      updatedExtCandidates.push_back(*ext);
+  }
+  if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates,
+                                        elidedExtCandidates)))
+    return failure();
+  return success();
+}
+
+/// Returns true if the given `storageClass` needs explicit layout when used
+/// in Shader environments.
 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
   switch (storageClass) {
   case spirv::StorageClass::PhysicalStorageBuffer:
@@ -230,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
   type.getCapabilities(capabilities, storageClass);
 
   // If all requirements are met, then we can accept this type as-is.
-  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
-      succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+  if (succeeded(checkCapabilityAndExtensionRequirements(
+          type, targetEnv, capabilities, extensions)))
     return type;
 
   // Otherwise we need to adjust the type, which really means adjusting the
@@ -342,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
   cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
 
+  // If the bit-width related capabilities and extensions are not met
+  // for lower bit-width (<32-bit), convert it to 32-bit
+  auto elementType =
+      convertScalarType(targetEnv, options, scalarType, storageClass);
+  if (!elementType)
+    return nullptr;
+  type = VectorType::get(type.getShape(), elementType);
+
+  SmallVector<spirv::Capability, 4> elidedCaps;
+  SmallVector<spirv::Extension, 4> elidedExts;
+
+  // Relax the bitwidth requirements for capabilities and extensions
+  if (options.emulateLT32BitScalarTypes) {
+    elidedCaps.push_back(spirv::Capability::Int8);
+    elidedCaps.push_back(spirv::Capability::Int16);
+    elidedCaps.push_back(spirv::Capability::Float16);
+  }
+  // For capabilities whose requirements were relaxed, relax requirements for
+  // the extensions that were infered by those capabilities (e.g., elidedCaps)
+  for (spirv::Capability cap : elidedCaps) {
+    std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
+    if (ext)
+      llvm::append_range(elidedExts, *ext);
+  }
   // If all requirements are met, then we can accept this type as-is.
-  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
-      succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+  if (succeeded(checkCapabilityAndExtensionRequirements(
+          type, targetEnv, capabilities, extensions, elidedCaps, elidedExts)))
     return type;
 
-  auto elementType =
-      convertScalarType(targetEnv, options, scalarType, storageClass);
-  if (elementType)
-    return VectorType::get(type.getShape(), elementType);
   return nullptr;
 }
 
@@ -656,8 +731,9 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
   SmallVector<ArrayRef<spirv::Capability>, 2> caps;
   scalarType.getExtensions(exts);
   scalarType.getCapabilities(caps);
-  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
-      failed(checkExtensionRequirements(type, targetEnv, exts))) {
+
+  if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps,
+                                                     exts))) {
     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
     return castOp.getResult(0);
   }
@@ -1150,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
   for (Type valueType : valueTypes) {
-    typeExtensions.clear();
-    cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
-    if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
-                                          typeExtensions)))
-      return false;
-
     typeCapabilities.clear();
     cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
-    if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
-                                           typeCapabilities)))
+    typeExtensions.clear();
+    cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
+    // Checking for capability and extension requirements along with capability
+    // infered extensions.
+    // If a capability is present, the extension that
+    // supports it should also be present, this reduces the burden of adding
+    // extension requirement that may or maynot be added in
+    // CompositeType::getExtensions().
+    if (failed(checkCapabilityAndExtensionRequirements(
+            op->getName(), this->targetEnv, typeCapabilities, typeExtensions)))
       return false;
   }
 
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d8570..d61ace8d6876b87 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -11,9 +11,9 @@ module attributes {
     #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
 } {
 
-func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
+func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) {
   // expected-error at +1 {{failed to legalize operation 'arith.subi'}}
-  %1 = arith.subi %arg0, %arg0: vector<5xi32>
+  %1 = arith.subi %arg0, %arg1: vector<5xi32>
   return
 }
 
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 0221e4815a9397d..6ceeade486efd68 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
 }
 
 } // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// VectorAnyINTEL support
+//===----------------------------------------------------------------------===//
+
+// Check that with VectorAnyINTEL, VectorComputeINTEL capability,
+// and SPV_INTEL_vector_compute extension, any sized (2-2^32 -1) vector is allowed.
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel, VectorAnyINTEL], [SPV_INTEL_vector_compute]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @any_vector
+func.func @any_vector(%arg0: vector<16xi32>, %arg1: vector<16xi32>) {
+  // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<16xi32>
+  %0 = arith.subi %arg0, %arg1: vector<16xi32>
+  return
+}
+
+// CHECK-LABEL: @max_vector
+func.func @max_vector(%arg0: vector<4294967295xi32>, %arg1: vector<4294967295xi32>) {
+  // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<4294967295xi32>
+  %0 = arith.subi %arg0, %arg1: vector<4294967295xi32>
+  return
+}
+
+
+// Check float vector types of any size.
+// CHECK-LABEL: @float_vector58
+func.func @float_vector58(%arg0: vector<5xf16>, %arg1: vector<8xf64>) {
+  // CHECK: spirv.FAdd %{{.*}}, %{{.*}}: vector<5xf16>
+  %0 = arith.addf %arg0, %arg0: vector<5xf16>
+  // CHECK: spirv.FMul %{{.*}}, %{{.*}}: vector<8xf64>
+  %1 = arith.mulf %a...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/68034


More information about the Mlir-commits mailing list