[Mlir-commits] [mlir] d88cc07 - [mlir][gpuTonvvm] Remove hardcoded values in MMAType to llvm struct
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 2 08:13:00 PDT 2021
Author: thomasraoux
Date: 2021-11-02T08:12:27-07:00
New Revision: d88cc079434da6a2d18d5bb51643671195aa7ce1
URL: https://github.com/llvm/llvm-project/commit/d88cc079434da6a2d18d5bb51643671195aa7ce1
DIFF: https://github.com/llvm/llvm-project/commit/d88cc079434da6a2d18d5bb51643671195aa7ce1.diff
LOG: [mlir][gpuTonvvm] Remove hardcoded values in MMAType to llvm struct
Also relax the types allowed in GPU wmma ops
Differential Revision: https://reviews.llvm.org/D112969
Added:
Modified:
mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 267a362b372ce..229ccd0e95da8 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -9,6 +9,7 @@
#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include <memory>
namespace mlir {
@@ -22,8 +23,11 @@ class OperationPass;
namespace gpu {
class GPUModuleOp;
+class MMAMatrixType;
}
+LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
+
/// Configure target to convert from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 3e62833934623..3f1ad84278cb0 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -1068,8 +1068,8 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
```
}];
- let arguments = (ins Arg<MMAMatrixOf<[F16]>>:$opA,
- Arg<MMAMatrixOf<[F16]>>:$opB,
+ let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$opA,
+ Arg<MMAMatrixOf<[F16, F32]>>:$opB,
Arg<MMAMatrixOf<[F16, F32]>>:$opC);
let results = (outs GPU_MMAMatrix:$res);
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 3ac7ee4e2d204..d0589f1fc35bb 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -134,34 +134,8 @@ struct LowerGpuOpsToNVVMOpsPass
// Lowering for MMAMatrixType.
converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
- // The number of items in structToReturn are dependent on the the dataType
- // and the MMA operand that this operation is associated with.
- llvm::DenseMap<StringRef, int64_t> numElemsPerThreadF16,
- numElemsPerThreadF32;
- numElemsPerThreadF16["AOp"] = 8;
- numElemsPerThreadF16["BOp"] = 8;
- numElemsPerThreadF16["COp"] = 4;
- numElemsPerThreadF32["AOp"] = 8;
- numElemsPerThreadF32["BOp"] = 8;
- numElemsPerThreadF32["COp"] = 8;
- Type structToReturn;
- if (type.getElementType().isF16()) {
- // Number of f16's in 32-bit.
- unsigned vecSize = 2;
- Type vec = VectorType::get(vecSize, FloatType::getF16(&getContext()));
- unsigned size = numElemsPerThreadF16[type.getOperand()];
- SmallVector<Type> elements(size, vec);
- structToReturn =
- LLVM::LLVMStructType::getLiteral(&getContext(), elements);
- } else if (type.getElementType().isF32()) {
- unsigned size = numElemsPerThreadF32[type.getOperand()];
- SmallVector<Type> elements(size, FloatType::getF32(&getContext()));
- structToReturn =
- LLVM::LLVMStructType::getLiteral(&getContext(), elements);
- }
- return structToReturn;
+ return convertMMAToLLVMType(type);
});
-
RewritePatternSet patterns(m.getContext());
RewritePatternSet llvmPatterns(m.getContext());
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 878d0cf22fd8f..b0bf94b7f8066 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -59,16 +60,6 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
llvm_unreachable("Unsupported type");
}
-/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
-static LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) {
- NVVM::MMAFrag frag = convertOperand(type.getOperand());
- NVVM::MMATypes eltType = getElementType(type);
- std::pair<Type, unsigned> typeInfo =
- inferMMAType(eltType, frag, type.getContext());
- return LLVM::LLVMStructType::getLiteral(
- type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
-}
-
/// This class implements the conversion of GPU MMA loadOp to wmma.load op
/// in the NVVM dialect. The conversion not only emits the NVVM op but also
/// emits code that is necessary to store the data in the destination memref
@@ -433,6 +424,17 @@ struct WmmaElementwiseOpToNVVMLowering
} // anonymous namespace
namespace mlir {
+
+/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
+LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) {
+ NVVM::MMAFrag frag = convertOperand(type.getOperand());
+ NVVM::MMATypes eltType = getElementType(type);
+ std::pair<Type, unsigned> typeInfo =
+ inferMMAType(eltType, frag, type.getContext());
+ return LLVM::LLVMStructType::getLiteral(
+ type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
+}
+
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
More information about the Mlir-commits
mailing list