[Mlir-commits] [mlir] [mlir] Add lowering of absi and fpowi to libdevice (PR #123644)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 20 08:38:19 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Oleksandr "Alex" Zinenko (ftynse)
<details>
<summary>Changes</summary>
More concise version of #<!-- -->123422.
---
Full diff: https://github.com/llvm/llvm-project/pull/123644.diff
3 Files Affected:
- (modified) mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h (+54-28)
- (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+19)
- (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir (+24)
``````````diff
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 46fd182346b3b7..9f7ceb11752bab 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -16,10 +16,16 @@
namespace mlir {
-/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
-/// `f32ApproxFunc` or `f16Func` depending on the element type and the
-/// fastMathFlag of that Op. The function declaration is added in case it was
-/// not added before.
+namespace {
+/// Detection trait tor the `getFastmath` instance method.
+template <typename T>
+using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
+} // namespace
+
+/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
+/// `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and
+/// the fastMathFlag of that Op, if present. The function declaration is added
+/// in case it was not added before.
///
/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
/// value is first casted to f32, the function called and then the result casted
@@ -39,14 +45,22 @@ namespace mlir {
///
/// will be transformed into
/// llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
+///
+/// Final example with NVVM:
+/// %pow_f32 = math.fpowi %arg_f32, %arg_i32
+///
+/// will be transformed into
+/// llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32
template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
StringRef f32Func, StringRef f64Func,
- StringRef f32ApproxFunc, StringRef f16Func)
+ StringRef f32ApproxFunc, StringRef f16Func,
+ StringRef i32Func = "")
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
- f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
+ f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
+ i32Func(i32Func) {}
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -76,9 +90,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
Type resultType = castedOperands.front().getType();
Type funcType = getFunctionType(resultType, castedOperands);
- StringRef funcName =
- getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
- op.getFastmath());
+ StringRef funcName = getFunctionName(
+ cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op);
if (funcName.empty())
return failure();
@@ -91,6 +104,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return success();
}
+ assert(callOp.getResult().getType().isF32() &&
+ "only f32 types are supposed to be truncated back");
Value truncated = rewriter.create<LLVM::FPTruncOp>(
op->getLoc(), adaptor.getOperands().front().getType(),
callOp.getResult());
@@ -98,7 +113,6 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return success();
}
-private:
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
Type type = operand.getType();
if (!isa<Float16Type, BFloat16Type>(type))
@@ -117,38 +131,50 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
- StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
- if (isa<Float16Type>(type))
- return f16Func;
- if (isa<Float32Type>(type)) {
- if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
- !f32ApproxFunc.empty())
- return f32ApproxFunc;
- else
- return f32Func;
- }
- if (isa<Float64Type>(type))
- return f64Func;
- return "";
- }
-
LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
Operation *op) const {
using LLVM::LLVMFuncOp;
auto funcAttr = StringAttr::get(op->getContext(), funcName);
- Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
+ auto funcOp =
+ SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
if (funcOp)
- return cast<LLVMFuncOp>(*funcOp);
+ return funcOp;
- mlir::OpBuilder b(op->getParentOfType<FunctionOpInterface>());
+ auto parentFunc = op->getParentOfType<FunctionOpInterface>();
+ assert(parentFunc && "expected there to be a parent function");
+ OpBuilder b(parentFunc);
return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
}
+ StringRef getFunctionName(Type type, SourceOp op) const {
+ bool useApprox = false;
+ if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) {
+ arith::FastMathFlags flag = op.getFastmath();
+ useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
+ !f32ApproxFunc.empty();
+ }
+
+ if (isa<Float16Type>(type))
+ return f16Func;
+ if (isa<Float32Type>(type)) {
+ if (useApprox)
+ return f32ApproxFunc;
+ return f32Func;
+ }
+ if (isa<Float64Type>(type))
+ return f64Func;
+
+ if (type.isInteger(32))
+ return i32Func;
+ return "";
+ }
+
const std::string f32Func;
const std::string f64Func;
const std::string f32ApproxFunc;
const std::string f16Func;
+ const std::string i32Func;
};
} // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 2768929f460e2e..11363a0d60ebfa 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -446,6 +446,22 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
f32ApproxFunc, f16Func);
}
+template <typename OpTy>
+static void populateIntOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ StringRef i32Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func);
+}
+
+template <typename OpTy>
+static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ StringRef f32Func, StringRef f64Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "");
+}
+
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<GPUSubgroupReduceOpLowering>(converter);
@@ -509,6 +525,7 @@ void mlir::populateGpuToNVVMConversionPatterns(
populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
"__nv_fmod");
+ populateIntOpPatterns<math::AbsIOp>(converter, patterns, "__nv_abs");
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
"__nv_fabs");
populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
@@ -555,6 +572,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
"__nv_log2", "__nv_fast_log2f");
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
"__nv_fast_powf");
+ populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, "__nv_powif",
+ "__nv_powi");
populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
"__nv_round");
populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index f52dd6c0d0ce30..94c0f9e34c29ce 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1033,3 +1033,27 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+
+gpu.module @test_module_52 {
+ // CHECK: llvm.func @__nv_abs(i32) -> i32
+ // CHECK-LABEL: func @gpu_abs
+ func.func @gpu_abs(%arg_i32 : i32) -> (i32) {
+ %result32 = math.absi %arg_i32 : i32
+ // CHECK: llvm.call @__nv_abs(%{{.*}}) : (i32) -> i32
+ func.return %result32 : i32
+ }
+}
+
+gpu.module @test_module_53 {
+ // CHECK: llvm.func @__nv_powif(f32, i32) -> f32
+ // CHECK: llvm.func @__nv_powi(f64, i32) -> f64
+ // CHECK-LABEL: func @gpu_powi
+ func.func @gpu_powi(%arg_f32 : f32, %arg_f64 : f64, %arg_i32 : i32) -> (f32, f64) {
+ %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32
+ // CHECK: llvm.call @__nv_powif(%{{.*}}, %{{.*}}) : (f32, i32) -> f32
+ %result64 = math.fpowi %arg_f64, %arg_i32 : f64, i32
+ // CHECK: llvm.call @__nv_powi(%{{.*}}, %{{.*}}) : (f64, i32) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/123644
More information about the Mlir-commits
mailing list