[Mlir-commits] [mlir] 2c7827d - [mlir][spirv] Add GPU subgroup MMA to spirv.MMAMatrixTimesScalar

Lei Zhang llvmlistbot at llvm.org
Mon Dec 5 14:30:59 PST 2022


Author: Lei Zhang
Date: 2022-12-05T22:30:50Z
New Revision: 2c7827da4f5bf758a1659a0c4a2d0e7944827c42

URL: https://github.com/llvm/llvm-project/commit/2c7827da4f5bf758a1659a0c4a2d0e7944827c42
DIFF: https://github.com/llvm/llvm-project/commit/2c7827da4f5bf758a1659a0c4a2d0e7944827c42.diff

LOG: [mlir][spirv] Add GPU subgroup MMA to spirv.MMAMatrixTimesScalar

Along the way, make the default pattern fail instead of crashing
when an elementwise op is not supported yet.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D139280

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
    mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 43d7f6237671..81e7d5c29437 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -24,42 +24,47 @@
 
 using namespace mlir;
 
-// See SPV_NV_cooperative_matrix for supported element wise ops.
-static void createElementWiseOp(ConversionPatternRewriter &builder,
+/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
+/// when the elementwise op directly supports with cooperative matrix type.
+/// Returns false if cannot.
+///
+/// See SPV_NV_cooperative_matrix for supported elementwise ops.
+static bool createElementwiseOp(ConversionPatternRewriter &builder,
                                 gpu::SubgroupMmaElementwiseOp op,
                                 spirv::CooperativeMatrixNVType coopType,
                                 ValueRange operands) {
   switch (op.getOpType()) {
   case gpu::MMAElementwiseOp::ADDF:
     builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
-    return;
+    return true;
   case gpu::MMAElementwiseOp::ADDI:
     builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
-    return;
+    return true;
   case gpu::MMAElementwiseOp::SUBF:
     builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
-    return;
+    return true;
   case gpu::MMAElementwiseOp::SUBI:
     builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
-    return;
+    return true;
   case gpu::MMAElementwiseOp::DIVF:
     builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
-    return;
+    return true;
   case gpu::MMAElementwiseOp::DIVS:
     builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
-    return;
+    return true;
   case gpu::MMAElementwiseOp::DIVU:
     builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
-    return;
+    return true;
   case gpu::MMAElementwiseOp::NEGATEF:
     builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
-    return;
+    return true;
   case gpu::MMAElementwiseOp::NEGATES:
     builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
-    return;
+    return true;
   default:
-    llvm_unreachable("unknown op");
+    break;
   }
+  return false;
 }
 
 namespace {
@@ -163,13 +168,14 @@ struct WmmaConstantOpToSPIRVLowering
   }
 };
 
-/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops.
-struct WmmaElementwiseOpToSPIRVLowering
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// the default case.
+struct WmmaElementwiseOpToSPIRVDefaultLowering
     : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
+  matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
                   OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // All operands should be of cooperative matrix types.
@@ -178,9 +184,58 @@ struct WmmaElementwiseOpToSPIRVLowering
         return failure();
     }
     auto coopType = convertMMAToSPIRVType(
-        subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
-    createElementWiseOp(rewriter, subgroupMmaElementwiseOp, coopType,
-                        adaptor.getOperands());
+        elementwiseOp.getType().cast<gpu::MMAMatrixType>());
+    return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
+                                       adaptor.getOperands()));
+  }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// matrix times scalar case.
+struct WmmaElementwiseOpToSPIRVScalarMulLowering
+    : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (adaptor.getOperands().size() != 2)
+      return failure();
+    // All operands should be of cooperative matrix types.
+    for (Value operand : adaptor.getOperands()) {
+      if (!operand.getType().isa<spirv::CooperativeMatrixNVType>())
+        return failure();
+    }
+
+    // Use the original operands to check whether one of the operands is a splat
+    // scalar value.
+    Value lhs = elementwiseOp.getOperands().front();
+    Value rhs = elementwiseOp.getOperands().back();
+    Value splat = nullptr;
+    Value matrix = nullptr;
+    if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+      splat = adaptor.getOperands().front();
+      matrix = adaptor.getOperands().back();
+    } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+      matrix = adaptor.getOperands().front();
+      splat = adaptor.getOperands().back();
+    }
+    if (!splat || !matrix)
+      return failure();
+
+    // Constant MMA matrix ops are converted to spirv.CompositeConstruct ops.
+    Value scalar = nullptr;
+    auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
+    if (!cc)
+      return failure();
+    assert(cc.getConstituents().size() == 1);
+    scalar = cc.getConstituents().front();
+
+    auto coopType = convertMMAToSPIRVType(
+        elementwiseOp.getType().cast<gpu::MMAMatrixType>());
+    rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
+        elementwiseOp, coopType, ValueRange{matrix, scalar});
     return success();
   }
 };
@@ -198,8 +253,11 @@ mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
 
 void mlir::populateGpuWMMAToSPIRVConversionPatterns(
     SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
   patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
                WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
-               WmmaElementwiseOpToSPIRVLowering>(converter,
-                                                 patterns.getContext());
+               WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
+  // Give the following patterns higher benefit to prevail over the default one.
+  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
+                                                          /*benefit=*/2);
 }

diff  --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
index c4dc7458bc31..b544ae5a1238 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
@@ -4,10 +4,8 @@ module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
   gpu.module @kernels {
-    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
     // CHECK-LABEL: spirv.func @gpu_wmma_load_op
-    // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
-    // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
+    // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
     gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       %i = arith.constant 16 : index
@@ -27,7 +25,6 @@ module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
   gpu.module @kernels {
-    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
     // CHECK-LABEL: spirv.func @gpu_wmma_load_op_transpose
     // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
     // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
@@ -50,11 +47,9 @@ module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
   gpu.module @kernels {
-    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
     // CHECK-LABEL: spirv.func @gpu_wmma_store_op
-    // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
-    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
-    // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
+    // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
+    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
     gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       %i = arith.constant 16 : index
@@ -74,7 +69,6 @@ module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
   gpu.module @kernels {
-    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
     // CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose
     // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
     // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
@@ -98,12 +92,10 @@ module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
   gpu.module @kernels {
-    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
     // CHECK-LABEL: spirv.func @gpu_wmma_mma_op
-    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
-    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
-    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 2)>})
-    // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
+    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
     gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
@@ -120,7 +112,6 @@ module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
   gpu.module @kernels {
-    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
     // CHECK-LABEL: spirv.func @gpu_wmma_constant_op
     gpu.func @gpu_wmma_constant_op() kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
@@ -140,11 +131,10 @@ module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
   gpu.module @kernels {
-    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
-    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op
-    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
-    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
-    gpu.func @gpu_wmma_elementwise_op(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
+    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+    gpu.func @gpu_wmma_elementwise_op_default(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       // CHECK:  {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
       %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
@@ -157,3 +147,24 @@ module attributes {
     }
   }
 }
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+  gpu.module @kernels {
+    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
+    // CHECK-SAME:    %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME:    %[[S:.+]]: f16
+    gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+      %C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+      %D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      gpu.return
+    }
+  }
+}


        


More information about the Mlir-commits mailing list