[Mlir-commits] [mlir] bcfc0a9 - [MLIR][GPU] Replace fdiv on fp16 with promoted (fp32) multiplication with reciprocal plus one (conditional) Newton iteration.

Christian Sigg llvmlistbot at llvm.org
Fri Jun 3 23:03:38 PDT 2022


Author: Christian Sigg
Date: 2022-06-04T08:03:29+02:00
New Revision: bcfc0a9051014437b55ab932d9aca5ecdca6776b

URL: https://github.com/llvm/llvm-project/commit/bcfc0a9051014437b55ab932d9aca5ecdca6776b
DIFF: https://github.com/llvm/llvm-project/commit/bcfc0a9051014437b55ab932d9aca5ecdca6776b.diff

LOG: [MLIR][GPU] Replace fdiv on fp16 with promoted (fp32) multiplication with reciprocal plus one (conditional) Newton iteration.

This is correct for all values, i.e. the same as promoting the division to fp32 in the NVPTX backend. But it is faster (~10% in average, sometimes more) because:

- it performs less Newton iterations
- it avoids the slow path for e.g. denormals
- it allows reuse of the reciprocal for multiple divisions by the same divisor

Test program:
```
#include <stdio.h>
#include "cuda_fp16.h"

// This is a variant of CUDA's own __hdiv which is fast than hdiv_promote below
// and doesn't suffer from the perf cliff of div.rn.fp32 with 'special' values.
__device__ half hdiv_newton(half a, half b) {
  float fa = __half2float(a);
  float fb = __half2float(b);

  float rcp;
  asm("{rcp.approx.ftz.f32 %0, %1;\n}" : "=f"(rcp) : "f"(fb));

  float result = fa * rcp;
  auto exponent = reinterpret_cast<const unsigned&>(result) & 0x7f800000;
  if (exponent != 0 && exponent != 0x7f800000) {
    float err = __fmaf_rn(-fb, result, fa);
    result = __fmaf_rn(rcp, err, result);
  }

  return __float2half(result);
}

// Surprisingly, this is faster than CUDA's own __hdiv.
__device__ half hdiv_promote(half a, half b) {
  return __float2half(__half2float(a) / __half2float(b));
}

// This is an approximation that is accurate up to 1 ulp.
__device__ half hdiv_approx(half a, half b) {
  float fa = __half2float(a);
  float fb = __half2float(b);

  float result;
  asm("{div.approx.ftz.f32 %0, %1, %2;\n}" : "=f"(result) : "f"(fa), "f"(fb));
  return __float2half(result);
}

__global__ void CheckCorrectness() {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  half x = reinterpret_cast<const half&>(i);
  for (int j = 0; j < 65536; ++j) {
    half y = reinterpret_cast<const half&>(j);
    half d1 = hdiv_newton(x, y);
    half d2 = hdiv_promote(x, y);
    auto s1 = reinterpret_cast<const short&>(d1);
    auto s2 = reinterpret_cast<const short&>(d2);
    if (s1 != s2) {
      printf("%f (%u) / %f (%u), got %f (%hu), expected: %f (%hu)\n",
             __half2float(x), i, __half2float(y), j, __half2float(d1), s1,
             __half2float(d2), s2);
      //__trap();
    }
  }
}

__device__ half dst;

__global__ void ProfileBuiltin(half x) {
  #pragma unroll 1
  for (int i = 0; i < 10000000; ++i) {
    x = x / x;
  }
  dst = x;
}

__global__ void ProfilePromote(half x) {
  #pragma unroll 1
  for (int i = 0; i < 10000000; ++i) {
    x = hdiv_promote(x, x);
  }
  dst = x;
}

__global__ void ProfileNewton(half x) {
  #pragma unroll 1
  for (int i = 0; i < 10000000; ++i) {
    x = hdiv_newton(x, x);
  }
  dst = x;
}

__global__ void ProfileApprox(half x) {
  #pragma unroll 1
  for (int i = 0; i < 10000000; ++i) {
    x = hdiv_approx(x, x);
  }
  dst = x;
}

int main() {
  CheckCorrectness<<<256, 256>>>();
  half one = __float2half(1.0f);
  ProfileBuiltin<<<1, 1>>>(one);  // 1.001s
  ProfilePromote<<<1, 1>>>(one);  // 0.560s
  ProfileNewton<<<1, 1>>>(one);   // 0.508s
  ProfileApprox<<<1, 1>>>(one);   // 0.304s
  auto status = cudaDeviceSynchronize();
  printf("%s\n", cudaGetErrorString(status));
}
```

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D126158

Added: 
    mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h
    mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
    mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
    mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
    mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f19500e1957c7..20cb2e47343c0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -51,21 +51,21 @@ class NVVM_Op<string mnemonic, list<Trait> traits = []> :
 // NVVM intrinsic operations
 //===----------------------------------------------------------------------===//
 
-class NVVM_IntrOp<string mnem, list<int> overloadedResults,
-                  list<int> overloadedOperands, list<Trait> traits,
+class NVVM_IntrOp<string mnem, list<Trait> traits,
                   int numResults>
   : LLVM_IntrOpBase<NVVM_Dialect, mnem, "nvvm_" # !subst(".", "_", mnem),
-                    overloadedResults, overloadedOperands, traits, numResults>;
+                    /*list<int> overloadedResults=*/[],
+                    /*list<int> overloadedOperands=*/[],
+                    traits, numResults>;
 
 
 //===----------------------------------------------------------------------===//
 // NVVM special register op definitions
 //===----------------------------------------------------------------------===//
 
-class NVVM_SpecialRegisterOp<string mnemonic,
-    list<Trait> traits = []> :
-  NVVM_IntrOp<mnemonic, [], [], !listconcat(traits, [NoSideEffect]), 1>,
-  Arguments<(ins)> {
+class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
+  NVVM_IntrOp<mnemonic, !listconcat(traits, [NoSideEffect]), 1> {
+  let arguments = (ins);
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
@@ -92,6 +92,16 @@ def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
 def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
 def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
 
+//===----------------------------------------------------------------------===//
+// NVVM approximate op definitions
+//===----------------------------------------------------------------------===//
+
+def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [NoSideEffect], 1> {
+  let arguments = (ins F32:$arg);
+  let results = (outs F32:$res);
+  let assemblyFormat = "$arg attr-dict `:` type($res)";
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM synchronization op definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h
new file mode 100644
index 0000000000000..af0c4ea4e568c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h
@@ -0,0 +1,25 @@
+//===- OptimizeForNVVM.h - Optimize LLVM IR for NVVM -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+namespace NVVM {
+
+/// Creates a pass that optimizes LLVM IR for the NVVM target.
+std::unique_ptr<Pass> createOptimizeForTargetPass();
+
+} // namespace NVVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
index 868a0e5635105..39948557b55a6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H
 
 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
+#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
index 0dc193e794f52..060822603bc20 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
@@ -16,4 +16,9 @@ def LLVMLegalizeForExport : Pass<"llvm-legalize-for-export"> {
   let constructor = "mlir::LLVM::createLegalizeForExportPass()";
 }
 
+def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> {
+  let summary = "Optimize NVVM IR";
+  let constructor = "mlir::NVVM::createOptimizeForTargetPass()";
+}
+
 #endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index 3e1342dcf2c9c..e27d83e4426db 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRLLVMIRTransforms
   LegalizeForExport.cpp
+  OptimizeForNVVM.cpp
 
   DEPENDS
   MLIRLLVMPassIncGen

diff  --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
new file mode 100644
index 0000000000000..d269aa82ecec5
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
@@ -0,0 +1,97 @@
+//===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
+#include "PassDetail.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+// Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one
+// (conditional) Newton iteration.
+//
+// This as accurate as promoting the division to fp32 in the NVPTX backend, but
+// faster because it performs less Newton iterations, avoids the slow path
+// for e.g. denormals, and allows reuse of the reciprocal for multiple divisions
+// by the same divisor.
+struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
+  using OpRewritePattern<LLVM::FDivOp>::OpRewritePattern;
+
+private:
+  LogicalResult matchAndRewrite(LLVM::FDivOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+struct NVVMOptimizeForTarget
+    : public NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
+  void runOnOperation() override;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<NVVM::NVVMDialect>();
+  }
+};
+} // namespace
+
+LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
+                                            PatternRewriter &rewriter) const {
+  if (!op.getType().isF16())
+    return rewriter.notifyMatchFailure(op, "not f16");
+  Location loc = op.getLoc();
+
+  Type f32Type = rewriter.getF32Type();
+  Type i32Type = rewriter.getI32Type();
+
+  // Extend lhs and rhs to fp32.
+  Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs());
+  Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs());
+
+  // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp.
+  Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs);
+  Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp);
+
+  // Refine the approximation with one Newton iteration:
+  // float refined = approx + (lhs - approx * rhs) * rcp;
+  Value err = rewriter.create<LLVM::FMAOp>(
+      loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs);
+  Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx);
+
+  // Use refined value if approx is normal (exponent neither all 0 or all 1).
+  Value mask = rewriter.create<LLVM::ConstantOp>(
+      loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000));
+  Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx);
+  Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask);
+  Value zero = rewriter.create<LLVM::ConstantOp>(
+      loc, i32Type, rewriter.getUI32IntegerAttr(0));
+  Value pred = rewriter.create<LLVM::OrOp>(
+      loc,
+      rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero),
+      rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask));
+  Value result =
+      rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined);
+
+  // Replace with trucation back to fp16.
+  rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result);
+
+  return success();
+}
+
+void NVVMOptimizeForTarget::runOnOperation() {
+  MLIRContext *ctx = getOperation()->getContext();
+  RewritePatternSet patterns(ctx);
+  patterns.add<ExpandDivF16>(ctx);
+  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+    return signalPassFailure();
+}
+
+std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() {
+  return std::make_unique<NVVMOptimizeForTarget>();
+}

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 9b28841c3c781..c978d773d9591 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -29,6 +29,13 @@ func.func @nvvm_special_regs() -> i32 {
   llvm.return %0 : i32
 }
 
+// CHECK-LABEL: @nvvm_rcp
+func.func @nvvm_rcp(%arg0: f32) -> f32 {
+  // CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32
+  %0 = nvvm.rcp.approx.ftz.f %arg0 : f32
+  llvm.return %0 : f32
+}
+
 // CHECK-LABEL: @llvm_nvvm_barrier0
 func.func @llvm_nvvm_barrier0() {
   // CHECK: nvvm.barrier0

diff  --git a/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir b/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
new file mode 100644
index 0000000000000..e1cfd0c44f89b
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -llvm-optimize-for-nvvm-target | FileCheck %s
+
+// CHECK-LABEL: llvm.func @fdiv_fp16
+llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 {
+  // CHECK-DAG: %[[c0:.*]]      = llvm.mlir.constant(0 : ui32) : i32
+  // CHECK-DAG: %[[mask:.*]]    = llvm.mlir.constant(2139095040 : ui32) : i32
+  // CHECK-DAG: %[[lhs:.*]]     = llvm.fpext %arg0 : f16 to f32
+  // CHECK-DAG: %[[rhs:.*]]     = llvm.fpext %arg1 : f16 to f32
+  // CHECK-DAG: %[[rcp:.*]]     = nvvm.rcp.approx.ftz.f %[[rhs]] : f32
+  // CHECK-DAG: %[[approx:.*]]  = llvm.fmul %[[lhs]], %[[rcp]] : f32
+  // CHECK-DAG: %[[neg:.*]]     = llvm.fneg %[[rhs]] : f32
+  // CHECK-DAG: %[[err:.*]]     = "llvm.intr.fma"(%[[approx]], %[[neg]], %[[lhs]]) : (f32, f32, f32) -> f32
+  // CHECK-DAG: %[[refined:.*]] = "llvm.intr.fma"(%[[err]], %[[rcp]], %[[approx]]) : (f32, f32, f32) -> f32
+  // CHECK-DAG: %[[cast:.*]]    = llvm.bitcast %[[approx]] : f32 to i32
+  // CHECK-DAG: %[[exp:.*]]     = llvm.and %[[cast]], %[[mask]] : i32
+  // CHECK-DAG: %[[is_zero:.*]] = llvm.icmp "eq" %[[exp]], %[[c0]] : i32
+  // CHECK-DAG: %[[is_mask:.*]] = llvm.icmp "eq" %[[exp]], %[[mask]] : i32
+  // CHECK-DAG: %[[pred:.*]]    = llvm.or %[[is_zero]], %[[is_mask]] : i1
+  // CHECK-DAG: %[[select:.*]]  = llvm.select %[[pred]], %[[approx]], %[[refined]] : i1, f32
+  // CHECK-DAG: %[[result:.*]]  = llvm.fptrunc %[[select]] : f32 to f16
+  %result = llvm.fdiv %arg0, %arg1 : f16
+  // CHECK: llvm.return %[[result]] : f16
+  llvm.return %result : f16
+}

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 53af04140c38d..a66560d0e0da8 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -33,6 +33,13 @@ llvm.func @nvvm_special_regs() -> i32 {
   llvm.return %1 : i32
 }
 
+// CHECK-LABEL: @nvvm_rcp
+llvm.func @nvvm_rcp(%0: f32) -> f32 {
+  // CHECK: call float @llvm.nvvm.rcp.approx.ftz.f
+  %1 = nvvm.rcp.approx.ftz.f %0 : f32
+  llvm.return %1 : f32
+}
+
 // CHECK-LABEL: @llvm_nvvm_barrier0
 llvm.func @llvm_nvvm_barrier0() {
   // CHECK: call void @llvm.nvvm.barrier0()

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 48264e5126614..3afa3101fa48a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3386,7 +3386,9 @@ cc_library(
         ":IR",
         ":LLVMDialect",
         ":LLVMPassIncGen",
+        ":NVVMDialect",
         ":Pass",
+        ":Transforms",
     ],
 )
 


        


More information about the Mlir-commits mailing list