[Mlir-commits] [mlir] [mlir][gpu] Create splat constant in MMA elementwise min/max operation (PR #88393)

Adam Siemieniuk llvmlistbot at llvm.org
Tue May 14 08:24:15 PDT 2024


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/88393

>From 367fbe3cd4e52e42f09b3cf936fbf5023ef8e405 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 11 Apr 2024 15:33:18 +0200
Subject: [PATCH] [mlir][gpu] Create splat constant in MMA elementwise min/max
 operation

This PR ensures that a splat constant is created for MMA elementwise
min/max operands are of vector type.

The change fixes related runtime error about TypedAttr mismatch with
the constant type.
---
 .../Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp    |  9 +--
 .../GPUToNVVM/wmma-ops-to-nvvm.mlir           | 55 +++++++++++++++++--
 2 files changed, 54 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 775dd1e609037..359280f2e92b7 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -310,10 +310,11 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
   Value isNan = builder.create<LLVM::FCmpOp>(
       loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
-  Value nan = builder.create<LLVM::ConstantOp>(
-      loc, lhs.getType(),
-      builder.getFloatAttr(floatType,
-                           APFloat::getQNaN(floatType.getFloatSemantics())));
+  auto qnan = APFloat::getQNaN(floatType.getFloatSemantics());
+  TypedAttr nanAttr = builder.getFloatAttr(floatType, qnan);
+  if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
+    nanAttr = SplatElementsAttr::get(vecType, qnan);
+  Value nan = builder.create<LLVM::ConstantOp>(loc, lhs.getType(), nanAttr);
   return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
 }
 
diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index 9dec666bf4b3d..bdcfd9a22073c 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -287,7 +287,7 @@ gpu.module @test_module {
 
 gpu.module @test_module {
 
-// CHECK-LABEL: func @gpu_wmma_elementwise
+// CHECK-LABEL: func @gpu_wmma_elementwise_f16
 //       CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -312,7 +312,7 @@ gpu.module @test_module {
 //       CHECK: %[[CMP0:.*]] = llvm.fcmp "ogt" %[[A0]], %[[B0]] : vector<2xf16>
 //       CHECK: %[[SEL0:.*]] = llvm.select %[[CMP0]], %[[A0]], %[[B0]] : vector<2xi1>, vector<2xf16>
 //       CHECK: %[[CMP1:.*]] = llvm.fcmp "uno" %[[A0]], %[[B0]] : vector<2xf16>
-//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16>
+//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(dense<0x7E00> : vector<2xf16>) : vector<2xf16>
 //       CHECK: %[[C0:.*]] = llvm.select %[[CMP1]], %[[NAN]], %[[SEL0]] : vector<2xi1>, vector<2xf16>
 //       CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK: %[[A1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -320,7 +320,7 @@ gpu.module @test_module {
 //       CHECK: %[[CMP2:.*]] = llvm.fcmp "ogt" %[[A1]], %[[B1]] : vector<2xf16>
 //       CHECK: %[[SEL1:.*]] = llvm.select %[[CMP2]], %[[A1]], %[[B1]] : vector<2xi1>, vector<2xf16>
 //       CHECK: %[[CMP3:.*]] = llvm.fcmp "uno" %[[A1]], %[[B1]] : vector<2xf16>
-//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16>
+//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(dense<0x7E00> : vector<2xf16>) : vector<2xf16>
 //       CHECK: %[[C1:.*]] = llvm.select %[[CMP3]], %[[NAN]], %[[SEL1]] : vector<2xi1>, vector<2xf16>
 //       CHECK: %[[M2:.*]] = llvm.insertvalue %[[C1]], %[[M1]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK: %[[A2:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -328,7 +328,7 @@ gpu.module @test_module {
 //       CHECK: %[[CMP4:.*]] = llvm.fcmp "ogt" %[[A2]], %[[B2]] : vector<2xf16>
 //       CHECK: %[[SEL2:.*]] = llvm.select %[[CMP4]], %[[A2]], %[[B2]] : vector<2xi1>, vector<2xf16>
 //       CHECK: %[[CMP5:.*]] = llvm.fcmp "uno" %[[A2]], %[[B2]] : vector<2xf16>
-//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16>
+//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(dense<0x7E00> : vector<2xf16>) : vector<2xf16>
 //       CHECK: %[[C2:.*]] = llvm.select %[[CMP5]], %[[NAN]], %[[SEL2]] : vector<2xi1>, vector<2xf16>
 //       CHECK: %[[M3:.*]] = llvm.insertvalue %[[C2]], %[[M2]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 //       CHECK: %[[A3:.*]] = llvm.extractvalue %{{.*}}[3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -336,14 +336,57 @@ gpu.module @test_module {
 //       CHECK: %[[CMP6:.*]] = llvm.fcmp "ogt" %[[A3]], %[[B3]] : vector<2xf16>
 //       CHECK: %[[SEL3:.*]] = llvm.select %[[CMP6]], %[[A3]], %[[B3]] : vector<2xi1>, vector<2xf16>
 //       CHECK: %[[CMP7:.*]] = llvm.fcmp "uno" %[[A3]], %[[B3]] : vector<2xf16>
-//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7E00 : f16) : vector<2xf16>
+//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(dense<0x7E00> : vector<2xf16>) : vector<2xf16>
 //       CHECK: %[[C3:.*]] = llvm.select %[[CMP7]], %[[NAN]], %[[SEL3]] : vector<2xi1>, vector<2xf16>
 //       CHECK: %[[M5:.*]] = llvm.insertvalue %[[C3]], %[[M3]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 
 //       CHECK: llvm.return %[[M5]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-  func.func @gpu_wmma_elementwise(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">)  ->(!gpu.mma_matrix<16x16xf16, "COp">) {
+  func.func @gpu_wmma_elementwise_f16(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">)  ->(!gpu.mma_matrix<16x16xf16, "COp">) {
     %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
     %D = gpu.subgroup_mma_elementwise maxf %C, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
     return %D : !gpu.mma_matrix<16x16xf16, "COp">
   }
 }
+
+// -----
+
+gpu.module @test_module {
+
+// CHECK-LABEL: func @gpu_wmma_elementwise_f32
+//       CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+//       CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+//       CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+//       CHECK: %[[C0:.*]] = llvm.fadd %[[A0]], %[[B0]]  : f32
+//       CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+
+//       CHECK: %[[A7:.*]] = llvm.extractvalue %{{.*}}[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+//       CHECK: %[[B7:.*]] = llvm.extractvalue %{{.*}}[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+//       CHECK: %[[C7:.*]] = llvm.fadd %[[A7]], %[[B7]]  : f32
+//       CHECK: %[[M7:.*]] = llvm.insertvalue %[[C7]], {{.*}}[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+
+//       CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+//       CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+//       CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+//       CHECK: %[[CMP0:.*]] = llvm.fcmp "ogt" %[[A0]], %[[B0]] : f32
+//       CHECK: %[[SEL0:.*]] = llvm.select %[[CMP0]], %[[A0]], %[[B0]] : i1, f32
+//       CHECK: %[[CMP1:.*]] = llvm.fcmp "uno" %[[A0]], %[[B0]] : f32
+//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+//       CHECK: %[[C0:.*]] = llvm.select %[[CMP1]], %[[NAN]], %[[SEL0]] : i1, f32
+//       CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+
+//       CHECK: %[[A7:.*]] = llvm.extractvalue %{{.*}}[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+//       CHECK: %[[B7:.*]] = llvm.extractvalue %{{.*}}[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+//       CHECK: %[[CMP14:.*]] = llvm.fcmp "ogt" %[[A7]], %[[B7]] : f32
+//       CHECK: %[[SEL7:.*]] = llvm.select %[[CMP14]], %[[A7]], %[[B7]] : i1, f32
+//       CHECK: %[[CMP15:.*]] = llvm.fcmp "uno" %[[A7]], %[[B7]] : f32
+//       CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+//       CHECK: %[[C7:.*]] = llvm.select %[[CMP15]], %[[NAN]], %[[SEL7]] : i1, f32
+//       CHECK: %[[M8:.*]] = llvm.insertvalue %[[C7]], {{.*}}[7] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+
+//       CHECK: llvm.return %[[M8]] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+  func.func @gpu_wmma_elementwise_f32(%A : !gpu.mma_matrix<16x16xf32, "COp">, %B : !gpu.mma_matrix<16x16xf32, "COp">)  ->(!gpu.mma_matrix<16x16xf32, "COp">) {
+    %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+    %D = gpu.subgroup_mma_elementwise maxf %C, %B : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+    return %D : !gpu.mma_matrix<16x16xf32, "COp">
+  }
+}



More information about the Mlir-commits mailing list