[Mlir-commits] [mlir] d88cc07 - [mlir][gpuTonvvm] Remove hardcoded values in MMAType to llvm struct

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 2 08:13:00 PDT 2021


Author: thomasraoux
Date: 2021-11-02T08:12:27-07:00
New Revision: d88cc079434da6a2d18d5bb51643671195aa7ce1

URL: https://github.com/llvm/llvm-project/commit/d88cc079434da6a2d18d5bb51643671195aa7ce1
DIFF: https://github.com/llvm/llvm-project/commit/d88cc079434da6a2d18d5bb51643671195aa7ce1.diff

LOG: [mlir][gpuTonvvm] Remove hardcoded values in MMAType to llvm struct

Also relax the types allowed in GPU wmma ops

Differential Revision: https://reviews.llvm.org/D112969

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 267a362b372ce..229ccd0e95da8 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -9,6 +9,7 @@
 #define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
 
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include <memory>
 
 namespace mlir {
@@ -22,8 +23,11 @@ class OperationPass;
 
 namespace gpu {
 class GPUModuleOp;
+class MMAMatrixType;
 }
 
+LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
+
 /// Configure target to convert from the GPU dialect to NVVM.
 void configureGpuToNVVMConversionLegality(ConversionTarget &target);
 

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 3e62833934623..3f1ad84278cb0 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -1068,8 +1068,8 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
     ```
   }];
 
-  let arguments = (ins Arg<MMAMatrixOf<[F16]>>:$opA,
-                  Arg<MMAMatrixOf<[F16]>>:$opB,
+  let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$opA,
+                  Arg<MMAMatrixOf<[F16, F32]>>:$opB,
                   Arg<MMAMatrixOf<[F16, F32]>>:$opC);
 
   let results = (outs GPU_MMAMatrix:$res);

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 3ac7ee4e2d204..d0589f1fc35bb 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -134,34 +134,8 @@ struct LowerGpuOpsToNVVMOpsPass
 
     // Lowering for MMAMatrixType.
     converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
-      // The number of items in structToReturn are dependent on the the dataType
-      // and the MMA operand that this operation is associated with.
-      llvm::DenseMap<StringRef, int64_t> numElemsPerThreadF16,
-          numElemsPerThreadF32;
-      numElemsPerThreadF16["AOp"] = 8;
-      numElemsPerThreadF16["BOp"] = 8;
-      numElemsPerThreadF16["COp"] = 4;
-      numElemsPerThreadF32["AOp"] = 8;
-      numElemsPerThreadF32["BOp"] = 8;
-      numElemsPerThreadF32["COp"] = 8;
-      Type structToReturn;
-      if (type.getElementType().isF16()) {
-        // Number of f16's in 32-bit.
-        unsigned vecSize = 2;
-        Type vec = VectorType::get(vecSize, FloatType::getF16(&getContext()));
-        unsigned size = numElemsPerThreadF16[type.getOperand()];
-        SmallVector<Type> elements(size, vec);
-        structToReturn =
-            LLVM::LLVMStructType::getLiteral(&getContext(), elements);
-      } else if (type.getElementType().isF32()) {
-        unsigned size = numElemsPerThreadF32[type.getOperand()];
-        SmallVector<Type> elements(size, FloatType::getF32(&getContext()));
-        structToReturn =
-            LLVM::LLVMStructType::getLiteral(&getContext(), elements);
-      }
-      return structToReturn;
+      return convertMMAToLLVMType(type);
     });
-
     RewritePatternSet patterns(m.getContext());
     RewritePatternSet llvmPatterns(m.getContext());
 

diff  --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 878d0cf22fd8f..b0bf94b7f8066 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -59,16 +60,6 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
   llvm_unreachable("Unsupported type");
 }
 
-/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
-static LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) {
-  NVVM::MMAFrag frag = convertOperand(type.getOperand());
-  NVVM::MMATypes eltType = getElementType(type);
-  std::pair<Type, unsigned> typeInfo =
-      inferMMAType(eltType, frag, type.getContext());
-  return LLVM::LLVMStructType::getLiteral(
-      type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
-}
-
 /// This class implements the conversion of GPU MMA loadOp to wmma.load op
 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
 /// emits code that is necessary to store the data in the destination memref
@@ -433,6 +424,17 @@ struct WmmaElementwiseOpToNVVMLowering
 } // anonymous namespace
 
 namespace mlir {
+
+/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
+LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) {
+  NVVM::MMAFrag frag = convertOperand(type.getOperand());
+  NVVM::MMATypes eltType = getElementType(type);
+  std::pair<Type, unsigned> typeInfo =
+      inferMMAType(eltType, frag, type.getContext());
+  return LLVM::LLVMStructType::getLiteral(
+      type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
+}
+
 void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                              RewritePatternSet &patterns) {
   patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,


        


More information about the Mlir-commits mailing list