[Mlir-commits] [mlir] b05b897 - [mlir][gpu][spirv] Verify elementwise op type as mulf when converting to spirv.MatrixTimesScalar
Thomas Raoux
llvmlistbot at llvm.org
Wed Dec 14 19:16:17 PST 2022
Author: Quinn Dawkins
Date: 2022-12-15T03:15:04Z
New Revision: b05b8970d8b35e0ffcdab1a77e8be836c0aaae70
URL: https://github.com/llvm/llvm-project/commit/b05b8970d8b35e0ffcdab1a77e8be836c0aaae70
DIFF: https://github.com/llvm/llvm-project/commit/b05b8970d8b35e0ffcdab1a77e8be836c0aaae70.diff
LOG: [mlir][gpu][spirv] Verify elementwise op type as mulf when converting to spirv.MatrixTimesScalar
Conversion from gpu.subgroup_mma_constant_matrix to spirv.MatrixTimesScalar didn't check that the op type was a multiplication and thus would incorrectly convert other elementwise scalar operations.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D140081
Added:
Modified:
mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 81e7d5c294371..f7e135620133c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -208,6 +208,9 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
return failure();
}
+ if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF)
+ return failure();
+
// Use the original operands to check whether one of the operands is a splat
// scalar value.
Value lhs = elementwiseOp.getOperands().front();
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
index b544ae5a12380..829107f2625be 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
@@ -164,6 +164,27 @@ module attributes {
%C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
%D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
+ // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: %[[S:.+]]: f16
+ gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+ // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup>
+ %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup>
+ %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
gpu.return
}
}
More information about the Mlir-commits
mailing list