[Mlir-commits] [mlir] 1938b61 - [mlir][spirv] Allow usage of vector size 8 and 16 with Vector16 capability

Artur Bialas llvmlistbot at llvm.org
Wed Nov 4 23:27:26 PST 2020


Author: Artur Bialas
Date: 2020-11-05T08:26:15+01:00
New Revision: 1938b61bda50f0117e6b8bbc02e42d065d59b4f9

URL: https://github.com/llvm/llvm-project/commit/1938b61bda50f0117e6b8bbc02e42d065d59b4f9
DIFF: https://github.com/llvm/llvm-project/commit/1938b61bda50f0117e6b8bbc02e42d065d59b4f9.diff

LOG: [mlir][spirv] Allow usage of vector size 8 and 16 with Vector16 capability

Per spec, vector sizes 8 and 16 are allowed when Vector16 capability is present.
This change expands the limitation of vector sizes to accept these sizes.

Differential Revision: https://reviews.llvm.org/D90683

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir
    mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
    mlir/test/Dialect/SPIRV/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 7390e2d70f6c..cc23969dea5d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3045,7 +3045,7 @@ def SPV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
 def SPV_Int32 : TypeAlias<I32, "Int32">;
 def SPV_Float : FloatOfWidths<[16, 32, 64]>;
 def SPV_Float16or32 : FloatOfWidths<[16, 32]>;
-def SPV_Vector : VectorOfLengthAndType<[2, 3, 4],
+def SPV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
                                        [SPV_Bool, SPV_Integer, SPV_Float]>;
 // Component type check is done in the type parser for the following SPIR-V
 // dialect-specific types so we use "Any" here.
@@ -3083,10 +3083,10 @@ class SPV_CoopMatrixOfType<list<Type> allowedTypes> :
     "Cooperative Matrix">;
 
 class SPV_ScalarOrVectorOf<Type type> :
-    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>;
+    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
 
 class SPV_ScalarOrVectorOrCoopMatrixOf<Type type> :
-    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>,
+    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
                SPV_CoopMatrixOfType<[type]>]>;
 
 def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 6dc5a36ba178..c38f2fdbe785 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -172,8 +172,17 @@ bool CompositeType::classof(Type type) {
 }
 
 bool CompositeType::isValid(VectorType type) {
-  return type.getRank() == 1 && type.getElementType().isa<ScalarType>() &&
-         type.getNumElements() >= 2 && type.getNumElements() <= 4;
+  switch (type.getNumElements()) {
+  case 2:
+  case 3:
+  case 4:
+  case 8:
+  case 16:
+    break;
+  default:
+    return false;
+  }
+  return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
 }
 
 Type CompositeType::getElementType(unsigned index) const {
@@ -233,6 +242,12 @@ void CompositeType::getCapabilities(
             StructType>(
           [&](auto type) { type.getCapabilities(capabilities, storage); })
       .Case<VectorType>([&](VectorType type) {
+        auto vecSize = getNumElements();
+        if (vecSize == 8 || vecSize == 16) {
+          static const Capability caps[] = {Capability::Vector16};
+          ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
+          capabilities.push_back(ref);
+        }
         return type.getElementType().cast<ScalarType>().getCapabilities(
             capabilities, storage);
       })

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir
index 5130b4915096..2fd0476ad361 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir
@@ -14,4 +14,10 @@ spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Addresses], []> {
     %0 = spv.OCL.s_abs %arg0 : i32
     spv.Return
   }
+  
+  spv.func @vector_size16(%arg0 : vector<16xf32>) "None" {
+    // CHECK: {{%.*}} = spv.OCL.fabs {{%.*}} : vector<16xf32>
+    %0 = spv.OCL.fabs %arg0 : vector<16xf32>
+    spv.Return
+  }
 }

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index f0874f85e4f6..07d2d05aa741 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -121,6 +121,18 @@ spv.module Logical GLSL450 attributes {
   }
 }
 
+// Using 16-element vectors requires Vector16.
+// CHECK: requires #spv.vce<v1.0, [Vector16, Shader], []>
+spv.module Logical GLSL450 attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.3, [Shader, Vector16], []>, {}>
+} {
+  spv.func @iadd_v16_function(%val : vector<16xi32>) -> vector<16xi32> "None" {
+    %0 = spv.IAdd %val, %val : vector<16xi32>
+    spv.ReturnValue %0: vector<16xi32>
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Extension
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir
index affb6a004950..5722139511a1 100644
--- a/mlir/test/Dialect/SPIRV/ops.mlir
+++ b/mlir/test/Dialect/SPIRV/ops.mlir
@@ -843,7 +843,7 @@ func @logicalUnary(%arg0 : i1)
 
 func @logicalUnary(%arg0 : i32)
 {
-  // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4, but got 'i32'}}
+  // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}}
   %0 = spv.LogicalNot %arg0 : i32
   return
 }


        


More information about the Mlir-commits mailing list