[Mlir-commits] [mlir] Add Lowerings for GPU WMMA F16/F32 ops to ROCDL dialect (PR #69357)

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Oct 23 08:03:29 PDT 2023


================
@@ -0,0 +1,512 @@
+//===--------- WmmaOpsToROCDL.cpp - GPU WMMA ops to ROCDL lowering --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
+// ROCDL Dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
+#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h"
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Checks if all the operands of the op being lowered are of LLVM Types. The
+/// types are expected to be converted by the `LLVMTypeConverter` before the op
+/// is actually lowered. If the type of an operands is not already converted it
+/// hints a missing typeConversion and failure is returned in that case.
+static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
+                                     ConversionPatternRewriter &rewriter) {
+  if (!llvm::all_of(operands, [](Value value) {
+        return LLVM::isCompatibleType(value.getType());
+      })) {
+    return rewriter.notifyMatchFailure(
+        op, "cannot convert if operands aren't of LLVM type.");
+  }
+
+  return success();
+}
+
+/// Return the WMMA operand corresponding to `operandName`.
+static ROCDL::ROCDLWMMAFrag convertOperand(StringRef operandName) {
+  if (operandName.equals("AOp"))
+    return ROCDL::ROCDLWMMAFrag::a;
+  if (operandName.equals("BOp"))
+    return ROCDL::ROCDLWMMAFrag::b;
+  if (operandName.equals("COp"))
+    return ROCDL::ROCDLWMMAFrag::c;
+  llvm_unreachable("Unknown operand name");
+}
+
+/// Generate load ops for `AOp` or `BOp`. `dataPtr` is the base address starting
+/// from which values will be loaded. `laneId` lane ID of the thread loading the
+/// values. `vecType` is the vector type of the values that will be loaded. The
+/// loaded values are returned in `loadedValues`. The address for loading the
+/// values is generated in the following manner:
+///
+/// wrappedLaneId = laneId % 16
+/// for i in vectorSize {
+///   loadedValues[i] = dataPtr + ((wrappedLaneId * leadingDim) + i);
+/// }
+static void generateAbLoadOpsVecFirst(Location loc, Value dataPtr, Value laneId,
+                                      Value leadingDim, VectorType vecType,
+                                      PatternRewriter &rewriter,
+                                      Value &loadedValues) {
+  // We wrap the laneId to 16 because of matrix replication in RDNA 3.
+  Value wrapSize = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                     /*value=*/16);
+  mlir::TypedAttr x;
+  Value wrappedLaneId = rewriter.create<LLVM::SRemOp>(loc, laneId, wrapSize);
+  loadedValues = rewriter.create<LLVM::UndefOp>(loc, vecType);
+  Value laneIdLdm =
+      rewriter.create<LLVM::MulOp>(loc, wrappedLaneId, leadingDim);
+  for (unsigned i = 0; i < vecType.getNumElements(); ++i) {
+    Value iter = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                   /*value=*/i);
+    Value curInx = rewriter.create<LLVM::AddOp>(loc, laneIdLdm, iter);
----------------
krzysz00 wrote:

Would it make sense to move this part out of the loop, so the GEP in the loop just adds a constant?

https://github.com/llvm/llvm-project/pull/69357


More information about the Mlir-commits mailing list