[Mlir-commits] [mlir] a011943 - [mlir][gpu] Support arith.extf in subgroup MMA elementwise ops

Lei Zhang llvmlistbot at llvm.org
Tue Aug 1 21:16:38 PDT 2023


Author: Lei Zhang
Date: 2023-08-01T21:12:37-07:00
New Revision: a01194377c8ae3178a4dd22dbc420caed9bba21d

URL: https://github.com/llvm/llvm-project/commit/a01194377c8ae3178a4dd22dbc420caed9bba21d
DIFF: https://github.com/llvm/llvm-project/commit/a01194377c8ae3178a4dd22dbc420caed9bba21d.diff

LOG: [mlir][gpu] Support arith.extf in subgroup MMA elementwise ops

This commit adds support for arith.extf in the supported list of
elementwise ops for subgroup MMA ops, and enables lowering to
SPIR-V.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D156847

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index e3cd604fcc30ce..3b20ba2b46e351 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1472,6 +1472,7 @@ def GPU_ElementwiseOpDivS : I32EnumAttrCase<"DIVS", 9, "divs">;
 def GPU_ElementwiseOpDivU : I32EnumAttrCase<"DIVU", 10, "divu">;
 def GPU_ElementwiseOpNEGF : I32EnumAttrCase<"NEGATEF", 11, "negatef">;
 def GPU_ElementwiseOpNEGS : I32EnumAttrCase<"NEGATES", 12, "negates">;
+def GPU_ElementwiseOpEXTF : I32EnumAttrCase<"EXTF", 13, "extf">;
 
 def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
   "elementwise operation to apply to mma matrix", [
@@ -1487,7 +1488,8 @@ def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
     GPU_ElementwiseOpDivS,
     GPU_ElementwiseOpDivU,
     GPU_ElementwiseOpNEGF,
-    GPU_ElementwiseOpNEGS
+    GPU_ElementwiseOpNEGS,
+    GPU_ElementwiseOpEXTF
   ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::gpu";

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index d64fa6ac4ece22..57e21530b9da76 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -61,6 +61,9 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
   case gpu::MMAElementwiseOp::NEGATES:
     builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
     return true;
+  case gpu::MMAElementwiseOp::EXTF:
+    builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
+    return true;
   default:
     break;
   }

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 409e9365a9f207..2c18de1c5b662e 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -214,6 +214,8 @@ static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
   });
 }
 
+static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
+
 /// Return the MMA elementwise enum associated with `op` if it is supported.
 /// Return `std::nullopt` otherwise.
 static std::optional<gpu::MMAElementwiseOp>
@@ -242,6 +244,8 @@ convertElementwiseOpToMMA(Operation *op) {
     return gpu::MMAElementwiseOp::DIVU;
   if (isa<arith::NegFOp>(op))
     return gpu::MMAElementwiseOp::NEGATEF;
+  if (isa<arith::ExtFOp>(op))
+    return gpu::MMAElementwiseOp::EXTF;
   return std::nullopt;
 }
 
@@ -297,6 +301,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
     return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
   if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
     return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
+  if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
+    return fpExtendSupportsMMAMatrixType(fpExtend);
   return elementwiseSupportsMMAMatrixType(op);
 }
 
@@ -1203,8 +1209,17 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
       return rewriter.notifyMatchFailure(op, "no mapping");
     matrixOperands.push_back(it->second);
   }
+  auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>();
+  if (opType == gpu::MMAElementwiseOp::EXTF) {
+    // The floating point extension case has a 
diff erent result type.
+    auto vectorType = op->getResultTypes()[0].cast<VectorType>();
+    resultType = gpu::MMAMatrixType::get(resultType.getShape(),
+                                         vectorType.getElementType(),
+                                         resultType.getOperand());
+  }
+
   Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
-      op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType);
+      op->getLoc(), resultType, matrixOperands, opType);
   valueMapping[op->getResult(0)] = newOp;
   return success();
 }

diff  --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
index 12b6a2eb94268c..a53eca65fc9869 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
@@ -142,6 +142,8 @@ module attributes {
       %D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
       // CHECK:  {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       %E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK:  {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> to !spirv.NV.coopmatrix<16x16xf32, Subgroup>
+      %F = gpu.subgroup_mma_elementwise extf %E : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
       // CHECK: spirv.Return
       gpu.return
     }

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 08f7e12cf55d98..fa9fff2dad6649 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -437,4 +437,26 @@ func.func @matmul_mixed_signedness_int8(%arg0: memref<16x32xi8>, %arg1: memref<1
   %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %Ae, %Be, %C : vector<16x32xi32>, vector<16x32xi32> into vector<16x16xi32>
   vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
   return
-}
\ No newline at end of file
+}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @cast_f16_to_f32_write
+//       CHECK:    %[[COMPUTE:.+]] = gpu.subgroup_mma_compute
+//       CHECK:    %[[EXT:.+]] = gpu.subgroup_mma_elementwise  extf %[[COMPUTE]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+//       CHECK:    gpu.subgroup_mma_store_matrix %[[EXT]]
+func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+  %cast = arith.extf %D : vector<16x16xf16> to vector<16x16xf32>
+  vector.transfer_write %cast, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
+  return
+}


        


More information about the Mlir-commits mailing list