[Mlir-commits] [mlir] [mlir][spirv] Add support for VectorAnyINTEL capability (PR #68034)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 2 13:41:37 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
<details>
<summary>Changes</summary>
Allow vector of any lengths between [2-2^32-1].
VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension) relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4.
---
Patch is 44.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68034.diff
16 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+8-3)
- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+70)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+5-2)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+16-6)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+102-24)
- (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir (+2-2)
- (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+40)
- (modified) mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir (+15-2)
- (modified) mlir/test/Conversion/GPUToSPIRV/reductions.mlir (+32-32)
- (modified) mlir/test/Dialect/SPIRV/IR/bit-ops.mlir (+3-3)
- (modified) mlir/test/Dialect/SPIRV/IR/gl-ops.mlir (+1-1)
- (modified) mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir (+2-2)
- (modified) mlir/test/Dialect/SPIRV/IR/logical-ops.mlir (+1-1)
- (modified) mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir (+22-20)
- (modified) mlir/test/Target/SPIRV/arithmetic-ops.mlir (+3-3)
- (modified) mlir/test/Target/SPIRV/ocl-ops.mlir (+6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1013cbc8ca562b7..c458a500eb367f9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
-def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
+// Remove the vector size restriction.
+// Although the vector size can be upto (2^64-1), uint64,
+// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose
+// for all practical cases.
+// Also unsigned is used for the number elements for composite tyeps.
+def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
@@ -4206,10 +4211,10 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
"Joint Matrix">;
class SPIRV_ScalarOrVectorOf<Type type> :
- AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
+ AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>]>;
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
- AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+ AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>,
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
class SPIRV_MatrixOrCoopMatrixOf<Type type> :
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4fc14e30b8a10d0..703122547df7493 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -546,6 +546,76 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
+// Whether the number of elements of a vector is from the given
+// `allowedRanges` list, the list has two values, start and end
+// of the range (inclusive).
+class IsVectorOfLengthRangePred<list<int> allowedRanges>
+ : And<[IsVectorTypePred,
+ And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()>= }] # allowedRanges[0]>,
+ CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Whether the number of elements of a fixed-length vector is from the given
+// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
+class IsFixedVectorOfLengthRangePred<list<int> allowedRanges>
+ : And<[IsFixedVectorTypePred,
+ And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
+ CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Whether the number of elements of a scalable vector is from the given
+// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
+class IsScalableVectorOfLengthRangePred<list<int> allowedRanges>
+ : And<[IsScalableVectorTypePred,
+ And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
+ CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Any vector where the number of elements is from the given
+// `allowedRanges` list.
+class VectorOfLengthRange<list<int> allowedRanges>
+ : Type<IsVectorOfLengthRangePred<allowedRanges>,
+ " of length " # !interleave(allowedRanges, "-"),
+ "::mlir::VectorType">;
+
+// Any fixed-length vector where the number of elements is from the given
+// `allowedRanges` list.
+class FixedVectorOfLengthRange<list<int> allowedRanges>
+ : Type<IsFixedVectorOfLengthRangePred<allowedRanges>,
+ " of length " # !interleave(allowedRanges, "-"),
+ "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedRanges` list.
+class ScalableVectorOfLengthRange<list<int> allowedRanges>
+ : Type<IsScalableVectorOfLengthRangePred<allowedRanges>,
+ " of length " # !interleave(allowedRanges, "-"),
+ "::mlir::VectorType">;
+
+// Any vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class VectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+ : Type<And<[VectorOf<allowedTypes>.predicate, VectorOfLengthRange<allowedRanges>.predicate]>,
+ VectorOf<allowedTypes>.summary # VectorOfLengthRange<allowedRanges>.summary,
+ "::mlir::VectorType">;
+
+// Any fixed-length vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class FixedVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+ : Type<
+ And<[FixedVectorOf<allowedTypes>.predicate, FixedVectorOfLengthRange<allowedRanges>.predicate]>,
+ FixedVectorOf<allowedTypes>.summary # FixedVectorOfLengthRange<allowedRanges>.summary,
+ "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class ScalableVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+ : Type<
+ And<[ScalableVectorOf<allowedTypes>.predicate, ScalableVectorOfLengthRange<allowedRanges>.predicate]>,
+ ScalableVectorOf<allowedTypes>.summary # ScalableVectorOfLengthRange<allowedRanges>.summary,
+ "::mlir::VectorType">;
+
+
def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a51d77dda78bf2f..be85d3c330a887a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -184,9 +184,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
return Type();
}
- if (t.getNumElements() > 4) {
+ // Number of elements should be between [2 - 2^32 -1],
+ // since getNumElements() returns an unsigned, the upper limit check is
+ // unnecessary.
+ if (t.getNumElements() < 2) {
parser.emitError(
- typeLoc, "vector length has to be less than or equal to 4 but found ")
+ typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
<< t.getNumElements();
return Type();
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 39d6603a46f965d..9d39d99b4148253 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) {
}
bool CompositeType::isValid(VectorType type) {
- return type.getRank() == 1 &&
- llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
- llvm::isa<ScalarType>(type.getElementType());
+ // Number of elements should be between [2 - 2^32 -1],
+ // since getNumElements() returns an unsigned, the upper limit check is
+ // unnecessary.
+ return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType()) &&
+ type.getNumElements() >= 2;
}
Type CompositeType::getElementType(unsigned index) const {
@@ -171,9 +173,17 @@ void CompositeType::getCapabilities(
.Case<VectorType>([&](VectorType type) {
auto vecSize = getNumElements();
if (vecSize == 8 || vecSize == 16) {
- static const Capability caps[] = {Capability::Vector16};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
+ static constexpr Capability caps[] = {Capability::Vector16,
+ Capability::VectorAnyINTEL};
+ capabilities.push_back(caps);
+ }
+ // VectorAnyINTEL capability removes the vector size restriction and
+ // allows the vector size to be up to (2^32-1).
+ // Vector16 capability allows the vector size to be 8 and 16
+ SmallVector<unsigned, 5> allowedVecRange = {2, 3, 4, 8, 16};
+ if (vecSize >= 2 && !llvm::is_contained(allowedVecRange, vecSize)) {
+ static constexpr Capability caps[] = {Capability::VectorAnyINTEL};
+ capabilities.push_back(caps);
}
return llvm::cast<ScalarType>(type.getElementType())
.getCapabilities(capabilities, storage);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9e09..25e6a080642e681 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -43,9 +43,13 @@ using namespace mlir;
template <typename LabelT>
static LogicalResult checkExtensionRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
- const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
+ const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
+ const ArrayRef<spirv::Extension> elidedCandidates = {}) {
for (const auto &ors : candidates) {
- if (targetEnv.allows(ors))
+ if (targetEnv.allows(ors) ||
+ llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) {
+ return llvm::is_contained(ors, elidedExt);
+ }))
continue;
LLVM_DEBUG({
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
template <typename LabelT>
static LogicalResult checkCapabilityRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
- const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
+ const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
+ const ArrayRef<spirv::Capability> elidedCandidates = {}) {
for (const auto &ors : candidates) {
- if (targetEnv.allows(ors))
+ if (targetEnv.allows(ors) ||
+ llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) {
+ return llvm::is_contained(ors, elidedCap);
+ }))
continue;
LLVM_DEBUG({
@@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements(
return success();
}
-/// Returns true if the given `storageClass` needs explicit layout when used in
-/// Shader environments.
+/// Check capabilities and extensions requirements
+/// Checks that `capCandidates`, `extCandidates`, and capability
+/// (`capCandidates`) infered extension requirements are possible to be
+/// satisfied with the given `targetEnv`.
+/// It also provides a way to relax requirements for certain capabilities and
+/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to
+/// allow passes to relax certain requirements based on an option (e.g.,
+/// relaxing bitwidth requirement, see `convertScalarType()`,
+/// `ConvertVectorType()`).
+template <typename LabelT>
+static LogicalResult checkCapabilityAndExtensionRequirements(
+ LabelT label, const spirv::TargetEnv &targetEnv,
+ const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates,
+ const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates,
+ const ArrayRef<spirv::Capability> elidedCapCandidates = {},
+ const ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
+ SmallVector<ArrayRef<spirv::Extension>, 8> updatedExtCandidates;
+ llvm::append_range(updatedExtCandidates, extCandidates);
+
+ if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates,
+ elidedCapCandidates)))
+ return failure();
+ // Add capablity infered extensions to the list of extension requirement list,
+ // only considers the capabilities that already available in the `targetEnv`.
+
+ // WARNING: Some capabilities are part of both the core SPIR-V
+ // specification and an extension (e.g., 'Groups' capability is part of both
+ // core specification and SPV_AMD_shader_ballot extension, hence we should
+ // relax the capability inferred extension for these cases).
+ static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups};
+ ArrayRef<spirv::Capability> multiModalCapsArrayRef(multiModalCaps,
+ std::size(multiModalCaps));
+
+ for (auto cap : targetEnv.getAttr().getCapabilities()) {
+ if (llvm::any_of(multiModalCapsArrayRef,
+ [&cap](spirv::Capability mMCap) { return cap == mMCap; }))
+ continue;
+ std::optional<ArrayRef<spirv::Extension>> ext = getExtensions(cap);
+ if (ext)
+ updatedExtCandidates.push_back(*ext);
+ }
+ if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates,
+ elidedExtCandidates)))
+ return failure();
+ return success();
+}
+
+/// Returns true if the given `storageClass` needs explicit layout when used
+/// in Shader environments.
static bool needsExplicitLayout(spirv::StorageClass storageClass) {
switch (storageClass) {
case spirv::StorageClass::PhysicalStorageBuffer:
@@ -230,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
type.getCapabilities(capabilities, storageClass);
// If all requirements are met, then we can accept this type as-is.
- if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
- succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+ if (succeeded(checkCapabilityAndExtensionRequirements(
+ type, targetEnv, capabilities, extensions)))
return type;
// Otherwise we need to adjust the type, which really means adjusting the
@@ -342,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
+ // If the bit-width related capabilities and extensions are not met
+ // for lower bit-width (<32-bit), convert it to 32-bit
+ auto elementType =
+ convertScalarType(targetEnv, options, scalarType, storageClass);
+ if (!elementType)
+ return nullptr;
+ type = VectorType::get(type.getShape(), elementType);
+
+ SmallVector<spirv::Capability, 4> elidedCaps;
+ SmallVector<spirv::Extension, 4> elidedExts;
+
+ // Relax the bitwidth requirements for capabilities and extensions
+ if (options.emulateLT32BitScalarTypes) {
+ elidedCaps.push_back(spirv::Capability::Int8);
+ elidedCaps.push_back(spirv::Capability::Int16);
+ elidedCaps.push_back(spirv::Capability::Float16);
+ }
+ // For capabilities whose requirements were relaxed, relax requirements for
+ // the extensions that were infered by those capabilities (e.g., elidedCaps)
+ for (spirv::Capability cap : elidedCaps) {
+ std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
+ if (ext)
+ llvm::append_range(elidedExts, *ext);
+ }
// If all requirements are met, then we can accept this type as-is.
- if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
- succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+ if (succeeded(checkCapabilityAndExtensionRequirements(
+ type, targetEnv, capabilities, extensions, elidedCaps, elidedExts)))
return type;
- auto elementType =
- convertScalarType(targetEnv, options, scalarType, storageClass);
- if (elementType)
- return VectorType::get(type.getShape(), elementType);
return nullptr;
}
@@ -656,8 +731,9 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
SmallVector<ArrayRef<spirv::Capability>, 2> caps;
scalarType.getExtensions(exts);
scalarType.getCapabilities(caps);
- if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
- failed(checkExtensionRequirements(type, targetEnv, exts))) {
+
+ if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps,
+ exts))) {
auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return castOp.getResult(0);
}
@@ -1150,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
- typeExtensions.clear();
- cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
- if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
- typeExtensions)))
- return false;
-
typeCapabilities.clear();
cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
- if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
- typeCapabilities)))
+ typeExtensions.clear();
+ cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
+ // Checking for capability and extension requirements along with capability
+ // infered extensions.
+ // If a capability is present, the extension that
+ // supports it should also be present, this reduces the burden of adding
+ // extension requirement that may or maynot be added in
+ // CompositeType::getExtensions().
+ if (failed(checkCapabilityAndExtensionRequirements(
+ op->getName(), this->targetEnv, typeCapabilities, typeExtensions)))
return false;
}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d8570..d61ace8d6876b87 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -11,9 +11,9 @@ module attributes {
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
} {
-func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
+func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) {
// expected-error at +1 {{failed to legalize operation 'arith.subi'}}
- %1 = arith.subi %arg0, %arg0: vector<5xi32>
+ %1 = arith.subi %arg0, %arg1: vector<5xi32>
return
}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 0221e4815a9397d..6ceeade486efd68 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
}
} // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// VectorAnyINTEL support
+//===----------------------------------------------------------------------===//
+
+// Check that with VectorAnyINTEL, VectorComputeINTEL capability,
+// and SPV_INTEL_vector_compute extension, any sized (2-2^32 -1) vector is allowed.
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel, VectorAnyINTEL], [SPV_INTEL_vector_compute]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @any_vector
+func.func @any_vector(%arg0: vector<16xi32>, %arg1: vector<16xi32>) {
+ // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<16xi32>
+ %0 = arith.subi %arg0, %arg1: vector<16xi32>
+ return
+}
+
+// CHECK-LABEL: @max_vector
+func.func @max_vector(%arg0: vector<4294967295xi32>, %arg1: vector<4294967295xi32>) {
+ // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<4294967295xi32>
+ %0 = arith.subi %arg0, %arg1: vector<4294967295xi32>
+ return
+}
+
+
+// Check float vector types of any size.
+// CHECK-LABEL: @float_vector58
+func.func @float_vector58(%arg0: vector<5xf16>, %arg1: vector<8xf64>) {
+ // CHECK: spirv.FAdd %{{.*}}, %{{.*}}: vector<5xf16>
+ %0 = arith.addf %arg0, %arg0: vector<5xf16>
+ // CHECK: spirv.FMul %{{.*}}, %{{.*}}: vector<8xf64>
+ %1 = arith.mulf %a...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/68034
More information about the Mlir-commits
mailing list