[Mlir-commits] [mlir] Add Lowerings for GPU WMMA F16/F32 ops to ROCDL dialect (PR #69357)

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Oct 23 08:03:29 PDT 2023


================
@@ -0,0 +1,180 @@
+//===-------- WmmaOpsToAMDGPU.cpp - GPU WMMA ops to AMD GPU lowering ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
+// AMD GPU Dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h"
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+
+using namespace mlir;
+
+namespace {
+
+static LogicalResult areAllVectorTypes(Operation *op, ValueRange operands,
+                                       ConversionPatternRewriter &rewriter) {
+  if (!llvm::all_of(operands, [](Value value) {
+        return isa<mlir::VectorType>(value.getType());
+      })) {
+    return rewriter.notifyMatchFailure(
+        op, "cannot convert if operands aren't of Vector type.");
+  }
+
+  return success();
+}
+
+/// Create a WMMA compute intrinsic doing the multiply-add operation as :
+///
+///  `cOp` = `aOp` * `bOp` + `cOp`
+///
+/// and return the generated op in `computeOp`.
+static LogicalResult createWMMAComputeIntrinsic(Value aOp, Value bOp, Value cOp,
+                                                Location loc, bool opSelect,
+                                                PatternRewriter &rewriter,
+                                                Value &computeOp) {
+  Type aType = aOp.getType();
+  Type bType = bOp.getType();
+  Type cType = cOp.getType();
+
+  // All the intrinsics present currently operate on vector types.
+  auto checkVecType = [](Value value, StringRef op) {
+    if (!isa<VectorType>(value.getType())) {
+      return mlir::emitError(value.getLoc(), op + "should be of vector type");
+    }
+    return InFlightDiagnostic();
----------------
krzysz00 wrote:

`return success();`

https://github.com/llvm/llvm-project/pull/69357


More information about the Mlir-commits mailing list