[Mlir-commits] [mlir] [mlir][spirv] Support coop matrix in `spirv.CompositeConstruct` (PR #66399)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 14 09:46:42 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv
            
<details>
<summary>Changes</summary>
Also improve the documentation (code and website).
--
Full diff: https://github.com/llvm/llvm-project/pull/66399.diff

3 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td (+9-1) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+20-16) 
- (modified) mlir/test/Dialect/SPIRV/IR/composite-ops.mlir (+28-6) 


<pre>
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
index b8307b488af6fa5..8216814d9f99598 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
@@ -53,7 +53,15 @@ def SPIRV_CompositeConstructOp : SPIRV_Op&lt;&quot;CompositeConstruct&quot;, [Pure]&gt; {
     #### Example:
 
     ```mlir
-    %0 = spirv.CompositeConstruct %1, %2, %3 : vector&lt;3xf32&gt;
+    %a = spirv.CompositeConstruct %1, %2, %3 : vector&lt;3xf32&gt;
+    %b = spirv.CompositeConstruct %a, %1 : (vector&lt;3xf32&gt;, f32) -&gt; vector&lt;4xf32&gt;
+
+    %c = spirv.CompositeConstruct %1 :
+      !spirv.coopmatrix&lt;4x4xf32, Subgroup, MatrixA&gt;
+
+    %d = spirv.CompositeConstruct %a, %4, %5 :
+      (vector&lt;3xf32&gt;, !spirv.array&lt;4xf32&gt;, !spirv.struct&lt;(f32)&gt;) -&gt;
+        !spirv.struct&lt;(vector&lt;3xf32&gt;, !spirv.array&lt;4xf32&gt;, !spirv.struct&lt;(f32)&gt;)&gt;
     ```
   }];
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 1f07b0b9e85bff6..3906bf74ea72235 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -29,6 +29,7 @@
 #include &quot;mlir/IR/Operation.h&quot;
 #include &quot;mlir/IR/TypeUtilities.h&quot;
 #include &quot;mlir/Interfaces/FunctionImplementation.h&quot;
+#include &quot;mlir/Support/LogicalResult.h&quot;
 #include &quot;llvm/ADT/APFloat.h&quot;
 #include &quot;llvm/ADT/APInt.h&quot;
 #include &quot;llvm/ADT/ArrayRef.h&quot;
@@ -363,31 +364,35 @@ LogicalResult spirv::AddressOfOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::CompositeConstructOp::verify() {
-  auto cType = llvm::cast&lt;spirv::CompositeType&gt;(getType());
   operand_range constituents = this-&gt;getConstituents();
 
-  if (auto coopType = llvm::dyn_cast&lt;spirv::CooperativeMatrixNVType&gt;(cType)) {
-    if (constituents.size() != 1)
-      return emitOpError(&quot;has incorrect number of operands: expected &quot;)
-             &lt;&lt; &quot;1, but provided &quot; &lt;&lt; constituents.size();
-    if (coopType.getElementType() != constituents.front().getType())
-      return emitOpError(&quot;operand type mismatch: expected operand type &quot;)
-             &lt;&lt; coopType.getElementType() &lt;&lt; &quot;, but provided &quot;
-             &lt;&lt; constituents.front().getType();
-    return success();
-  }
+  // There are 4 cases with varying verification rules:
+  // 1. Cooperative Matrices (1 constituent)
+  // 2. Structs (1 constituent for each member)
+  // 3. Arrays (1 constituent for each array element)
+  // 4. Vectors (1 constituent (sub-)element for each vector element)
+
+  auto coopElementType =
+      llvm::TypeSwitch&lt;Type, Type&gt;(getType())
+          .Case&lt;spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
+                spirv::JointMatrixINTELType&gt;(
+              [](auto coopType) { return coopType.getElementType(); })
+          .Default([](Type) { return nullptr; });
 
-  if (auto jointType = llvm::dyn_cast&lt;spirv::JointMatrixINTELType&gt;(cType)) {
+  // Case 1. -- matrices.
+  if (coopElementType) {
     if (constituents.size() != 1)
       return emitOpError(&quot;has incorrect number of operands: expected &quot;)
              &lt;&lt; &quot;1, but provided &quot; &lt;&lt; constituents.size();
-    if (jointType.getElementType() != constituents.front().getType())
+    if (coopElementType != constituents.front().getType())
       return emitOpError(&quot;operand type mismatch: expected operand type &quot;)
-             &lt;&lt; jointType.getElementType() &lt;&lt; &quot;, but provided &quot;
+             &lt;&lt; coopElementType &lt;&lt; &quot;, but provided &quot;
              &lt;&lt; constituents.front().getType();
     return success();
   }
 
+  // Case 2./3./4. -- number of constituents matches the number of elements.
+  auto cType = llvm::cast&lt;spirv::CompositeType&gt;(getType());
   if (constituents.size() == cType.getNumElements()) {
     for (auto index : llvm::seq&lt;uint32_t&gt;(0, constituents.size())) {
       if (constituents[index].getType() != cType.getElementType(index)) {
@@ -399,8 +404,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
     return success();
   }
 
-  // If not constructing a cooperative matrix type, then we must be constructing
-  // a vector type.
+  // Case 4. -- check that all constituents add up tp the expected vector type.
   auto resultType = llvm::dyn_cast&lt;VectorType&gt;(cType);
   if (!resultType)
     return emitOpError(
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index ce7f6bc6118b316..2891513961d5e2a 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -4,22 +4,20 @@
 // spirv.CompositeConstruct
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: func @composite_construct_vector
 func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -&gt; vector&lt;3xf32&gt; {
   // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -&gt; vector&lt;3xf32&gt;
   %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -&gt; vector&lt;3xf32&gt;
   return %0: vector&lt;3xf32&gt;
 }
 
-// -----
-
+// CHECK-LABEL: func @composite_construct_struct
 func.func @composite_construct_struct(%arg0: vector&lt;3xf32&gt;, %arg1: !spirv.array&lt;4xf32&gt;, %arg2 : !spirv.struct&lt;(f32)&gt;) -&gt; !spirv.struct&lt;(vector&lt;3xf32&gt;, !spirv.array&lt;4xf32&gt;, !spirv.struct&lt;(f32)&gt;)&gt; {
   // CHECK: spirv.CompositeConstruct
   %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (vector&lt;3xf32&gt;, !spirv.array&lt;4xf32&gt;, !spirv.struct&lt;(f32)&gt;) -&gt; !spirv.struct&lt;(vector&lt;3xf32&gt;, !spirv.array&lt;4xf32&gt;, !spirv.struct&lt;(f32)&gt;)&gt;
   return %0: !spirv.struct&lt;(vector&lt;3xf32&gt;, !spirv.array&lt;4xf32&gt;, !spirv.struct&lt;(f32)&gt;)&gt;
 }
 
-// -----
-
 // CHECK-LABEL: func @composite_construct_mixed_scalar_vector
 func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector&lt;2xf32&gt;) -&gt; vector&lt;4xf32&gt; {
   // CHECK: spirv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector&lt;2xf32&gt;, f32) -&gt; vector&lt;4xf32&gt;
@@ -27,9 +25,15 @@ func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2
   return %0: vector&lt;4xf32&gt;
 }
 
-// -----
+// CHECK-LABEL: func @composite_construct_coopmatrix_khr
+func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -&gt; !spirv.coopmatrix&lt;8x16xf32, Subgroup, MatrixA&gt; {
+  // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -&gt; !spirv.coopmatrix&lt;8x16xf32, Subgroup, MatrixA&gt;
+  %0 = spirv.CompositeConstruct %arg0 : (f32) -&gt; !spirv.coopmatrix&lt;8x16xf32, Subgroup, MatrixA&gt;
+  return %0: !spirv.coopmatrix&lt;8x16xf32, Subgroup, MatrixA&gt;
+}
 
-func.func @composite_construct_NV.coopmatrix(%arg0 : f32) -&gt; !spirv.NV.coopmatrix&lt;8x16xf32, Subgroup&gt; {
+// CHECK-LABEL: func @composite_construct_coopmatrix_nv
+func.func @composite_construct_coopmatrix_nv(%arg0 : f32) -&gt; !spirv.NV.coopmatrix&lt;8x16xf32, Subgroup&gt; {
   // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -&gt; !spirv.NV.coopmatrix&lt;8x16xf32, Subgroup&gt;
   %0 = spirv.CompositeConstruct %arg0 : (f32) -&gt; !spirv.NV.coopmatrix&lt;8x16xf32, Subgroup&gt;
   return %0: !spirv.NV.coopmatrix&lt;8x16xf32, Subgroup&gt;
@@ -53,6 +57,24 @@ func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg
 
 // -----
 
+func.func @composite_construct_khr_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -&gt;
+  !spirv.coopmatrix&lt;8x16xf32, Subgroup, MatrixA&gt; {
+  // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
+  %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -&gt; !spirv.coopmatrix&lt;8x16xf32, Subgroup, MatrixA&gt;
+  return %0: !spirv.coopmatrix&lt;8x16xf32, Subgroup, MatrixA&gt;
+}
+
+// -----
+
+func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32) -&gt;
+  !spirv.coopmatrix&lt;8x16xf32, Subgroup, MatrixB&gt; {
+  // expected-error @+1 {{operand type mismatch: expected operand type &#x27;f32&#x27;, but provided &#x27;i32&#x27;}}
+  %0 = spirv.CompositeConstruct %arg0 : (i32) -&gt; !spirv.coopmatrix&lt;8x16xf32, Subgroup, MatrixB&gt;
+  return %0: !spirv.coopmatrix&lt;8x16xf32, Subgroup, MatrixB&gt;
+}
+
+// -----
+
 func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -&gt; !spirv.NV.coopmatrix&lt;8x16xf32, Subgroup&gt; {
   // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
   %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -&gt; !spirv.NV.coopmatrix&lt;8x16xf32, Subgroup&gt;
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66399


More information about the Mlir-commits mailing list