[Mlir-commits] [mlir] 0712eac - [mlir][spirv] Enable composite instructions for cooperative matrix type.

Thomas Raoux llvmlistbot at llvm.org
Thu May 21 12:20:46 PDT 2020


Author: Thomas Raoux
Date: 2020-05-21T12:19:55-07:00
New Revision: 0712eac76616a088f1f1183399049560e69c3506

URL: https://github.com/llvm/llvm-project/commit/0712eac76616a088f1f1183399049560e69c3506
DIFF: https://github.com/llvm/llvm-project/commit/0712eac76616a088f1f1183399049560e69c3506.diff

LOG: [mlir][spirv] Enable composite instructions for cooperative matrix type.

Enable inset/extract/construct composite ops as well as access chain for
cooperative matrix. ConstantComposite requires more change and will be done in
a separate patch. Also fix the getNumElements function for coopMatrix per
feedback from Jeff Bolz. The number of element is implementation dependent so
it cannot be known at compile time.

Differential Revision: https://reviews.llvm.org/D80321

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
    mlir/test/Dialect/SPIRV/composite-ops.mlir
    mlir/test/Dialect/SPIRV/cooperative-matrix.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index a3a2c2bec43b..ead6c0341cd6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3027,10 +3027,12 @@ def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
 def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
 def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
 def SPV_Composite :
-    AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
+    AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
+               SPV_AnyCooperativeMatrix]>;
 def SPV_Type : AnyTypeOf<[
     SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
-    SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct
+    SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
+    SPV_AnyCooperativeMatrix
   ]>;
 
 def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 078fb5a67225..71eba72e5e84 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -134,10 +134,16 @@ class CompositeType : public SPIRVType {
   /// Returns true if the given vector type is valid for the SPIR-V dialect.
   static bool isValid(VectorType);
 
+  /// Return the number of elements of the type. This should only be called if
+  /// hasCompileTimeKnownNumElements is true.
   unsigned getNumElements() const;
 
   Type getElementType(unsigned) const;
 
+  /// Return true if the number of elements is known at compile time and is not
+  /// implementation dependent.
+  bool hasCompileTimeKnownNumElements() const;
+
   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                      Optional<spirv::StorageClass> storage = llvm::None);
   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
@@ -334,7 +340,7 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
 
 // SPIR-V cooperative matrix type
 class CooperativeMatrixNVType
-    : public Type::TypeBase<CooperativeMatrixNVType, SPIRVType,
+    : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
                             detail::CooperativeMatrixTypeStorage> {
 public:
   using Base::Base;

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 630f09842ccd..4f48ef9d7d7c 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -418,7 +418,9 @@ getElementType(Type type, ArrayRef<int32_t> indices,
 
   for (auto index : indices) {
     if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
-      if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
+      if (cType.hasCompileTimeKnownNumElements() &&
+          (index < 0 ||
+           static_cast<uint64_t>(index) >= cType.getNumElements())) {
         emitErrorFn("index ") << index << " out of bounds for " << type;
         return nullptr;
       }
@@ -1098,7 +1100,8 @@ static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
            << type;
   }
 
-  if (operands.size() != cType.getNumElements()) {
+  if (cType.hasCompileTimeKnownNumElements() &&
+      operands.size() != cType.getNumElements()) {
     return parser.emitError(loc, "has incorrect number of operands: expected ")
            << cType.getNumElements() << ", but provided " << operands.size();
   }
@@ -1107,8 +1110,8 @@ static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
   // also be vectors with the same component type as the Result Type component
   // type".
   SmallVector<Type, 4> elementTypes;
-  elementTypes.reserve(cType.getNumElements());
-  for (auto index : llvm::seq<uint32_t>(0, cType.getNumElements())) {
+  elementTypes.reserve(operands.size());
+  for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
     elementTypes.push_back(cType.getElementType(index));
   }
   state.addTypes(type);
@@ -1124,13 +1127,19 @@ static void print(spirv::CompositeConstructOp compositeConstructOp,
 
 static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
   auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();
-
   SmallVector<Value, 4> constituents(compositeConstructOp.constituents());
-  if (constituents.size() != cType.getNumElements()) {
-    return compositeConstructOp.emitError(
-               "has incorrect number of operands: expected ")
-           << cType.getNumElements() << ", but provided "
-           << constituents.size();
+
+  if (cType.isa<spirv::CooperativeMatrixNVType>()) {
+    if (constituents.size() != 1)
+      return compositeConstructOp.emitError(
+                 "has incorrect number of operands: expected ")
+             << "1, but provided " << constituents.size();
+  } else {
+    if (constituents.size() != cType.getNumElements())
+      return compositeConstructOp.emitError(
+                 "has incorrect number of operands: expected ")
+             << cType.getNumElements() << ", but provided "
+             << constituents.size();
   }
 
   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index ce5a6c0c4fd9..49b39ec78435 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -196,8 +196,8 @@ unsigned CompositeType::getNumElements() const {
   case spirv::TypeKind::Array:
     return cast<ArrayType>().getNumElements();
   case spirv::TypeKind::CooperativeMatrix:
-    return cast<CooperativeMatrixNVType>().getRows() *
-           cast<CooperativeMatrixNVType>().getColumns();
+    llvm_unreachable(
+        "invalid to query number of elements of spirv::CooperativeMatrix type");
   case spirv::TypeKind::RuntimeArray:
     llvm_unreachable(
         "invalid to query number of elements of spirv::RuntimeArray type");
@@ -210,6 +210,16 @@ unsigned CompositeType::getNumElements() const {
   }
 }
 
+bool CompositeType::hasCompileTimeKnownNumElements() const {
+  switch (getKind()) {
+  case TypeKind::CooperativeMatrix:
+  case TypeKind::RuntimeArray:
+    return false;
+  default:
+    return true;
+  }
+}
+
 void CompositeType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     Optional<StorageClass> storage) {

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
index 6fb58d859d1f..12f710ea1b46 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
@@ -91,4 +91,12 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_N
     %r = spv.FDiv %a, %b : !spv.coopmatrix<8x16xf32, Subgroup>
     spv.Return
   }
+
+  // CHECK-LABEL: @cooperative_matrix_access_chain
+  spv.func @cooperative_matrix_access_chain(%a : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
+    %0 = spv.constant 0: i32
+    // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>
+    %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>
+    spv.ReturnValue %1 : !spv.ptr<f32, Function>
+  }
 }

diff  --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir
index 556bed823155..ca3f60311576 100644
--- a/mlir/test/Dialect/SPIRV/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir
@@ -20,6 +20,14 @@ func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spv.array<4xf32>,
 
 // -----
 
+func @composite_construct_coopmatrix(%arg0 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> {
+  // CHECK: spv.CompositeConstruct {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup>
+  %0 = spv.CompositeConstruct %arg0 : !spv.coopmatrix<8x16xf32, Subgroup>
+  return %0: !spv.coopmatrix<8x16xf32, Subgroup>
+}
+
+// -----
+
 func @composite_construct_empty_struct() -> !spv.struct<> {
   // CHECK: spv.CompositeConstruct : !spv.struct<>
   %0 = spv.CompositeConstruct : !spv.struct<>
@@ -52,6 +60,14 @@ func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg2 : f
 
 // -----
 
+func @composite_construct_coopmatrix(%arg0 : f32, %arg1 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> {
+  // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
+  %0 = spv.CompositeConstruct %arg0, %arg1 : !spv.coopmatrix<8x16xf32, Subgroup>
+  return %0: !spv.coopmatrix<8x16xf32, Subgroup>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.CompositeExtractOp
 //===----------------------------------------------------------------------===//
@@ -80,6 +96,14 @@ func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 {
 
 // -----
 
+func @composite_extract_coopmatrix(%arg0 : !spv.coopmatrix<8x16xf32, Subgroup>) -> f32 {
+  // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[2 : i32] : !spv.coopmatrix<8x16xf32, Subgroup>
+  %0 = spv.CompositeExtract %arg0[2 : i32] : !spv.coopmatrix<8x16xf32, Subgroup>
+  return %0 : f32
+}
+
+// -----
+
 func @composite_extract_no_ssa_operand() -> () {
   // expected-error @+1 {{expected SSA operand}}
   %0 = spv.CompositeExtract [4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>>
@@ -200,6 +224,14 @@ func @composite_insert_struct(%arg0: !spv.struct<!spv.array<4xf32>, f32>, %arg1:
 
 // -----
 
+func @composite_insert_coopmatrix(%arg0: !spv.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spv.coopmatrix<8x16xi32, Subgroup> {
+  // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spv.coopmatrix<8x16xi32, Subgroup>
+  %0 = spv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spv.coopmatrix<8x16xi32, Subgroup>
+  return %0: !spv.coopmatrix<8x16xi32, Subgroup>
+}
+
+// -----
+
 func @composite_insert_no_indices(%arg0: !spv.array<4xf32>, %arg1: f32) -> !spv.array<4xf32> {
   // expected-error @+1 {{expected at least one index}}
   %0 = spv.CompositeInsert %arg1, %arg0[] : f32 into !spv.array<4xf32>

diff  --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
index 0b05d8a587e5..e30352625da6 100644
--- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
@@ -94,6 +94,16 @@ spv.func @cooperative_matrix_fdiv(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b :
 
 // -----
 
+// CHECK-LABEL: @cooperative_matrix_access_chain
+spv.func @cooperative_matrix_access_chain(%a : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
+  %0 = spv.constant 0: i32
+  // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>
+  %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>
+  spv.ReturnValue %1 : !spv.ptr<f32, Function>
+}
+
+// -----
+
 spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<16x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
   // expected-error @+1 {{'spv.CooperativeMatrixMulAddNV' op matrix size must match}}
   %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<16x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>


        


More information about the Mlir-commits mailing list