[Mlir-commits] [mlir] Add Lowerings for GPU WMMA F16/F32 ops to ROCDL dialect (PR #69357)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Oct 23 08:03:29 PDT 2023
================
@@ -0,0 +1,512 @@
+//===--------- WmmaOpsToROCDL.cpp - GPU WMMA ops to ROCDL lowering --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
+// ROCDL Dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
+#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h"
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Checks if all the operands of the op being lowered are of LLVM Types. The
+/// types are expected to be converted by the `LLVMTypeConverter` before the op
+/// is actually lowered. If the type of an operands is not already converted it
+/// hints a missing typeConversion and failure is returned in that case.
+static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter) {
+ if (!llvm::all_of(operands, [](Value value) {
+ return LLVM::isCompatibleType(value.getType());
+ })) {
+ return rewriter.notifyMatchFailure(
+ op, "cannot convert if operands aren't of LLVM type.");
+ }
+
+ return success();
+}
+
+/// Return the WMMA operand corresponding to `operandName`.
+static ROCDL::ROCDLWMMAFrag convertOperand(StringRef operandName) {
+ if (operandName.equals("AOp"))
+ return ROCDL::ROCDLWMMAFrag::a;
+ if (operandName.equals("BOp"))
+ return ROCDL::ROCDLWMMAFrag::b;
+ if (operandName.equals("COp"))
+ return ROCDL::ROCDLWMMAFrag::c;
+ llvm_unreachable("Unknown operand name");
+}
+
+/// Generate load ops for `AOp` or `BOp`. `dataPtr` is the base address starting
+/// from which values will be loaded. `laneId` lane ID of the thread loading the
+/// values. `vecType` is the vector type of the values that will be loaded. The
+/// loaded values are returned in `loadedValues`. The address for loading the
+/// values is generated in the following manner:
+///
+/// wrappedLaneId = laneId % 16
+/// for i in vectorSize {
+/// loadedValues[i] = dataPtr + ((wrappedLaneId * leadingDim) + i);
+/// }
+static void generateAbLoadOpsVecFirst(Location loc, Value dataPtr, Value laneId,
+ Value leadingDim, VectorType vecType,
+ PatternRewriter &rewriter,
+ Value &loadedValues) {
+ // We wrap the laneId to 16 because of matrix replication in RDNA 3.
+ Value wrapSize = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+ /*value=*/16);
+ mlir::TypedAttr x;
+ Value wrappedLaneId = rewriter.create<LLVM::SRemOp>(loc, laneId, wrapSize);
+ loadedValues = rewriter.create<LLVM::UndefOp>(loc, vecType);
+ Value laneIdLdm =
+ rewriter.create<LLVM::MulOp>(loc, wrappedLaneId, leadingDim);
+ for (unsigned i = 0; i < vecType.getNumElements(); ++i) {
+ Value iter = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+ /*value=*/i);
+ Value curInx = rewriter.create<LLVM::AddOp>(loc, laneIdLdm, iter);
----------------
krzysz00 wrote:
Would it make sense to move this part out of the loop, so the GEP in the loop just adds a constant?
https://github.com/llvm/llvm-project/pull/69357
More information about the Mlir-commits
mailing list