[Mlir-commits] [mlir] [mlir][spirv] Support coop matrix in `spirv.CompositeConstruct` (PR #66399)

Jakub Kuderski llvmlistbot at llvm.org
Thu Sep 14 09:46:55 PDT 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/66399:

>From f3790647ab40e93f942f72ef1e4ad838885a9f14 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 14 Sep 2023 12:44:01 -0400
Subject: [PATCH] [mlir][spirv] Support coop matrix in
 `spirv.CompositeConstruct`

Also improve the documentation (code and website).
---
 .../Dialect/SPIRV/IR/SPIRVCompositeOps.td     | 10 +++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 36 ++++++++++---------
 mlir/test/Dialect/SPIRV/IR/composite-ops.mlir | 34 ++++++++++++++----
 3 files changed, 57 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
index b8307b488af6fa5..33078b74cc3cf4f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
@@ -53,7 +53,15 @@ def SPIRV_CompositeConstructOp : SPIRV_Op<"CompositeConstruct", [Pure]> {
     #### Example:
 
     ```mlir
-    %0 = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32>
+    %a = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32>
+    %b = spirv.CompositeConstruct %a, %1 : (vector<3xf32>, f32) -> vector<4xf32>
+
+    %c = spirv.CompositeConstruct %1 :
+      (f32) -> !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+
+    %d = spirv.CompositeConstruct %a, %4, %5 :
+      (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) ->
+        !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
     ```
   }];
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 1f07b0b9e85bff6..3906bf74ea72235 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -29,6 +29,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -363,31 +364,35 @@ LogicalResult spirv::AddressOfOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::CompositeConstructOp::verify() {
-  auto cType = llvm::cast<spirv::CompositeType>(getType());
   operand_range constituents = this->getConstituents();
 
-  if (auto coopType = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(cType)) {
-    if (constituents.size() != 1)
-      return emitOpError("has incorrect number of operands: expected ")
-             << "1, but provided " << constituents.size();
-    if (coopType.getElementType() != constituents.front().getType())
-      return emitOpError("operand type mismatch: expected operand type ")
-             << coopType.getElementType() << ", but provided "
-             << constituents.front().getType();
-    return success();
-  }
+  // There are 4 cases with varying verification rules:
+  // 1. Cooperative Matrices (1 constituent)
+  // 2. Structs (1 constituent for each member)
+  // 3. Arrays (1 constituent for each array element)
+  // 4. Vectors (1 constituent (sub-)element for each vector element)
+
+  auto coopElementType =
+      llvm::TypeSwitch<Type, Type>(getType())
+          .Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
+                spirv::JointMatrixINTELType>(
+              [](auto coopType) { return coopType.getElementType(); })
+          .Default([](Type) { return nullptr; });
 
-  if (auto jointType = llvm::dyn_cast<spirv::JointMatrixINTELType>(cType)) {
+  // Case 1. -- matrices.
+  if (coopElementType) {
     if (constituents.size() != 1)
       return emitOpError("has incorrect number of operands: expected ")
              << "1, but provided " << constituents.size();
-    if (jointType.getElementType() != constituents.front().getType())
+    if (coopElementType != constituents.front().getType())
       return emitOpError("operand type mismatch: expected operand type ")
-             << jointType.getElementType() << ", but provided "
+             << coopElementType << ", but provided "
              << constituents.front().getType();
     return success();
   }
 
+  // Case 2./3./4. -- number of constituents matches the number of elements.
+  auto cType = llvm::cast<spirv::CompositeType>(getType());
   if (constituents.size() == cType.getNumElements()) {
     for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
       if (constituents[index].getType() != cType.getElementType(index)) {
@@ -399,8 +404,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
     return success();
   }
 
-  // If not constructing a cooperative matrix type, then we must be constructing
-  // a vector type.
+  // Case 4. -- check that all constituents add up tp the expected vector type.
   auto resultType = llvm::dyn_cast<VectorType>(cType);
   if (!resultType)
     return emitOpError(
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index ce7f6bc6118b316..2891513961d5e2a 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -4,22 +4,20 @@
 // spirv.CompositeConstruct
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: func @composite_construct_vector
 func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
   // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
   %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
   return %0: vector<3xf32>
 }
 
-// -----
-
+// CHECK-LABEL: func @composite_construct_struct
 func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
   // CHECK: spirv.CompositeConstruct
   %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
   return %0: !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
 }
 
-// -----
-
 // CHECK-LABEL: func @composite_construct_mixed_scalar_vector
 func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
   // CHECK: spirv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32>
@@ -27,9 +25,15 @@ func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2
   return %0: vector<4xf32>
 }
 
-// -----
+// CHECK-LABEL: func @composite_construct_coopmatrix_khr
+func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
+  // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+  %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+  return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+}
 
-func.func @composite_construct_NV.coopmatrix(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
+// CHECK-LABEL: func @composite_construct_coopmatrix_nv
+func.func @composite_construct_coopmatrix_nv(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
   // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
   %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
   return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
@@ -53,6 +57,24 @@ func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg
 
 // -----
 
+func.func @composite_construct_khr_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) ->
+  !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
+  // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
+  %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+  return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+}
+
+// -----
+
+func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32) ->
+  !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> {
+  // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
+  %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
+  return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
+}
+
+// -----
+
 func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
   // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
   %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>



More information about the Mlir-commits mailing list