[Mlir-commits] [mlir] b05b897 - [mlir][gpu][spirv] Verify elementwise op type as mulf when converting to spirv.MatrixTimesScalar

Thomas Raoux llvmlistbot at llvm.org
Wed Dec 14 19:16:17 PST 2022


Author: Quinn Dawkins
Date: 2022-12-15T03:15:04Z
New Revision: b05b8970d8b35e0ffcdab1a77e8be836c0aaae70

URL: https://github.com/llvm/llvm-project/commit/b05b8970d8b35e0ffcdab1a77e8be836c0aaae70
DIFF: https://github.com/llvm/llvm-project/commit/b05b8970d8b35e0ffcdab1a77e8be836c0aaae70.diff

LOG: [mlir][gpu][spirv] Verify elementwise op type as mulf when converting to spirv.MatrixTimesScalar

Conversion from gpu.subgroup_mma_constant_matrix to spirv.MatrixTimesScalar didn't check that the op type was a multiplication and thus would incorrectly convert other elementwise scalar operations.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D140081

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
    mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 81e7d5c294371..f7e135620133c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -208,6 +208,9 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
         return failure();
     }
 
+    if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF)
+      return failure();
+
     // Use the original operands to check whether one of the operands is a splat
     // scalar value.
     Value lhs = elementwiseOp.getOperands().front();

diff  --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
index b544ae5a12380..829107f2625be 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
@@ -164,6 +164,27 @@ module attributes {
       %C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
       // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
       %D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: spirv.Return
+      gpu.return
+    }
+  }
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+  gpu.module @kernels {
+    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
+    // CHECK-SAME:    %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME:    %[[S:.+]]: f16
+    gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup>
+      %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup>
+      %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
       gpu.return
     }
   }


        


More information about the Mlir-commits mailing list