[Mlir-commits] [mlir] 25cbfa0 - [mlir][spirv] Allow mixed type cooperative matrix muladd
Thomas Raoux
llvmlistbot at llvm.org
Thu Jun 18 13:07:22 PDT 2020
Author: Thomas Raoux
Date: 2020-06-18T13:05:09-07:00
New Revision: 25cbfa0788846c7ec06affb9c0e0d4a87b510c02
URL: https://github.com/llvm/llvm-project/commit/25cbfa0788846c7ec06affb9c0e0d4a87b510c02
DIFF: https://github.com/llvm/llvm-project/commit/25cbfa0788846c7ec06affb9c0e0d4a87b510c02.diff
LOG: [mlir][spirv] Allow mixed type cooperative matrix muladd
muladd can have differenti types for lhs/rhs and acc/destination. Change
verifier and update the test to use supported example.
Differential Revision: https://reviews.llvm.org/D82042
Added:
Modified:
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index efe685873bd7..87456f000edc 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2753,8 +2753,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
typeR.getScope() != typeB.getScope() ||
typeR.getScope() != typeC.getScope())
return op.emitOpError("matrix scope must match");
- if (typeR.getElementType() != typeA.getElementType() ||
- typeR.getElementType() != typeB.getElementType() ||
+ if (typeA.getElementType() != typeB.getElementType() ||
typeR.getElementType() != typeC.getElementType())
return op.emitOpError("matrix element type must match");
return success();
diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
index 51c709067f6f..a2dafaddfa21 100644
--- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
@@ -37,9 +37,9 @@ spv.func @cooperative_matrix_length() -> i32 "None" {
}
// CHECK-LABEL: @cooperative_matrix_muladd
-spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
- %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x32xi8, Subgroup>, %b : !spv.coopmatrix<32x8xi8, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+ %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return
}
More information about the Mlir-commits
mailing list