[Mlir-commits] [mlir] [mlir][spirv] Support coop matrix in `spirv.CompositeConstruct` (PR #66399)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 14 09:46:42 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
<details>
<summary>Changes</summary>
Also improve the documentation (code and website).
--
Full diff: https://github.com/llvm/llvm-project/pull/66399.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td (+9-1)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+20-16)
- (modified) mlir/test/Dialect/SPIRV/IR/composite-ops.mlir (+28-6)
<pre>
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
index b8307b488af6fa5..8216814d9f99598 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
@@ -53,7 +53,15 @@ def SPIRV_CompositeConstructOp : SPIRV_Op<"CompositeConstruct", [Pure]> {
#### Example:
```mlir
- %0 = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32>
+ %a = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32>
+ %b = spirv.CompositeConstruct %a, %1 : (vector<3xf32>, f32) -> vector<4xf32>
+
+ %c = spirv.CompositeConstruct %1 :
+ !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+
+ %d = spirv.CompositeConstruct %a, %4, %5 :
+ (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) ->
+ !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
```
}];
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 1f07b0b9e85bff6..3906bf74ea72235 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -29,6 +29,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
@@ -363,31 +364,35 @@ LogicalResult spirv::AddressOfOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::CompositeConstructOp::verify() {
- auto cType = llvm::cast<spirv::CompositeType>(getType());
operand_range constituents = this->getConstituents();
- if (auto coopType = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(cType)) {
- if (constituents.size() != 1)
- return emitOpError("has incorrect number of operands: expected ")
- << "1, but provided " << constituents.size();
- if (coopType.getElementType() != constituents.front().getType())
- return emitOpError("operand type mismatch: expected operand type ")
- << coopType.getElementType() << ", but provided "
- << constituents.front().getType();
- return success();
- }
+ // There are 4 cases with varying verification rules:
+ // 1. Cooperative Matrices (1 constituent)
+ // 2. Structs (1 constituent for each member)
+ // 3. Arrays (1 constituent for each array element)
+ // 4. Vectors (1 constituent (sub-)element for each vector element)
+
+ auto coopElementType =
+ llvm::TypeSwitch<Type, Type>(getType())
+ .Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
+ spirv::JointMatrixINTELType>(
+ [](auto coopType) { return coopType.getElementType(); })
+ .Default([](Type) { return nullptr; });
- if (auto jointType = llvm::dyn_cast<spirv::JointMatrixINTELType>(cType)) {
+ // Case 1. -- matrices.
+ if (coopElementType) {
if (constituents.size() != 1)
return emitOpError("has incorrect number of operands: expected ")
<< "1, but provided " << constituents.size();
- if (jointType.getElementType() != constituents.front().getType())
+ if (coopElementType != constituents.front().getType())
return emitOpError("operand type mismatch: expected operand type ")
- << jointType.getElementType() << ", but provided "
+ << coopElementType << ", but provided "
<< constituents.front().getType();
return success();
}
+ // Case 2./3./4. -- number of constituents matches the number of elements.
+ auto cType = llvm::cast<spirv::CompositeType>(getType());
if (constituents.size() == cType.getNumElements()) {
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
if (constituents[index].getType() != cType.getElementType(index)) {
@@ -399,8 +404,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
return success();
}
- // If not constructing a cooperative matrix type, then we must be constructing
- // a vector type.
+ // Case 4. -- check that all constituents add up tp the expected vector type.
auto resultType = llvm::dyn_cast<VectorType>(cType);
if (!resultType)
return emitOpError(
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index ce7f6bc6118b316..2891513961d5e2a 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -4,22 +4,20 @@
// spirv.CompositeConstruct
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: func @composite_construct_vector
func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
// CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
%0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
return %0: vector<3xf32>
}
-// -----
-
+// CHECK-LABEL: func @composite_construct_struct
func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
// CHECK: spirv.CompositeConstruct
%0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
return %0: !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
}
-// -----
-
// CHECK-LABEL: func @composite_construct_mixed_scalar_vector
func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
// CHECK: spirv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32>
@@ -27,9 +25,15 @@ func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2
return %0: vector<4xf32>
}
-// -----
+// CHECK-LABEL: func @composite_construct_coopmatrix_khr
+func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
+ // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+ %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+ return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+}
-func.func @composite_construct_NV.coopmatrix(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
+// CHECK-LABEL: func @composite_construct_coopmatrix_nv
+func.func @composite_construct_coopmatrix_nv(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
// CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
%0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
@@ -53,6 +57,24 @@ func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg
// -----
+func.func @composite_construct_khr_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) ->
+ !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
+ // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
+ %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+ return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+}
+
+// -----
+
+func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32) ->
+ !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> {
+ // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
+ %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
+ return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
+}
+
+// -----
+
func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
// expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
%0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
</pre>
</details>
https://github.com/llvm/llvm-project/pull/66399
More information about the Mlir-commits
mailing list