[Mlir-commits] [mlir] 25cbfa0 - [mlir][spirv] Allow mixed type cooperative matrix muladd

Thomas Raoux llvmlistbot at llvm.org
Thu Jun 18 13:07:22 PDT 2020


Author: Thomas Raoux
Date: 2020-06-18T13:05:09-07:00
New Revision: 25cbfa0788846c7ec06affb9c0e0d4a87b510c02

URL: https://github.com/llvm/llvm-project/commit/25cbfa0788846c7ec06affb9c0e0d4a87b510c02
DIFF: https://github.com/llvm/llvm-project/commit/25cbfa0788846c7ec06affb9c0e0d4a87b510c02.diff

LOG: [mlir][spirv] Allow mixed type cooperative matrix muladd

muladd can have differenti types for lhs/rhs and acc/destination. Change
verifier and update the test to use supported example.

Differential Revision: https://reviews.llvm.org/D82042

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/cooperative-matrix.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index efe685873bd7..87456f000edc 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2753,8 +2753,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
       typeR.getScope() != typeB.getScope() ||
       typeR.getScope() != typeC.getScope())
     return op.emitOpError("matrix scope must match");
-  if (typeR.getElementType() != typeA.getElementType() ||
-      typeR.getElementType() != typeB.getElementType() ||
+  if (typeA.getElementType() != typeB.getElementType() ||
       typeR.getElementType() != typeC.getElementType())
     return op.emitOpError("matrix element type must match");
   return success();

diff  --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
index 51c709067f6f..a2dafaddfa21 100644
--- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
@@ -37,9 +37,9 @@ spv.func @cooperative_matrix_length() -> i32 "None" {
 }
 
 // CHECK-LABEL: @cooperative_matrix_muladd
-spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
-  %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x32xi8, Subgroup>, %b : !spv.coopmatrix<32x8xi8, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}  : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+  %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
   spv.Return
 }
 


        


More information about the Mlir-commits mailing list