[Mlir-commits] [mlir] [mlir][gpu] Support extf before contract when converting to MMA ops (PR #91988)

Lei Zhang llvmlistbot at llvm.org
Mon May 13 12:09:15 PDT 2024


https://github.com/antiagainst updated https://github.com/llvm/llvm-project/pull/91988

>From 6406abc20bf43d264f731af892fcdc958ee2cc7e Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Mon, 13 May 2024 09:05:44 -0700
Subject: [PATCH 1/2] [mlir][gpu] Support extf before contract when converting
 to MMA ops

This commit allows `inferFragType` to see through all arith.ext
op users before reaching contract op for figuring out the fragment
type.
---
 .../Conversion/VectorToGPU/VectorToGPU.cpp    | 14 +++++++---
 .../VectorToGPU/vector-to-mma-ops.mlir        | 27 +++++++++++++++++++
 2 files changed, 37 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 782cc92f83fee..ad7408bb06fc1 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -515,6 +515,13 @@ struct CombineTransferReadOpTranspose final
 // TODO: Change the GPU dialect to abstract the layout at the this level and
 // only care about it during lowering to NVVM.
 static const char *inferFragType(Operation *op) {
+  // We can have arith.ext ops before reaching contract ops. See through them.
+  if (op->hasOneUse()) {
+    Operation *extOp = *op->user_begin();
+    if (isa<arith::ExtFOp, arith::ExtUIOp, arith::ExtSIOp>(extOp))
+      return inferFragType(extOp);
+  }
+
   for (Operation *users : op->getUsers()) {
     auto contract = dyn_cast<vector::ContractionOp>(users);
     if (!contract)
@@ -560,13 +567,12 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
   if (op->hasOneUse()) {
     auto *user = *op->user_begin();
     // Infer the signedness of the mma type from the integer extend.
-    bool isSignedExtend = isa<arith::ExtSIOp>(user);
-    if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
+    if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
       elType = IntegerType::get(
           op.getContext(), cast<IntegerType>(elType).getWidth(),
-          isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
+          isa<arith::ExtSIOp>(user) ? IntegerType::Signed
+                                    : IntegerType::Unsigned);
       mappingResult = user->getResult(0);
-      fragType = inferFragType(user);
     }
   }
   gpu::MMAMatrixType type =
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 962ed7de584a2..8526ff1392599 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -490,3 +490,30 @@ func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector
 }
 
 // -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @cast_f16_to_f32_read
+//       CHECK:    %[[A:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//       CHECK:    %[[C:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:    %[[AE:.+]] = gpu.subgroup_mma_elementwise  extf %[[A]] : (!gpu.mma_matrix<16x16xf16, "AOp">) -> !gpu.mma_matrix<16x16xf32, "AOp">
+//       CHECK:    %[[CE:.+]] = gpu.subgroup_mma_elementwise  extf %[[C]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+//       CHECK:    %[[B:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//       CHECK:    %[[BE:.+]] = gpu.subgroup_mma_elementwise  extf %[[B]] : (!gpu.mma_matrix<16x16xf16, "BOp">) -> !gpu.mma_matrix<16x16xf32, "BOp">
+//       CHECK:    gpu.subgroup_mma_compute %[[AE]], %[[BE]], %[[CE]]
+func.func @cast_f16_to_f32_read(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %Aext = arith.extf %A : vector<16x16xf16> to vector<16x16xf32>
+  %Bext = arith.extf %B : vector<16x16xf16> to vector<16x16xf32>
+  %Cext = arith.extf %C : vector<16x16xf16> to vector<16x16xf32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                        %Aext, %Bext, %Cext : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
+  vector.transfer_write %D, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
+  return
+}

>From 269e0082ac34c6929d8fa33d797f69a2bf1d78df Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Mon, 13 May 2024 12:06:23 -0700
Subject: [PATCH 2/2] Address comments

---
 mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index ad7408bb06fc1..332f0a2eecfcf 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -515,11 +515,12 @@ struct CombineTransferReadOpTranspose final
 // TODO: Change the GPU dialect to abstract the layout at the this level and
 // only care about it during lowering to NVVM.
 static const char *inferFragType(Operation *op) {
-  // We can have arith.ext ops before reaching contract ops. See through them.
+  // We can have arith.ext ops before reaching contract ops. See through them
+  // and other kinds of elementwise ops.
   if (op->hasOneUse()) {
-    Operation *extOp = *op->user_begin();
-    if (isa<arith::ExtFOp, arith::ExtUIOp, arith::ExtSIOp>(extOp))
-      return inferFragType(extOp);
+    Operation *userOp = *op->user_begin();
+    if (userOp->hasTrait<OpTrait::Elementwise>())
+      return inferFragType(userOp);
   }
 
   for (Operation *users : op->getUsers()) {



More information about the Mlir-commits mailing list