[Mlir-commits] [mlir] 2c8afe1 - [mlir][gpu] Add support for f16 when lowering to nvvm intrinsics

Stephan Herhut llvmlistbot at llvm.org
Tue Jun 9 10:34:36 PDT 2020


Author: Stephan Herhut
Date: 2020-06-09T19:33:45+02:00
New Revision: 2c8afe1298e5f471a5736757b1cd2a708dd91ec9

URL: https://github.com/llvm/llvm-project/commit/2c8afe1298e5f471a5736757b1cd2a708dd91ec9
DIFF: https://github.com/llvm/llvm-project/commit/2c8afe1298e5f471a5736757b1cd2a708dd91ec9.diff

LOG: [mlir][gpu] Add support for f16 when lowering to nvvm intrinsics

Summary:
The NVVM target only provides implementations for tanh etc. on f32 and
f64 operands. To also support f16, we now insert operations to extend to f32
and truncate back to f16 around the intrinsic call.

Differential Revision: https://reviews.llvm.org/D81473

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index c7bbb6d989c8..58b5f1dbc975 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -20,6 +20,9 @@ namespace mlir {
 /// depending on the element type that Op operates upon. The function
 /// declaration is added in case it was not added before.
 ///
+/// If the input values are of f16 type, the value is first casted to f32, the
+/// function called and then the result casted back.
+///
 /// Example with NVVM:
 ///   %exp_f32 = std.exp %arg_f32 : f32
 ///
@@ -44,21 +47,48 @@ struct OpToFuncCallLowering : public ConvertToLLVMPattern {
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
 
-    LLVMType resultType = typeConverter.convertType(op->getResult(0).getType())
-                              .template cast<LLVM::LLVMType>();
-    LLVMType funcType = getFunctionType(resultType, operands);
-    StringRef funcName = getFunctionName(resultType);
+    static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
+                                  SourceOp>::value,
+                  "expected op with same operand and result types");
+
+    SmallVector<Value, 1> castedOperands;
+    for (Value operand : operands)
+      castedOperands.push_back(maybeCast(operand, rewriter));
+
+    LLVMType resultType =
+        castedOperands.front().getType().cast<LLVM::LLVMType>();
+    LLVMType funcType = getFunctionType(resultType, castedOperands);
+    StringRef funcName = getFunctionName(funcType.getFunctionResultType());
     if (funcName.empty())
       return failure();
 
     LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
     auto callOp = rewriter.create<LLVM::CallOp>(
-        op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
-    rewriter.replaceOp(op, {callOp.getResult(0)});
+        op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp),
+        castedOperands);
+
+    if (resultType == operands.front().getType()) {
+      rewriter.replaceOp(op, {callOp.getResult(0)});
+      return success();
+    }
+
+    Value truncated = rewriter.create<LLVM::FPTruncOp>(
+        op->getLoc(), operands.front().getType(), callOp.getResult(0));
+    rewriter.replaceOp(op, {truncated});
     return success();
   }
 
 private:
+  Value maybeCast(Value operand, PatternRewriter &rewriter) const {
+    LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
+    if (!type.isHalfTy())
+      return operand;
+
+    return rewriter.create<LLVM::FPExtOp>(
+        operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()),
+        operand);
+  }
+
   LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
                                  ArrayRef<Value> operands) const {
     using LLVM::LLVMType;

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index f05c9af1b30f..925615c0674e 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -219,12 +219,16 @@ gpu.module @test_module {
   // CHECK: llvm.func @__nv_tanhf(!llvm.float) -> !llvm.float
   // CHECK: llvm.func @__nv_tanh(!llvm.double) -> !llvm.double
   // CHECK-LABEL: func @gpu_tanh
-  func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = std.tanh %arg_f16 : f16
+    // CHECK: llvm.fpext %{{.*}} : !llvm.half to !llvm.float
+    // CHECK-NEXT: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
+    // CHECK-NEXT: llvm.fptrunc %{{.*}} : !llvm.float to !llvm.half
     %result32 = std.tanh %arg_f32 : f32
     // CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
     %result64 = std.tanh %arg_f64 : f64
     // CHECK: llvm.call @__nv_tanh(%{{.*}}) : (!llvm.double) -> !llvm.double
-    std.return %result32, %result64 : f32, f64
+    std.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 


        


More information about the Mlir-commits mailing list