[Mlir-commits] [mlir] [mlir][spirv] Add support for VectorAnyINTEL capability (PR #68034)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Mon Oct 2 15:12:46 PDT 2023


https://github.com/mshahneo updated https://github.com/llvm/llvm-project/pull/68034

>From 80c02b19ea578a1d18c6bcf394610f9f8d2068b0 Mon Sep 17 00:00:00 2001
From: Md Abdullah Shahneous Bari <Md.Abdullah.Shahneous.Bari at intel.com>
Date: Tue, 26 Sep 2023 14:45:05 -0700
Subject: [PATCH 1/4] [mlir] Add support for vector types whose number of
 elements are from a range of values

Add types and predicates for Vector, Fixed Vector, and Scalable Vector
whose number of elements is from a given `allowedRanges` list.
The list has two values, start and end of the range (inclusive).
---
 mlir/include/mlir/IR/CommonTypeConstraints.td | 70 +++++++++++++++++++
 1 file changed, 70 insertions(+)

diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4fc14e30b8a10d0..703122547df7493 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -546,6 +546,76 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
+// Whether the number of elements of a vector is from the given
+// `allowedRanges` list, the list has two values, start and end
+// of the range (inclusive).
+class IsVectorOfLengthRangePred<list<int> allowedRanges>
+  : And<[IsVectorTypePred,
+        And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()>= }] # allowedRanges[0]>,
+            CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Whether the number of elements of a fixed-length vector is from the given
+// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
+class IsFixedVectorOfLengthRangePred<list<int> allowedRanges>
+  : And<[IsFixedVectorTypePred,
+        And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
+            CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Whether the number of elements of a scalable vector is from the given
+// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
+class IsScalableVectorOfLengthRangePred<list<int> allowedRanges>
+  : And<[IsScalableVectorTypePred,
+        And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
+            CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Any vector where the number of elements is from the given
+// `allowedRanges` list.
+class VectorOfLengthRange<list<int> allowedRanges>
+  : Type<IsVectorOfLengthRangePred<allowedRanges>,
+    " of length " # !interleave(allowedRanges, "-"),
+    "::mlir::VectorType">;
+
+// Any fixed-length vector where the number of elements is from the given
+// `allowedRanges` list.
+class FixedVectorOfLengthRange<list<int> allowedRanges>
+  : Type<IsFixedVectorOfLengthRangePred<allowedRanges>,
+    " of length " # !interleave(allowedRanges, "-"),
+    "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedRanges` list.
+class ScalableVectorOfLengthRange<list<int> allowedRanges>
+  : Type<IsScalableVectorOfLengthRangePred<allowedRanges>,
+    " of length " # !interleave(allowedRanges, "-"),
+    "::mlir::VectorType">;
+
+// Any vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class VectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+  : Type<And<[VectorOf<allowedTypes>.predicate, VectorOfLengthRange<allowedRanges>.predicate]>,
+        VectorOf<allowedTypes>.summary # VectorOfLengthRange<allowedRanges>.summary,
+        "::mlir::VectorType">;
+
+// Any fixed-length vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class FixedVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+  : Type<
+      And<[FixedVectorOf<allowedTypes>.predicate, FixedVectorOfLengthRange<allowedRanges>.predicate]>,
+      FixedVectorOf<allowedTypes>.summary # FixedVectorOfLengthRange<allowedRanges>.summary,
+      "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class ScalableVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+  : Type<
+      And<[ScalableVectorOf<allowedTypes>.predicate, ScalableVectorOfLengthRange<allowedRanges>.predicate]>,
+      ScalableVectorOf<allowedTypes>.summary # ScalableVectorOfLengthRange<allowedRanges>.summary,
+      "::mlir::VectorType">;
+
+
 def AnyVector : VectorOf<[AnyType]>;
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;

>From 2509cb8bd03540fd49141ab0a9d59f8a0a8066bb Mon Sep 17 00:00:00 2001
From: Md Abdullah Shahneous Bari <Md.Abdullah.Shahneous.Bari at intel.com>
Date: Mon, 2 Oct 2023 11:15:08 -0700
Subject: [PATCH 2/4] [mlir][spirv] Extend capabilities and extensions
 requirements checking

Allow a way to relax requirements for certain capabilities and
extensions (e.g., `elidedCandidates`).

Also add a combined check for capabilities and extensions in
`checkCapabilityAndExtensionRequirements`.
This function checks capabilities, extensions, and
capability infered extension requirements.
---
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 67 +++++++++++++++++--
 1 file changed, 61 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9e09..7bcd36da0c21ee6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -43,9 +43,13 @@ using namespace mlir;
 template <typename LabelT>
 static LogicalResult checkExtensionRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
-    const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
+    const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
+    const ArrayRef<spirv::Extension> elidedCandidates = {}) {
   for (const auto &ors : candidates) {
-    if (targetEnv.allows(ors))
+    if (targetEnv.allows(ors) ||
+        llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) {
+          return llvm::is_contained(ors, elidedExt);
+        }))
       continue;
 
     LLVM_DEBUG({
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
 template <typename LabelT>
 static LogicalResult checkCapabilityRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
-    const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
+    const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
+    const ArrayRef<spirv::Capability> elidedCandidates = {}) {
   for (const auto &ors : candidates) {
-    if (targetEnv.allows(ors))
+    if (targetEnv.allows(ors) ||
+        llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) {
+          return llvm::is_contained(ors, elidedCap);
+        }))
       continue;
 
     LLVM_DEBUG({
@@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements(
   return success();
 }
 
-/// Returns true if the given `storageClass` needs explicit layout when used in
-/// Shader environments.
+/// Check capabilities and extensions requirements
+/// Checks that `capCandidates`, `extCandidates`, and capability
+/// (`capCandidates`) infered extension requirements are possible to be
+/// satisfied with the given `targetEnv`.
+/// It also provides a way to relax requirements for certain capabilities and
+/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to
+/// allow passes to relax certain requirements based on an option (e.g.,
+/// relaxing bitwidth requirement, see `convertScalarType()`,
+/// `ConvertVectorType()`).
+template <typename LabelT>
+static LogicalResult checkCapabilityAndExtensionRequirements(
+    LabelT label, const spirv::TargetEnv &targetEnv,
+    const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates,
+    const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates,
+    const ArrayRef<spirv::Capability> elidedCapCandidates = {},
+    const ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
+  SmallVector<ArrayRef<spirv::Extension>, 8> updatedExtCandidates;
+  llvm::append_range(updatedExtCandidates, extCandidates);
+
+  if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates,
+                                         elidedCapCandidates)))
+    return failure();
+  // Add capablity infered extensions to the list of extension requirement list,
+  // only considers the capabilities that already available in the `targetEnv`.
+
+  // WARNING: Some capabilities are part of both the core SPIR-V
+  // specification and an extension (e.g., 'Groups' capability is part of both
+  // core specification and SPV_AMD_shader_ballot extension, hence we should
+  // relax the capability inferred extension for these cases).
+  static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups};
+  ArrayRef<spirv::Capability> multiModalCapsArrayRef(multiModalCaps,
+                                                     std::size(multiModalCaps));
+
+  for (auto cap : targetEnv.getAttr().getCapabilities()) {
+    if (llvm::any_of(multiModalCapsArrayRef,
+                     [&cap](spirv::Capability mMCap) { return cap == mMCap; }))
+      continue;
+    std::optional<ArrayRef<spirv::Extension>> ext = getExtensions(cap);
+    if (ext)
+      updatedExtCandidates.push_back(*ext);
+  }
+  if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates,
+                                        elidedExtCandidates)))
+    return failure();
+  return success();
+}
+
+/// Returns true if the given `storageClass` needs explicit layout when used
+/// in Shader environments.
 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
   switch (storageClass) {
   case spirv::StorageClass::PhysicalStorageBuffer:

>From ade9fec70ab1d773c6774524185a2a6f6f156dce Mon Sep 17 00:00:00 2001
From: Md Abdullah Shahneous Bari <Md.Abdullah.Shahneous.Bari at intel.com>
Date: Mon, 2 Oct 2023 11:16:43 -0700
Subject: [PATCH 3/4] [mlir][spirv] Use combined-check for type related
 extension and capability requirements

Replace the seperate extension and capability checking with combined check
`checkCapabilityAndExtensionRequirements()`. This makes the code flow simpler.
Also adds the extra check for capability inferred extension check.

Need for capability inferred extension check:
If a capability is a requirement, the respective extension that implements
it should also become an extension requirement, there were no support for
that check, as a result, the extension requirement had to be added separately.
This separate requirement addition causes problem when a feature is enabled by
multiple capability, and one of the capability is part of an extension. E.g.,
vector size of 16 can be enabled by both "Vector16" and "vectorAnyINTEL"
capability, however, only "vectorAnyINTEL" has an extension requirement
("SPV_INTEL_vector_compute"). Since the process of adding capability
and extension requirement are independent, there is no way, to handle
cases like this. Therefore, for cases like this, enable adding capability
requirement initially, then do the check for capability inferred extension.
---
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 59 +++++++++++++------
 1 file changed, 41 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 7bcd36da0c21ee6..25e6a080642e681 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -285,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
   type.getCapabilities(capabilities, storageClass);
 
   // If all requirements are met, then we can accept this type as-is.
-  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
-      succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+  if (succeeded(checkCapabilityAndExtensionRequirements(
+          type, targetEnv, capabilities, extensions)))
     return type;
 
   // Otherwise we need to adjust the type, which really means adjusting the
@@ -397,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
   cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
 
+  // If the bit-width related capabilities and extensions are not met
+  // for lower bit-width (<32-bit), convert it to 32-bit
+  auto elementType =
+      convertScalarType(targetEnv, options, scalarType, storageClass);
+  if (!elementType)
+    return nullptr;
+  type = VectorType::get(type.getShape(), elementType);
+
+  SmallVector<spirv::Capability, 4> elidedCaps;
+  SmallVector<spirv::Extension, 4> elidedExts;
+
+  // Relax the bitwidth requirements for capabilities and extensions
+  if (options.emulateLT32BitScalarTypes) {
+    elidedCaps.push_back(spirv::Capability::Int8);
+    elidedCaps.push_back(spirv::Capability::Int16);
+    elidedCaps.push_back(spirv::Capability::Float16);
+  }
+  // For capabilities whose requirements were relaxed, relax requirements for
+  // the extensions that were infered by those capabilities (e.g., elidedCaps)
+  for (spirv::Capability cap : elidedCaps) {
+    std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
+    if (ext)
+      llvm::append_range(elidedExts, *ext);
+  }
   // If all requirements are met, then we can accept this type as-is.
-  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
-      succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+  if (succeeded(checkCapabilityAndExtensionRequirements(
+          type, targetEnv, capabilities, extensions, elidedCaps, elidedExts)))
     return type;
 
-  auto elementType =
-      convertScalarType(targetEnv, options, scalarType, storageClass);
-  if (elementType)
-    return VectorType::get(type.getShape(), elementType);
   return nullptr;
 }
 
@@ -711,8 +731,9 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
   SmallVector<ArrayRef<spirv::Capability>, 2> caps;
   scalarType.getExtensions(exts);
   scalarType.getCapabilities(caps);
-  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
-      failed(checkExtensionRequirements(type, targetEnv, exts))) {
+
+  if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps,
+                                                     exts))) {
     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
     return castOp.getResult(0);
   }
@@ -1205,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
   for (Type valueType : valueTypes) {
-    typeExtensions.clear();
-    cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
-    if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
-                                          typeExtensions)))
-      return false;
-
     typeCapabilities.clear();
     cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
-    if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
-                                           typeCapabilities)))
+    typeExtensions.clear();
+    cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
+    // Checking for capability and extension requirements along with capability
+    // infered extensions.
+    // If a capability is present, the extension that
+    // supports it should also be present, this reduces the burden of adding
+    // extension requirement that may or maynot be added in
+    // CompositeType::getExtensions().
+    if (failed(checkCapabilityAndExtensionRequirements(
+            op->getName(), this->targetEnv, typeCapabilities, typeExtensions)))
       return false;
   }
 

>From f1b971b7bd199e8e1519c311d48096f4e747996b Mon Sep 17 00:00:00 2001
From: Md Abdullah Shahneous Bari <Md.Abdullah.Shahneous.Bari at intel.com>
Date: Mon, 2 Oct 2023 11:25:52 -0700
Subject: [PATCH 4/4] [mlir][spirv] Add support for VectorAnyINTEL capability

Allow vector of any lengths between [2-2^32-1].
VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension)
relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 11 +++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  7 +++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 22 +++++++---
 .../arith-to-spirv-unsupported.mlir           |  4 +-
 .../ArithToSPIRV/arith-to-spirv.mlir          | 40 ++++++++++++++++++
 .../FuncToSPIRV/types-to-spirv.mlir           | 17 +++++++-
 mlir/test/Dialect/SPIRV/IR/bit-ops.mlir       |  6 +--
 mlir/test/Dialect/SPIRV/IR/gl-ops.mlir        |  2 +-
 mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir |  4 +-
 mlir/test/Dialect/SPIRV/IR/logical-ops.mlir   |  2 +-
 mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir       | 42 ++++++++++---------
 mlir/test/Target/SPIRV/arithmetic-ops.mlir    |  6 +--
 mlir/test/Target/SPIRV/ocl-ops.mlir           |  6 +++
 13 files changed, 124 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1013cbc8ca562b7..c458a500eb367f9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
 def SPIRV_Float32 : TypeAlias<F32, "Float32">;
 def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
 def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
-def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
+// Remove the vector size restriction.
+// Although the vector size can be upto (2^64-1), uint64,
+// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose
+// for all practical cases.
+// Also unsigned is used for the number elements for composite tyeps.
+def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
                                        [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
 // Component type check is done in the type parser for the following SPIR-V
 // dialect-specific types so we use "Any" here.
@@ -4206,10 +4211,10 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
     "Joint Matrix">;
 
 class SPIRV_ScalarOrVectorOf<Type type> :
-    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
+    AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>]>;
 
 class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
-    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+    AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>,
                SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
 
 class SPIRV_MatrixOrCoopMatrixOf<Type type> :
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a51d77dda78bf2f..be85d3c330a887a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -184,9 +184,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
       return Type();
     }
-    if (t.getNumElements() > 4) {
+    // Number of elements should be between [2 - 2^32 -1],
+    // since getNumElements() returns an unsigned, the upper limit check is
+    // unnecessary.
+    if (t.getNumElements() < 2) {
       parser.emitError(
-          typeLoc, "vector length has to be less than or equal to 4 but found ")
+          typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
           << t.getNumElements();
       return Type();
     }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 39d6603a46f965d..9d39d99b4148253 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) {
 }
 
 bool CompositeType::isValid(VectorType type) {
-  return type.getRank() == 1 &&
-         llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
-         llvm::isa<ScalarType>(type.getElementType());
+  // Number of elements should be between [2 - 2^32 -1],
+  // since getNumElements() returns an unsigned, the upper limit check is
+  // unnecessary.
+  return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType()) &&
+         type.getNumElements() >= 2;
 }
 
 Type CompositeType::getElementType(unsigned index) const {
@@ -171,9 +173,17 @@ void CompositeType::getCapabilities(
       .Case<VectorType>([&](VectorType type) {
         auto vecSize = getNumElements();
         if (vecSize == 8 || vecSize == 16) {
-          static const Capability caps[] = {Capability::Vector16};
-          ArrayRef<Capability> ref(caps, std::size(caps));
-          capabilities.push_back(ref);
+          static constexpr Capability caps[] = {Capability::Vector16,
+                                                Capability::VectorAnyINTEL};
+          capabilities.push_back(caps);
+        }
+        // VectorAnyINTEL capability removes the vector size restriction and
+        // allows the vector size to be up to (2^32-1).
+        // Vector16 capability allows the vector size to be 8 and 16
+        SmallVector<unsigned, 5> allowedVecRange = {2, 3, 4, 8, 16};
+        if (vecSize >= 2 && !llvm::is_contained(allowedVecRange, vecSize)) {
+          static constexpr Capability caps[] = {Capability::VectorAnyINTEL};
+          capabilities.push_back(caps);
         }
         return llvm::cast<ScalarType>(type.getElementType())
             .getCapabilities(capabilities, storage);
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d8570..d61ace8d6876b87 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -11,9 +11,9 @@ module attributes {
     #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
 } {
 
-func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
+func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) {
   // expected-error at +1 {{failed to legalize operation 'arith.subi'}}
-  %1 = arith.subi %arg0, %arg0: vector<5xi32>
+  %1 = arith.subi %arg0, %arg1: vector<5xi32>
   return
 }
 
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 0221e4815a9397d..6ceeade486efd68 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
 }
 
 } // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// VectorAnyINTEL support
+//===----------------------------------------------------------------------===//
+
+// Check that with VectorAnyINTEL, VectorComputeINTEL capability,
+// and SPV_INTEL_vector_compute extension, any sized (2-2^32 -1) vector is allowed.
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel, VectorAnyINTEL], [SPV_INTEL_vector_compute]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @any_vector
+func.func @any_vector(%arg0: vector<16xi32>, %arg1: vector<16xi32>) {
+  // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<16xi32>
+  %0 = arith.subi %arg0, %arg1: vector<16xi32>
+  return
+}
+
+// CHECK-LABEL: @max_vector
+func.func @max_vector(%arg0: vector<4294967295xi32>, %arg1: vector<4294967295xi32>) {
+  // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<4294967295xi32>
+  %0 = arith.subi %arg0, %arg1: vector<4294967295xi32>
+  return
+}
+
+
+// Check float vector types of any size.
+// CHECK-LABEL: @float_vector58
+func.func @float_vector58(%arg0: vector<5xf16>, %arg1: vector<8xf64>) {
+  // CHECK: spirv.FAdd %{{.*}}, %{{.*}}: vector<5xf16>
+  %0 = arith.addf %arg0, %arg0: vector<5xf16>
+  // CHECK: spirv.FMul %{{.*}}, %{{.*}}: vector<8xf64>
+  %1 = arith.mulf %arg1, %arg1: vector<8xf64>
+  return
+}
+
+} // end module
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 82d750755ffe2e8..6f364c5b0875c8b 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -351,8 +351,21 @@ module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
 } {
 
-// CHECK-NOT: spirv.func @large_vector
-func.func @large_vector(%arg0: vector<1024xi32>) { return }
+// CHECK-NOT: spirv.func @large_vector_unsupported
+func.func @large_vector_unsupported(%arg0: vector<1024xi32>) { return }
+
+} // end module
+
+
+// -----
+
+// Check that large vectors are supported with VectorAnyINTEL or VectorComputeINTEL.
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Float16, Kernel, VectorAnyINTEL], [SPV_INTEL_vector_compute]>, #spirv.resource_limits<>>
+} {
+
+// CHECK: spirv.func @large_any_vector
+func.func @large_any_vector(%arg0: vector<1024xi32>) { return }
 
 } // end module
 
diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
index 82a2316f6c784fb..88a8e507c1993a6 100644
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> {
 // -----
 
 func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
-  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}}
   %0 = spirv.BitwiseOr %arg0, %arg1 : f16
   return %0 : f16
 }
@@ -163,7 +163,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
 // -----
 
 func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 {
-  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}}
   %0 = spirv.BitwiseXor %arg0, %arg1 : f16
   return %0 : f16
 }
@@ -272,7 +272,7 @@ func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> {
 // -----
 
 func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 {
-  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2-4294967295}}
   %0 = spirv.BitwiseAnd %arg0, %arg1 : f16
   return %0 : f16
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index 3683e5b469b17b3..a95a6001fd20433 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () {
 // -----
 
 func.func @exp(%arg0 : vector<5xf32>) -> () {
-  // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
+  // CHECK: spirv.GL.Exp {{%.*}} : vector<5xf32
   %2 = spirv.GL.Exp %arg0 : vector<5xf32>
   return
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index 53a1015de75bcc8..6929ef9b21d0ebd 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
 // -----
 
 spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
-  // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+  // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got}}
   %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
   spirv.Return
 }
@@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
 // -----
 
 spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
-  // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+  // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2-4294967295, but got}}
   %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16
   spirv.Return
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index 7dc0bd99f54b3a6..fa4d9e253307d48 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -166,7 +166,7 @@ func.func @logicalUnary(%arg0 : i1)
 
 func.func @logicalUnary(%arg0 : i32)
 {
-  // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+  // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2-4294967295, but got 'i32'}}
   %0 = spirv.LogicalNot %arg0 : i32
   return
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 29a4a46136156a9..24fe2f945841311 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -18,17 +18,17 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () {
 
 // -----
 
-func.func @exp(%arg0 : i32) -> () {
-  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
-  %2 = spirv.CL.exp %arg0 : i32
+func.func @exp_any_vec(%arg0 : vector<5xf32>) -> () {
+  // CHECK: spirv.CL.exp {{%.*}} : vector<5xf32>
+  %2 = spirv.CL.exp %arg0 : vector<5xf32>
   return
 }
 
 // -----
 
-func.func @exp(%arg0 : vector<5xf32>) -> () {
-  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}}
-  %2 = spirv.CL.exp %arg0 : vector<5xf32>
+func.func @exp(%arg0 : i32) -> () {
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %2 = spirv.CL.exp %arg0 : i32
   return
 }
 
@@ -66,6 +66,14 @@ func.func @fabsvec(%arg0 : vector<3xf16>) -> () {
   return
 }
 
+// -----
+
+func.func @fabs_any_vec(%arg0 : vector<5xf32>) -> () {
+  // CHECK: spirv.CL.fabs {{%.*}} : vector<5xf32>
+  %2 = spirv.CL.fabs %arg0 : vector<5xf32>
+  return
+}
+
 func.func @fabsf64(%arg0 : f64) -> () {
   // CHECK: spirv.CL.fabs {{%.*}} : f64
   %2 = spirv.CL.fabs %arg0 : f64
@@ -82,14 +90,6 @@ func.func @fabs(%arg0 : i32) -> () {
 
 // -----
 
-func.func @fabs(%arg0 : vector<5xf32>) -> () {
-  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}}
-  %2 = spirv.CL.fabs %arg0 : vector<5xf32>
-  return
-}
-
-// -----
-
 func.func @fabs(%arg0 : f32, %arg1 : f32) -> () {
   // expected-error @+1 {{expected ':'}}
   %2 = spirv.CL.fabs %arg0, %arg1 : i32
@@ -122,6 +122,14 @@ func.func @sabsvec(%arg0 : vector<3xi16>) -> () {
   return
 }
 
+// -----
+
+func.func @sabs_any_vec(%arg0 : vector<5xi32>) -> () {
+  // CHECK: spirv.CL.s_abs {{%.*}} : vector<5xi32>
+  %2 = spirv.CL.s_abs %arg0 : vector<5xi32>
+  return
+}
+
 func.func @sabsi64(%arg0 : i64) -> () {
   // CHECK: spirv.CL.s_abs {{%.*}} : i64
   %2 = spirv.CL.s_abs %arg0 : i64
@@ -142,13 +150,7 @@ func.func @sabs(%arg0 : f32) -> () {
   return
 }
 
-// -----
 
-func.func @sabs(%arg0 : vector<5xi32>) -> () {
-  // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
-  %2 = spirv.CL.s_abs %arg0 : vector<5xi32>
-  return
-}
 
 // -----
 
diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
index b1ea13c6854fd7f..90144afc6f3af73 100644
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
@@ -6,9 +6,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.FMul %arg0, %arg1 : f32
     spirv.Return
   }
-  spirv.func @fadd(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" {
-    // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : vector<4xf32>
-    %0 = spirv.FAdd %arg0, %arg1 : vector<4xf32>
+  spirv.func @fadd(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : vector<5xf32>
+    %0 = spirv.FAdd %arg0, %arg1 : vector<5xf32>
     spirv.Return
   }
   spirv.func @fdiv(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" {
diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir
index 9a2e4cf62e370be..31a7f616d648e86 100644
--- a/mlir/test/Target/SPIRV/ocl-ops.mlir
+++ b/mlir/test/Target/SPIRV/ocl-ops.mlir
@@ -39,6 +39,12 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
     spirv.Return
   }
 
+  spirv.func @vector_anysize(%arg0 : vector<5000xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.CL.fabs {{%.*}} : vector<5000xf32>
+    %0 = spirv.CL.fabs %arg0 : vector<5000xf32>
+    spirv.Return
+  }
+
   spirv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" {
     // CHECK: spirv.CL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
     %13 = spirv.CL.fma %arg0, %arg1, %arg2 : f32



More information about the Mlir-commits mailing list