[Mlir-commits] [mlir] [MLIR][ROCDL] Add math.clampf -> rocdl.fmed3 conversion (PR #163520)
Keshav Vinayak Jha
llvmlistbot at llvm.org
Thu Oct 16 02:42:06 PDT 2025
https://github.com/keshavvinayak01 updated https://github.com/llvm/llvm-project/pull/163520
>From f5cf0218677a5019f78bfe451fcd343b19beb4c8 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Mon, 13 Oct 2025 12:55:36 -0700
Subject: [PATCH 01/12] [MLIR][ROCDL] Added math.clampf -> rocdl.fmed3
conversion
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
.../mlir/Conversion/MathToROCDL/MathToROCDL.h | 4 +-
mlir/include/mlir/Conversion/Passes.td | 8 +
.../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 2 +-
.../Conversion/MathToROCDL/MathToROCDL.cpp | 54 +-
.../Conversion/MathToROCDL/math-to-rocdl.mlir | 941 +++++++++++++-----
5 files changed, 745 insertions(+), 264 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
index 46573e7966ccc..770f257d89bd5 100644
--- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -9,6 +9,7 @@
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>
@@ -20,7 +21,8 @@ class Pass;
/// Populate the given list with patterns that convert from Math to ROCDL calls.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns,
+ amdgpu::Chipset chipset);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3c18ecc753d0f..c3fd397e258ae 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -755,6 +755,14 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
"func::FuncDialect",
"vector::VectorDialect",
];
+ let options = [
+ Option<"chipset", "chipset", "std::string",
+
+
+ /*default=*/"\"gfx000\"",
+ "Chipset that these operations will run on">
+ ];
+
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index b215211e131d4..c03f3a5d3889c 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
GPUSubgroupBroadcastOpToROCDL>(converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
- populateMathToROCDLConversionPatterns(converter, patterns);
+ populateMathToROCDLConversionPatterns(converter, patterns, chipset);
}
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index df219f3ff4f6e..ceb3d22c6bd59 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -10,6 +10,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -42,8 +43,39 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
f32ApproxFunc, f16Func);
}
+struct ClampFOpConversion final
+ : public ConvertOpToLLVMPattern<math::ClampFOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ ClampFOpConversion(const LLVMTypeConverter &converter,
+ amdgpu::Chipset chipset)
+ : ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
+
+ LogicalResult
+ matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // V_MED3_F16/F32 only exists in gfx9+ artchitectures
+ if (chipset.majorVersion < 9) {
+ return rewriter.notifyMatchFailure(
+ op, ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) +
+ "): V_MED_F16 / V_MED3_F32 not supported."));
+ }
+ rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
+ op.getMin(), op.getMax());
+ return success();
+ }
+ amdgpu::Chipset chipset;
+};
+
+static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ amdgpu::Chipset chipset) {
+
+ patterns.add<ClampFOpConversion>(converter, chipset);
+}
+
void mlir::populateMathToROCDLConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ amdgpu::Chipset chipset) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
@@ -118,27 +150,31 @@ void mlir::populateMathToROCDLConversionPatterns(
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64", "__ocml_fmod_f16");
+
+ addChipsetDependentPatterns(converter, patterns, chipset);
}
-namespace {
-struct ConvertMathToROCDLPass
- : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
- ConvertMathToROCDLPass() = default;
+struct ConvertMathToROCDLPass final
+ : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+ using impl::ConvertMathToROCDLBase<
+ ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
+
void runOnOperation() override;
};
-} // namespace
void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
MLIRContext *ctx = m.getContext();
+ FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
- populateMathToROCDLConversionPatterns(converter, patterns);
+ populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
ConversionTarget target(getContext());
- target.addLegalDialect<BuiltinDialect, func::FuncDialect,
- vector::VectorDialect, LLVM::LLVMDialect>();
+ target
+ .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
+ LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index dbff23339d8b3..29851e2de5cb2 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -1,18 +1,40 @@
-// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file
+// -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' |
+// FileCheck %s --check-prefix=PRE9 RUN: mlir-opt %s -allow-unregistered-dialect
+// -split-input-file
+// -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' |
+// FileCheck %s --check-prefix=POST9
module @test_module {
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
// CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
// CHECK-LABEL: func @arith_remf
- func.func @arith_remf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = arith.remf %arg_f16, %arg_f16 : f16
- // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
- %result32 = arith.remf %arg_f32, %arg_f32 : f32
- // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
- %result64 = arith.remf %arg_f64, %arg_f64 : f64
- // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @arith_remf(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = arith.remf % arg_f16,
+ %
+ arg_f16 : f16
+ // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) :
+ // (f16, f16) -> f16
+ %
+ result32 = arith.remf % arg_f32,
+ %
+ arg_f32 : f32
+ // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) :
+ // (f32, f32) -> f32
+ %
+ result64 = arith.remf % arg_f64,
+ %
+ arg_f64 : f64
+ // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) :
+ // (f64, f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -23,14 +45,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
// CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
// CHECK-LABEL: func @math_acos
- func.func @math_acos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.acos %arg_f16 : f16
- // CHECK: llvm.call @__ocml_acos_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.acos %arg_f32 : f32
- // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.acos %arg_f64 : f64
- // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_acos(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.acos %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_acos_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.acos %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.acos %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -41,14 +77,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64
// CHECK-LABEL: func @math_acosh
- func.func @math_acosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.acosh %arg_f16 : f16
- // CHECK: llvm.call @__ocml_acosh_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.acosh %arg_f32 : f32
- // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.acosh %arg_f64 : f64
- // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_acosh(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.acosh %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_acosh_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.acosh %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.acosh %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -59,14 +109,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_asin_f32(f32) -> f32
// CHECK: llvm.func @__ocml_asin_f64(f64) -> f64
// CHECK-LABEL: func @math_asin
- func.func @math_asin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.asin %arg_f16 : f16
- // CHECK: llvm.call @__ocml_asin_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.asin %arg_f32 : f32
- // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.asin %arg_f64 : f64
- // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_asin(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.asin %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_asin_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.asin %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.asin %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -77,14 +141,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64
// CHECK-LABEL: func @math_asinh
- func.func @math_asinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.asinh %arg_f16 : f16
- // CHECK: llvm.call @__ocml_asinh_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.asinh %arg_f32 : f32
- // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.asinh %arg_f64 : f64
- // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_asinh(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.asinh %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_asinh_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.asinh %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.asinh %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -95,14 +173,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
// CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
// CHECK-LABEL: func @math_atan
- func.func @math_atan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.atan %arg_f16 : f16
- // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.atan %arg_f32 : f32
- // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.atan %arg_f64 : f64
- // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_atan(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.atan %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.atan %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.atan %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -113,14 +205,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64
// CHECK-LABEL: func @math_atanh
- func.func @math_atanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.atanh %arg_f16 : f16
- // CHECK: llvm.call @__ocml_atanh_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.atanh %arg_f32 : f32
- // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.atanh %arg_f64 : f64
- // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_atanh(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.atanh %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_atanh_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.atanh %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.atanh %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -131,14 +237,31 @@ module @test_module {
// CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64
// CHECK-LABEL: func @math_atan2
- func.func @math_atan2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.atan2 %arg_f16, %arg_f16 : f16
- // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
- %result32 = math.atan2 %arg_f32, %arg_f32 : f32
- // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
- %result64 = math.atan2 %arg_f64, %arg_f64 : f64
- // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_atan2(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.atan2 % arg_f16,
+ %
+ arg_f16 : f16
+ // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}, %{{.*}}) :
+ // (f16, f16) -> f16
+ %
+ result32 = math.atan2 % arg_f32,
+ %
+ arg_f32 : f32
+ // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) :
+ // (f32, f32) -> f32
+ %
+ result64 = math.atan2 % arg_f64,
+ %
+ arg_f64 : f64
+ // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}})
+ // : (f64, f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -149,14 +272,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
// CHECK-LABEL: func @math_cbrt
- func.func @math_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.cbrt %arg_f16 : f16
- // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.cbrt %arg_f32 : f32
- // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.cbrt %arg_f64 : f64
- // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_cbrt(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.cbrt %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.cbrt %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.cbrt %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -167,14 +304,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
// CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
// CHECK-LABEL: func @math_ceil
- func.func @math_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.ceil %arg_f16 : f16
- // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.ceil %arg_f32 : f32
- // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.ceil %arg_f64 : f64
- // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_ceil(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.ceil %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.ceil %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.ceil %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -185,14 +336,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
// CHECK-LABEL: func @math_cos
- func.func @math_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.cos %arg_f16 : f16
- // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.cos %arg_f32 : f32
- // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.cos %arg_f64 : f64
- // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_cos(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.cos %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.cos %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.cos %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -203,14 +368,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_cosh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cosh_f64(f64) -> f64
// CHECK-LABEL: func @math_cosh
- func.func @math_cosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.cosh %arg_f16 : f16
- // CHECK: llvm.call @__ocml_cosh_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.cosh %arg_f32 : f32
- // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.cosh %arg_f64 : f64
- // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_cosh(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.cosh %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_cosh_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.cosh %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.cosh %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -221,14 +400,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_sinh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sinh_f64(f64) -> f64
// CHECK-LABEL: func @math_sinh
- func.func @math_sinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.sinh %arg_f16 : f16
- // CHECK: llvm.call @__ocml_sinh_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.sinh %arg_f32 : f32
- // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.sinh %arg_f64 : f64
- // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_sinh(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.sinh %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_sinh_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.sinh %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.sinh %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -238,12 +431,18 @@ module @test_module {
// CHECK: llvm.func @__ocml_exp_f16(f16) -> f16
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
// CHECK-LABEL: func @math_exp
- func.func @math_exp(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
- %result16 = math.exp %arg_f16 : f16
- // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
- %result64 = math.exp %arg_f64 : f64
- // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result64 : f16, f64
+ func.func @math_exp(% arg_f16 : f16, % arg_f64 : f64)->(f16, f64) {
+ % result16 =
+ math.exp %
+ arg_f16 : f16
+ // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
+ %
+ result64 = math.exp %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
+ func.return % result16,
+ % result64 : f16, f64
}
}
@@ -254,14 +453,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
// CHECK-LABEL: func @math_exp2
- func.func @math_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.exp2 %arg_f16 : f16
- // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.exp2 %arg_f32 : f32
- // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.exp2 %arg_f64 : f64
- // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_exp2(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.exp2 %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.exp2 %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.exp2 %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -272,14 +485,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32
// CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64
// CHECK-LABEL: func @math_expm1
- func.func @math_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.expm1 %arg_f16 : f16
- // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.expm1 %arg_f32 : f32
- // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.expm1 %arg_f64 : f64
- // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_expm1(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.expm1 %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.expm1 %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.expm1 %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -290,14 +517,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
// CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
// CHECK-LABEL: func @math_floor
- func.func @math_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.floor %arg_f16 : f16
- // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.floor %arg_f32 : f32
- // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.floor %arg_f64 : f64
- // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_floor(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.floor %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.floor %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.floor %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -307,12 +548,18 @@ module @test_module {
// CHECK: llvm.func @__ocml_log_f16(f16) -> f16
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
// CHECK-LABEL: func @math_log
- func.func @math_log(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
- %result16 = math.log %arg_f16 : f16
- // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
- %result64 = math.log %arg_f64 : f64
- // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result64 : f16, f64
+ func.func @math_log(% arg_f16 : f16, % arg_f64 : f64)->(f16, f64) {
+ % result16 =
+ math.log %
+ arg_f16 : f16
+ // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
+ %
+ result64 = math.log %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
+ func.return % result16,
+ % result64 : f16, f64
}
}
@@ -323,14 +570,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
// CHECK-LABEL: func @math_log10
- func.func @math_log10(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.log10 %arg_f16 : f16
- // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.log10 %arg_f32 : f32
- // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.log10 %arg_f64 : f64
- // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_log10(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.log10 %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.log10 %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.log10 %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -341,14 +602,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
// CHECK-LABEL: func @math_log1p
- func.func @math_log1p(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.log1p %arg_f16 : f16
- // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.log1p %arg_f32 : f32
- // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.log1p %arg_f64 : f64
- // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_log1p(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.log1p %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.log1p %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.log1p %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -359,14 +634,31 @@ module @test_module {
// CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64
// CHECK-LABEL: func @math_powf
- func.func @math_powf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.powf %arg_f16, %arg_f16 : f16
- // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
- %result32 = math.powf %arg_f32, %arg_f32 : f32
- // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
- %result64 = math.powf %arg_f64, %arg_f64 : f64
- // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_powf(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.powf % arg_f16,
+ %
+ arg_f16 : f16
+ // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) :
+ // (f16, f16) -> f16
+ %
+ result32 = math.powf % arg_f32,
+ %
+ arg_f32 : f32
+ // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) :
+ // (f32, f32) -> f32
+ %
+ result64 = math.powf % arg_f64,
+ %
+ arg_f64 : f64
+ // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) :
+ // (f64, f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -377,14 +669,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64
// CHECK-LABEL: func @math_rsqrt
- func.func @math_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.rsqrt %arg_f16 : f16
- // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.rsqrt %arg_f32 : f32
- // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.rsqrt %arg_f64 : f64
- // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_rsqrt(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.rsqrt %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.rsqrt %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.rsqrt %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -395,14 +701,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
// CHECK-LABEL: func @math_sin
- func.func @math_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.sin %arg_f16 : f16
- // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.sin %arg_f32 : f32
- // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.sin %arg_f64 : f64
- // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_sin(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.sin %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.sin %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.sin %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -413,14 +733,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64
// CHECK-LABEL: func @math_tanh
- func.func @math_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.tanh %arg_f16 : f16
- // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.tanh %arg_f32 : f32
- // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.tanh %arg_f64 : f64
- // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_tanh(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.tanh %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.tanh %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.tanh %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -431,14 +765,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
// CHECK-LABEL: func @math_tan
- func.func @math_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.tan %arg_f16 : f16
- // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.tan %arg_f32 : f32
- // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.tan %arg_f64 : f64
- // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_tan(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.tan %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.tan %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.tan %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -449,14 +797,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_erf_f32(f32) -> f32
// CHECK: llvm.func @__ocml_erf_f64(f64) -> f64
// CHECK-LABEL: func @math_erf
- func.func @math_erf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.erf %arg_f16 : f16
- // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.erf %arg_f32 : f32
- // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.erf %arg_f64 : f64
- // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_erf(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.erf %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.erf %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.erf %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -467,14 +829,28 @@ module @test_module {
// CHECK: llvm.func @__ocml_erfc_f32(f32) -> f32
// CHECK: llvm.func @__ocml_erfc_f64(f64) -> f64
// CHECK-LABEL: func @math_erfc
- func.func @math_erfc(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
- %result16 = math.erfc %arg_f16 : f16
- // CHECK: llvm.call @__ocml_erfc_f16(%{{.*}}) : (f16) -> f16
- %result32 = math.erfc %arg_f32 : f32
- // CHECK: llvm.call @__ocml_erfc_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.erfc %arg_f64 : f64
- // CHECK: llvm.call @__ocml_erfc_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ func.func @math_erfc(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64)
+ ->(f16, f32, f64) {
+ % result16 = math.erfc %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_erfc_f16(%{{.*}}) : (f16) -> f16
+ %
+ result32 = math.erfc %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_erfc_f32(%{{.*}}) : (f32) -> f32
+ %
+ result64 = math.erfc %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_erfc_f64(%{{.*}}) : (f64) -> f64
+ func.return %
+ result16,
+ % result32, % result64 : f16, f32, f64
}
}
@@ -485,18 +861,36 @@ module @test_module {
// CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
// CHECK-LABEL: func @math_casting
- func.func @math_casting(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64, %arg_bf16 : bf16) -> (f16, f32, f64, bf16) {
- %resultf16 = math.sin %arg_f16 : f16
- // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
- %resultf32 = math.sin %arg_f32 : f32
- // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
- %resultf64 = math.sin %arg_f64 : f64
- // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
- %resultbf16 = math.sin %arg_bf16 : bf16
- // CHECK: llvm.fpext %{{.*}} : bf16 to f32
- // CHECK-NEXT: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
- // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to bf16
- func.return %resultf16, %resultf32, %resultf64, %resultbf16 : f16, f32, f64, bf16
+ func.func @math_casting(% arg_f16
+ : f16, % arg_f32
+ : f32, % arg_f64
+ : f64, % arg_bf16
+ : bf16)
+ ->(f16, f32, f64, bf16) {
+ % resultf16 = math.sin %
+ arg_f16
+ : f16
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ %
+ resultf32 = math.sin %
+ arg_f32
+ : f32
+ // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+ %
+ resultf64 = math.sin %
+ arg_f64
+ : f64
+ // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
+ %
+ resultbf16 = math.sin %
+ arg_bf16
+ : bf16
+ // CHECK: llvm.fpext %{{.*}} : bf16 to f32
+ // CHECK-NEXT: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+ // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to bf16
+ func.return %
+ resultf16,
+ % resultf32, % resultf64, % resultbf16 : f16, f32, f64, bf16
}
}
@@ -507,14 +901,22 @@ module @test_module {
// CHECK: llvm.func @__ocml_pown_f32(f32, i32) -> f32
// CHECK: llvm.func @__ocml_pown_f64(f64, i32) -> f64
// CHECK-LABEL: func @math_fpowi
- func.func @math_fpowi(%arg0: f16, %arg1: f32, %arg2: f64, %arg3: i32) -> (f16, f32, f64) {
+ func.func @math_fpowi(% arg0
+ : f16, % arg1
+ : f32, % arg2
+ : f64, % arg3
+ : i32)
+ ->(f16, f32, f64) {
// CHECK: llvm.call @__ocml_pown_f16(%{{.*}}) : (f16, i32) -> f16
- %0 = math.fpowi %arg0, %arg3 : f16, i32
- // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32
- %1 = math.fpowi %arg1, %arg3 : f32, i32
- // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64
- %2 = math.fpowi %arg2, %arg3 : f64, i32
- return %0, %1, %2 : f16, f32, f64
+ % 0 = math.fpowi % arg0, % arg3 : f16,
+ i32
+ // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32
+ % 1 = math.fpowi % arg1,
+ % arg3 : f32,
+ i32
+ // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64
+ % 2 = math.fpowi % arg2,
+ % arg3 : f64, i32 return % 0, % 1, % 2 : f16, f32, f64
}
}
@@ -523,13 +925,13 @@ module @test_module {
// Math operation not inside function
// Ensure it not crash
-module {
- "test.some_op_with_region"() ({
- ^bb0(%arg0: f64):
- // CHECK: math.atan
- %0 = math.atan %arg0 : f64
- "test.possible_terminator"() : () -> ()
- }) : () -> ()
+module{
+ "test.some_op_with_region"()({
+ ^bb0(% arg0:f64) :
+ // CHECK: math.atan
+ % 0 = math.atan % arg0:f64 "test.possible_terminator"() : ()->()
+ }) : ()
+ ->()
}
// -----
@@ -537,12 +939,11 @@ module {
module @test_module {
// CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
// CHECK-LABEL: func @math_sin_vector_0d
- func.func @math_sin_vector_0d(%arg : vector<f16>) -> vector<f16> {
+ func.func @math_sin_vector_0d(% arg : vector<f16>)->vector<f16> {
// CHECK: llvm.extractelement {{.*}} : vector<1xf16>
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<1xf16>
- %result = math.sin %arg : vector<f16>
- func.return %result : vector<f16>
+ % result = math.sin % arg : vector<f16> func.return % result : vector<f16>
}
}
@@ -551,7 +952,7 @@ module @test_module {
module @test_module {
// CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
// CHECK-LABEL: func @math_sin_vector_1d
- func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
+ func.func @math_sin_vector_1d(% arg : vector<4xf16>)->vector<4xf16> {
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
@@ -564,8 +965,8 @@ module @test_module {
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
- %result = math.sin %arg : vector<4xf16>
- func.return %result : vector<4xf16>
+ % result =
+ math.sin % arg : vector<4xf16> func.return % result : vector<4xf16>
}
}
@@ -574,11 +975,11 @@ module @test_module {
module @test_module {
// CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
// CHECK-LABEL: func @math_sin_vector_2d
- func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
- // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
- // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
- // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
- // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ func.func @math_sin_vector_2d(% arg : vector<2x2xf16>)->vector<2x2xf16> {
+ // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to
+ // !llvm.array<2 x vector<2xf16>> CHECK: llvm.extractvalue {{.*}} :
+ // !llvm.array<2 x vector<2xf16>> CHECK: llvm.extractelement {{.*}} :
+ // vector<2xf16> CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
@@ -591,8 +992,42 @@ module @test_module {
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
- // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
- %result = math.sin %arg : vector<2x2xf16>
- func.return %result : vector<2x2xf16>
+ // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ % result =
+ math.sin % arg : vector<2x2xf16> func.return % result : vector<2x2xf16>
}
}
+
+// -----
+
+// f16 clamp → rocdl.fmed3 on gfx9+
+func.func @clampf_f16(% x
+ : f16, % lo
+ : f16, % hi
+ : f16)
+ ->f16{ % r = math.clampf % x to[% lo, % hi] : f16 return % r : f16}
+
+// f32 clamp → rocdl.fmed3 on gfx9+
+func.func @clampf_f32(% x
+ : f32, % lo
+ : f32, % hi
+ : f32)
+ ->f32 {
+ % r = math.clampf % x to[% lo, % hi] : f32 return % r : f32
+}
+
+// POST9-LABEL: func.func @clampf_f16
+// POST9: rocdl.fmed3 {{.*}} : f16
+// POST9: return
+
+// POST9-LABEL: func.func @clampf_f32
+// POST9: rocdl.fmed3 {{.*}} : f32
+// POST9: return
+
+// PRE9-LABEL: func.func @clampf_f16
+// PRE9-NOT: rocdl.fmed3
+// PRE9: math.clampf {{.*}} : f16
+
+// PRE9-LABEL: func.func @clampf_f32
+// PRE9-NOT: rocdl.fmed3
+// PRE9: math.clampf {{.*}} : f32
>From 92bcb55d165dcf4407b045a38d98f01bd2a0c2bc Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Mon, 13 Oct 2025 13:00:35 -0700
Subject: [PATCH 02/12] Removed incorrect formatting
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
.../Conversion/MathToROCDL/math-to-rocdl.mlir | 927 +++++-------------
1 file changed, 261 insertions(+), 666 deletions(-)
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 29851e2de5cb2..7244b0aac8e43 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -1,40 +1,19 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file
-// -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' |
-// FileCheck %s --check-prefix=PRE9 RUN: mlir-opt %s -allow-unregistered-dialect
-// -split-input-file
-// -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' |
-// FileCheck %s --check-prefix=POST9
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9
module @test_module {
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
// CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
// CHECK-LABEL: func @arith_remf
- func.func @arith_remf(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = arith.remf % arg_f16,
- %
- arg_f16 : f16
- // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) :
- // (f16, f16) -> f16
- %
- result32 = arith.remf % arg_f32,
- %
- arg_f32 : f32
- // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) :
- // (f32, f32) -> f32
- %
- result64 = arith.remf % arg_f64,
- %
- arg_f64 : f64
- // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) :
- // (f64, f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @arith_remf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = arith.remf %arg_f16, %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
+ %result32 = arith.remf %arg_f32, %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result64 = arith.remf %arg_f64, %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -45,28 +24,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
// CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
// CHECK-LABEL: func @math_acos
- func.func @math_acos(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.acos %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_acos_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.acos %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.acos %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_acos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.acos %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_acos_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.acos %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.acos %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -77,28 +42,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64
// CHECK-LABEL: func @math_acosh
- func.func @math_acosh(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.acosh %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_acosh_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.acosh %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.acosh %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_acosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.acosh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_acosh_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.acosh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.acosh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -109,28 +60,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_asin_f32(f32) -> f32
// CHECK: llvm.func @__ocml_asin_f64(f64) -> f64
// CHECK-LABEL: func @math_asin
- func.func @math_asin(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.asin %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_asin_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.asin %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.asin %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_asin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.asin %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_asin_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.asin %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.asin %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -141,28 +78,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64
// CHECK-LABEL: func @math_asinh
- func.func @math_asinh(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.asinh %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_asinh_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.asinh %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.asinh %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_asinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.asinh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_asinh_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.asinh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.asinh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -173,28 +96,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
// CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
// CHECK-LABEL: func @math_atan
- func.func @math_atan(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.atan %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.atan %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.atan %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_atan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.atan %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.atan %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.atan %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -205,28 +114,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64
// CHECK-LABEL: func @math_atanh
- func.func @math_atanh(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.atanh %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_atanh_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.atanh %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.atanh %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_atanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.atanh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_atanh_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.atanh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.atanh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -237,31 +132,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64
// CHECK-LABEL: func @math_atan2
- func.func @math_atan2(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.atan2 % arg_f16,
- %
- arg_f16 : f16
- // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}, %{{.*}}) :
- // (f16, f16) -> f16
- %
- result32 = math.atan2 % arg_f32,
- %
- arg_f32 : f32
- // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) :
- // (f32, f32) -> f32
- %
- result64 = math.atan2 % arg_f64,
- %
- arg_f64 : f64
- // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}})
- // : (f64, f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_atan2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.atan2 %arg_f16, %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
+ %result32 = math.atan2 %arg_f32, %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result64 = math.atan2 %arg_f64, %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -272,28 +150,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
// CHECK-LABEL: func @math_cbrt
- func.func @math_cbrt(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.cbrt %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.cbrt %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.cbrt %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.cbrt %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.cbrt %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.cbrt %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -304,28 +168,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
// CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
// CHECK-LABEL: func @math_ceil
- func.func @math_ceil(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.ceil %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.ceil %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.ceil %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.ceil %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.ceil %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.ceil %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -336,28 +186,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
// CHECK-LABEL: func @math_cos
- func.func @math_cos(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.cos %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.cos %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.cos %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.cos %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.cos %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.cos %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -368,28 +204,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_cosh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cosh_f64(f64) -> f64
// CHECK-LABEL: func @math_cosh
- func.func @math_cosh(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.cosh %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_cosh_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.cosh %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.cosh %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_cosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.cosh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_cosh_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.cosh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.cosh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -400,28 +222,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_sinh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sinh_f64(f64) -> f64
// CHECK-LABEL: func @math_sinh
- func.func @math_sinh(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.sinh %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_sinh_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.sinh %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.sinh %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_sinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.sinh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_sinh_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.sinh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.sinh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -431,18 +239,12 @@ module @test_module {
// CHECK: llvm.func @__ocml_exp_f16(f16) -> f16
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
// CHECK-LABEL: func @math_exp
- func.func @math_exp(% arg_f16 : f16, % arg_f64 : f64)->(f16, f64) {
- % result16 =
- math.exp %
- arg_f16 : f16
- // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
- %
- result64 = math.exp %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
- func.return % result16,
- % result64 : f16, f64
+ func.func @math_exp(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+ %result16 = math.exp %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
+ %result64 = math.exp %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result64 : f16, f64
}
}
@@ -453,28 +255,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
// CHECK-LABEL: func @math_exp2
- func.func @math_exp2(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.exp2 %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.exp2 %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.exp2 %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.exp2 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.exp2 %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.exp2 %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -485,28 +273,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32
// CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64
// CHECK-LABEL: func @math_expm1
- func.func @math_expm1(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.expm1 %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.expm1 %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.expm1 %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.expm1 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.expm1 %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.expm1 %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -517,28 +291,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
// CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
// CHECK-LABEL: func @math_floor
- func.func @math_floor(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.floor %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.floor %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.floor %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.floor %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.floor %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.floor %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -548,18 +308,12 @@ module @test_module {
// CHECK: llvm.func @__ocml_log_f16(f16) -> f16
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
// CHECK-LABEL: func @math_log
- func.func @math_log(% arg_f16 : f16, % arg_f64 : f64)->(f16, f64) {
- % result16 =
- math.log %
- arg_f16 : f16
- // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
- %
- result64 = math.log %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
- func.return % result16,
- % result64 : f16, f64
+ func.func @math_log(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+ %result16 = math.log %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
+ %result64 = math.log %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result64 : f16, f64
}
}
@@ -570,28 +324,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
// CHECK-LABEL: func @math_log10
- func.func @math_log10(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.log10 %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.log10 %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.log10 %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_log10(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.log10 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.log10 %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.log10 %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -602,28 +342,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
// CHECK-LABEL: func @math_log1p
- func.func @math_log1p(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.log1p %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.log1p %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.log1p %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_log1p(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.log1p %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.log1p %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.log1p %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -634,31 +360,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64
// CHECK-LABEL: func @math_powf
- func.func @math_powf(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.powf % arg_f16,
- %
- arg_f16 : f16
- // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) :
- // (f16, f16) -> f16
- %
- result32 = math.powf % arg_f32,
- %
- arg_f32 : f32
- // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) :
- // (f32, f32) -> f32
- %
- result64 = math.powf % arg_f64,
- %
- arg_f64 : f64
- // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) :
- // (f64, f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_powf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.powf %arg_f16, %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
+ %result32 = math.powf %arg_f32, %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result64 = math.powf %arg_f64, %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -669,28 +378,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64
// CHECK-LABEL: func @math_rsqrt
- func.func @math_rsqrt(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.rsqrt %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.rsqrt %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.rsqrt %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.rsqrt %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.rsqrt %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.rsqrt %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -701,28 +396,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
// CHECK-LABEL: func @math_sin
- func.func @math_sin(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.sin %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.sin %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.sin %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.sin %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.sin %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.sin %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -733,28 +414,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64
// CHECK-LABEL: func @math_tanh
- func.func @math_tanh(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.tanh %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.tanh %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.tanh %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.tanh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.tanh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.tanh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -765,28 +432,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
// CHECK-LABEL: func @math_tan
- func.func @math_tan(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.tan %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.tan %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.tan %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.tan %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.tan %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.tan %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -797,28 +450,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_erf_f32(f32) -> f32
// CHECK: llvm.func @__ocml_erf_f64(f64) -> f64
// CHECK-LABEL: func @math_erf
- func.func @math_erf(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.erf %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.erf %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.erf %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_erf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.erf %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.erf %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.erf %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -829,28 +468,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_erfc_f32(f32) -> f32
// CHECK: llvm.func @__ocml_erfc_f64(f64) -> f64
// CHECK-LABEL: func @math_erfc
- func.func @math_erfc(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64)
- ->(f16, f32, f64) {
- % result16 = math.erfc %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_erfc_f16(%{{.*}}) : (f16) -> f16
- %
- result32 = math.erfc %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_erfc_f32(%{{.*}}) : (f32) -> f32
- %
- result64 = math.erfc %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_erfc_f64(%{{.*}}) : (f64) -> f64
- func.return %
- result16,
- % result32, % result64 : f16, f32, f64
+ func.func @math_erfc(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.erfc %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_erfc_f16(%{{.*}}) : (f16) -> f16
+ %result32 = math.erfc %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_erfc_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.erfc %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_erfc_f64(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -861,36 +486,18 @@ module @test_module {
// CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
// CHECK-LABEL: func @math_casting
- func.func @math_casting(% arg_f16
- : f16, % arg_f32
- : f32, % arg_f64
- : f64, % arg_bf16
- : bf16)
- ->(f16, f32, f64, bf16) {
- % resultf16 = math.sin %
- arg_f16
- : f16
- // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
- %
- resultf32 = math.sin %
- arg_f32
- : f32
- // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
- %
- resultf64 = math.sin %
- arg_f64
- : f64
- // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
- %
- resultbf16 = math.sin %
- arg_bf16
- : bf16
- // CHECK: llvm.fpext %{{.*}} : bf16 to f32
- // CHECK-NEXT: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
- // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to bf16
- func.return %
- resultf16,
- % resultf32, % resultf64, % resultbf16 : f16, f32, f64, bf16
+ func.func @math_casting(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64, %arg_bf16 : bf16) -> (f16, f32, f64, bf16) {
+ %resultf16 = math.sin %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ %resultf32 = math.sin %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+ %resultf64 = math.sin %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
+ %resultbf16 = math.sin %arg_bf16 : bf16
+ // CHECK: llvm.fpext %{{.*}} : bf16 to f32
+ // CHECK-NEXT: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+ // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to bf16
+ func.return %resultf16, %resultf32, %resultf64, %resultbf16 : f16, f32, f64, bf16
}
}
@@ -901,22 +508,14 @@ module @test_module {
// CHECK: llvm.func @__ocml_pown_f32(f32, i32) -> f32
// CHECK: llvm.func @__ocml_pown_f64(f64, i32) -> f64
// CHECK-LABEL: func @math_fpowi
- func.func @math_fpowi(% arg0
- : f16, % arg1
- : f32, % arg2
- : f64, % arg3
- : i32)
- ->(f16, f32, f64) {
+ func.func @math_fpowi(%arg0: f16, %arg1: f32, %arg2: f64, %arg3: i32) -> (f16, f32, f64) {
// CHECK: llvm.call @__ocml_pown_f16(%{{.*}}) : (f16, i32) -> f16
- % 0 = math.fpowi % arg0, % arg3 : f16,
- i32
- // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32
- % 1 = math.fpowi % arg1,
- % arg3 : f32,
- i32
- // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64
- % 2 = math.fpowi % arg2,
- % arg3 : f64, i32 return % 0, % 1, % 2 : f16, f32, f64
+ %0 = math.fpowi %arg0, %arg3 : f16, i32
+ // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32
+ %1 = math.fpowi %arg1, %arg3 : f32, i32
+ // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64
+ %2 = math.fpowi %arg2, %arg3 : f64, i32
+ return %0, %1, %2 : f16, f32, f64
}
}
@@ -925,13 +524,13 @@ module @test_module {
// Math operation not inside function
// Ensure it not crash
-module{
- "test.some_op_with_region"()({
- ^bb0(% arg0:f64) :
- // CHECK: math.atan
- % 0 = math.atan % arg0:f64 "test.possible_terminator"() : ()->()
- }) : ()
- ->()
+module {
+ "test.some_op_with_region"() ({
+ ^bb0(%arg0: f64):
+ // CHECK: math.atan
+ %0 = math.atan %arg0 : f64
+ "test.possible_terminator"() : () -> ()
+ }) : () -> ()
}
// -----
@@ -939,11 +538,12 @@ module{
module @test_module {
// CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
// CHECK-LABEL: func @math_sin_vector_0d
- func.func @math_sin_vector_0d(% arg : vector<f16>)->vector<f16> {
+ func.func @math_sin_vector_0d(%arg : vector<f16>) -> vector<f16> {
// CHECK: llvm.extractelement {{.*}} : vector<1xf16>
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<1xf16>
- % result = math.sin % arg : vector<f16> func.return % result : vector<f16>
+ %result = math.sin %arg : vector<f16>
+ func.return %result : vector<f16>
}
}
@@ -952,7 +552,7 @@ module @test_module {
module @test_module {
// CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
// CHECK-LABEL: func @math_sin_vector_1d
- func.func @math_sin_vector_1d(% arg : vector<4xf16>)->vector<4xf16> {
+ func.func @math_sin_vector_1d(%arg : vector<4xf16>) -> vector<4xf16> {
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
@@ -965,8 +565,8 @@ module @test_module {
// CHECK: llvm.extractelement {{.*}} : vector<4xf16>
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<4xf16>
- % result =
- math.sin % arg : vector<4xf16> func.return % result : vector<4xf16>
+ %result = math.sin %arg : vector<4xf16>
+ func.return %result : vector<4xf16>
}
}
@@ -975,11 +575,11 @@ module @test_module {
module @test_module {
// CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
// CHECK-LABEL: func @math_sin_vector_2d
- func.func @math_sin_vector_2d(% arg : vector<2x2xf16>)->vector<2x2xf16> {
- // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to
- // !llvm.array<2 x vector<2xf16>> CHECK: llvm.extractvalue {{.*}} :
- // !llvm.array<2 x vector<2xf16>> CHECK: llvm.extractelement {{.*}} :
- // vector<2xf16> CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ func.func @math_sin_vector_2d(%arg : vector<2x2xf16>) -> vector<2x2xf16> {
+ // CHECK: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractelement {{.*}} : vector<2xf16>
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
@@ -992,28 +592,24 @@ module @test_module {
// CHECK: llvm.extractelement {{.*}} : vector<2xf16>
// CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
// CHECK: llvm.insertelement {{.*}} : vector<2xf16>
- // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
- % result =
- math.sin % arg : vector<2x2xf16> func.return % result : vector<2x2xf16>
+ // CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ %result = math.sin %arg : vector<2x2xf16>
+ func.return %result : vector<2x2xf16>
}
}
// -----
// f16 clamp → rocdl.fmed3 on gfx9+
-func.func @clampf_f16(% x
- : f16, % lo
- : f16, % hi
- : f16)
- ->f16{ % r = math.clampf % x to[% lo, % hi] : f16 return % r : f16}
+func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
+ %r = math.clampf %x to [%lo, %hi] : f16
+ return %r : f16
+}
// f32 clamp → rocdl.fmed3 on gfx9+
-func.func @clampf_f32(% x
- : f32, % lo
- : f32, % hi
- : f32)
- ->f32 {
- % r = math.clampf % x to[% lo, % hi] : f32 return % r : f32
+func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
+ %r = math.clampf %x to [%lo, %hi] : f32
+ return %r : f32
}
// POST9-LABEL: func.func @clampf_f16
@@ -1030,4 +626,3 @@ func.func @clampf_f32(% x
// PRE9-LABEL: func.func @clampf_f32
// PRE9-NOT: rocdl.fmed3
-// PRE9: math.clampf {{.*}} : f32
>From 49b08f9a4ce206e9768b3f341c49f6377c21d116 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Mon, 13 Oct 2025 13:12:01 -0700
Subject: [PATCH 03/12] Corrected pass option
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/include/mlir/Conversion/Passes.td | 11 +++--------
1 file changed, 3 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index c3fd397e258ae..06bd82341acab 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -755,14 +755,6 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
"func::FuncDialect",
"vector::VectorDialect",
];
- let options = [
- Option<"chipset", "chipset", "std::string",
-
-
- /*default=*/"\"gfx000\"",
- "Chipset that these operations will run on">
- ];
-
}
//===----------------------------------------------------------------------===//
@@ -793,6 +785,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
"ROCDL::ROCDLDialect",
"vector::VectorDialect",
];
+ let options = [Option<"chipset", "chipset", "std::string",
+ /*default=*/"\"gfx000\"",
+ "Chipset that these operations will run on">];
}
//===----------------------------------------------------------------------===//
>From 636ef8d9581229c916f69819a0fc172a648124bb Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 14 Oct 2025 06:50:42 -0700
Subject: [PATCH 04/12] Addressed Comments by Krzysztof: 1. Added lit test for
1D and 2D vectors 2. Added unrolling support for ND inputs
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
.../Conversion/MathToROCDL/MathToROCDL.cpp | 18 ++++++
.../Conversion/MathToROCDL/math-to-rocdl.mlir | 57 +++++++++++++++----
2 files changed, 64 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index ceb3d22c6bd59..d8e3c34399ad4 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -10,6 +10,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -59,10 +60,27 @@ struct ClampFOpConversion final
op, ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) +
"): V_MED_F16 / V_MED3_F32 not supported."));
}
+ auto resultType = getTypeConverter()->convertType(op.getType());
+ // Handle multi-dimensional vectors (converted to LLVM arrays)
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) {
+ // Handle multi-dimensional vectors (converted to LLVM arrays)
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+ typename math::ClampFOp::Adaptor adaptor(operands);
+ return rewriter.create<ROCDL::FMed3Op>(
+ op.getLoc(), llvm1DVectorTy, adaptor.getValue(),
+ adaptor.getMin(), adaptor.getMax());
+ },
+ rewriter);
+ }
+
+ // Handle 1D vectors and scalars directly
rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
op.getMin(), op.getMax());
return success();
}
+
amdgpu::Chipset chipset;
};
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 7244b0aac8e43..55d48fa0d27f1 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -601,28 +601,63 @@ module @test_module {
// -----
// f16 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_f16
func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
%r = math.clampf %x to [%lo, %hi] : f16
return %r : f16
+ // POST9: rocdl.fmed3 {{.*}} : f16
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : f16
}
// f32 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_f32
func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
%r = math.clampf %x to [%lo, %hi] : f32
return %r : f32
+ // POST9: rocdl.fmed3 {{.*}} : f32
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : f32
}
-// POST9-LABEL: func.func @clampf_f16
-// POST9: rocdl.fmed3 {{.*}} : f16
-// POST9: return
+// -----
+
+// Vector f16 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_vector_f16
+func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> {
+ %r = math.clampf %x to [%lo, %hi] : vector<2xf16>
+ return %r : vector<2xf16>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : vector<2xf16>
+}
+
+// -----
-// POST9-LABEL: func.func @clampf_f32
-// POST9: rocdl.fmed3 {{.*}} : f32
-// POST9: return
+// Vector f32 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_vector_f32
+func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> {
+ %r = math.clampf %x to [%lo, %hi] : vector<2xf32>
+ return %r : vector<2xf32>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf32>
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : vector<2xf32>
+}
-// PRE9-LABEL: func.func @clampf_f16
-// PRE9-NOT: rocdl.fmed3
-// PRE9: math.clampf {{.*}} : f16
+// -----
-// PRE9-LABEL: func.func @clampf_f32
-// PRE9-NOT: rocdl.fmed3
+// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors)
+// CHECK-LABEL: func.func @clampf_vector_2d_f16
+func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> {
+ %r = math.clampf %x to [%lo, %hi] : vector<2x2xf16>
+ return %r : vector<2x2xf16>
+ // POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
+ // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+ // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+ // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+ // PRE9-NOT: rocdl.fmed3
+ // PRE9: math.clampf {{.*}} : vector<2x2xf16>
+}
>From 767c0aca77c01c7b15dc5750335d2284514d1c70 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 14 Oct 2025 07:07:22 -0700
Subject: [PATCH 05/12] Set chipset default value to empty
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/include/mlir/Conversion/Passes.td | 6 +++++-
.../lib/Conversion/MathToROCDL/MathToROCDL.cpp | 18 +++++++++++++-----
2 files changed, 18 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 06bd82341acab..78a6df3ad8755 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -778,6 +778,10 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
let summary = "Convert Math dialect to ROCDL library calls";
let description = [{
This pass converts supported Math ops to ROCDL library calls.
+
+ The chipset option specifies the target AMDGPU architecture. If the chipset
+ is empty, none of the chipset-dependent patterns are added and the pass
+ will not attempt to parse the chipset.
}];
let dependentDialects = [
"arith::ArithDialect",
@@ -786,7 +790,7 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
"vector::VectorDialect",
];
let options = [Option<"chipset", "chipset", "std::string",
- /*default=*/"\"gfx000\"",
+ /*default=*/"\"\"",
"Chipset that these operations will run on">];
}
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index d8e3c34399ad4..aef768e225d67 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -84,9 +84,9 @@ struct ClampFOpConversion final
amdgpu::Chipset chipset;
};
-static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns,
- amdgpu::Chipset chipset) {
+void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ amdgpu::Chipset chipset) {
patterns.add<ClampFOpConversion>(converter, chipset);
}
@@ -183,12 +183,20 @@ struct ConvertMathToROCDLPass final
void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
MLIRContext *ctx = m.getContext();
- FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
- populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
+
+ // Only populate chipset-dependent patterns if chipset is specified
+ if (!chipset.empty()) {
+ FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
+ if (failed(maybeChipset)) {
+ return signalPassFailure();
+ }
+ populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
+ }
+
ConversionTarget target(getContext());
target
.addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
>From 5b197b75892dc6b9e4e66694e9c9c6649c9ac1b7 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 14 Oct 2025 07:57:01 -0700
Subject: [PATCH 06/12] Pattern should only apply to f16/f32 types; added
reject lit for bf16
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp | 16 ++++++++++++++--
.../Conversion/MathToROCDL/math-to-rocdl.mlir | 10 ++++++++++
2 files changed, 24 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index aef768e225d67..38704157ba565 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -54,13 +54,25 @@ struct ClampFOpConversion final
LogicalResult
matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // V_MED3_F16/F32 only exists in gfx9+ artchitectures
+ // Only f16 and f32 types are supported by fmed3
+ Type opTy = op.getType();
+ auto resultType = getTypeConverter()->convertType(opTy);
+
+ if (auto vectorType = dyn_cast<VectorType>(opTy)) {
+ opTy = vectorType.getElementType();
+ }
+
+ if (!opTy.isF16() && !opTy.isF32()) {
+ return rewriter.notifyMatchFailure(
+ op, "fmed3 only supports f16 and f32 types");
+ }
+
+ // V_MED3_F16/F32 only exists in gfx9+ architectures
if (chipset.majorVersion < 9) {
return rewriter.notifyMatchFailure(
op, ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) +
"): V_MED_F16 / V_MED3_F32 not supported."));
}
- auto resultType = getTypeConverter()->convertType(op.getType());
// Handle multi-dimensional vectors (converted to LLVM arrays)
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) {
// Handle multi-dimensional vectors (converted to LLVM arrays)
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 55d48fa0d27f1..959230ae6cd49 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -661,3 +661,13 @@ func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi:
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : vector<2x2xf16>
}
+
+// -----
+// CHECK-LABEL: func.func @clampf_bf16
+func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 {
+ %r = math.clampf %x to [%lo, %hi] : bf16
+ return %r : bf16
+ // CHECK: math.clampf {{.*}} : bf16
+ // CHECK-NOT: rocdl.fmed3
+}
+
>From f25ec273391ffbb791fc95e19e01b6882203bbd2 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 14 Oct 2025 07:57:39 -0700
Subject: [PATCH 07/12] Formatting lit test
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 959230ae6cd49..455f886839604 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -670,4 +670,3 @@ func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 {
// CHECK: math.clampf {{.*}} : bf16
// CHECK-NOT: rocdl.fmed3
}
-
>From 9b50b9d903f54030c0d82298062a845ffc742e4f Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 14 Oct 2025 08:15:02 -0700
Subject: [PATCH 08/12] Moved GFX9+ condition to within
addChipsetDependentPatterns
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp | 11 ++++-------
1 file changed, 4 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 38704157ba565..0fb670020b964 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -67,12 +67,6 @@ struct ClampFOpConversion final
op, "fmed3 only supports f16 and f32 types");
}
- // V_MED3_F16/F32 only exists in gfx9+ architectures
- if (chipset.majorVersion < 9) {
- return rewriter.notifyMatchFailure(
- op, ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) +
- "): V_MED_F16 / V_MED3_F32 not supported."));
- }
// Handle multi-dimensional vectors (converted to LLVM arrays)
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) {
// Handle multi-dimensional vectors (converted to LLVM arrays)
@@ -100,7 +94,10 @@ void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
amdgpu::Chipset chipset) {
- patterns.add<ClampFOpConversion>(converter, chipset);
+ // V_MED3_F16/F32 only exists in gfx9+ architectures
+ if (chipset.majorVersion >= 9) {
+ patterns.add<ClampFOpConversion>(converter, chipset);
+ }
}
void mlir::populateMathToROCDLConversionPatterns(
>From 61af07cd931011944101f840087daa01bf8d2020 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 14 Oct 2025 10:05:47 -0700
Subject: [PATCH 09/12] Added valid default value for chipset to pass
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/include/mlir/Conversion/Passes.td | 2 +-
mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp | 14 +++++++-------
2 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 33936da0190cc..a2eb335faac6c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -790,7 +790,7 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
"vector::VectorDialect",
];
let options = [Option<"chipset", "chipset", "std::string",
- /*default=*/"\"\"",
+ /*default=*/"\"gfx000\"",
"Chipset that these operations will run on">];
}
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 0fb670020b964..4ba7eab64a785 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -62,7 +62,7 @@ struct ClampFOpConversion final
opTy = vectorType.getElementType();
}
- if (!opTy.isF16() && !opTy.isF32()) {
+ if (!isa<Float16Type, Float32Type>(opTy)) {
return rewriter.notifyMatchFailure(
op, "fmed3 only supports f16 and f32 types");
}
@@ -74,9 +74,9 @@ struct ClampFOpConversion final
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
typename math::ClampFOp::Adaptor adaptor(operands);
- return rewriter.create<ROCDL::FMed3Op>(
- op.getLoc(), llvm1DVectorTy, adaptor.getValue(),
- adaptor.getMin(), adaptor.getMax());
+ return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getValue(), adaptor.getMin(),
+ adaptor.getMax());
},
rewriter);
}
@@ -90,9 +90,9 @@ struct ClampFOpConversion final
amdgpu::Chipset chipset;
};
-void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns,
- amdgpu::Chipset chipset) {
+static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ amdgpu::Chipset chipset) {
// V_MED3_F16/F32 only exists in gfx9+ architectures
if (chipset.majorVersion >= 9) {
>From 0a3eab91a90e8639eb44ce87036315d7251aa201 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Wed, 15 Oct 2025 01:04:03 -0700
Subject: [PATCH 10/12] Removed populateMathToROCDLConversionPatterns from
condition block
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 4ba7eab64a785..11d9e1c8d0296 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -198,13 +198,14 @@ void ConvertMathToROCDLPass::runOnOperation() {
LLVMTypeConverter converter(ctx, options);
// Only populate chipset-dependent patterns if chipset is specified
+ FailureOr<amdgpu::Chipset> maybeChipset;
if (!chipset.empty()) {
- FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
+ maybeChipset = amdgpu::Chipset::parse(chipset);
if (failed(maybeChipset)) {
return signalPassFailure();
}
- populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
}
+ populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
ConversionTarget target(getContext());
target
>From b73b30f8072fe34ec0bd660fd953a7427ce59036 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <31160700+keshavvinayak01 at users.noreply.github.com>
Date: Wed, 15 Oct 2025 13:40:15 +0530
Subject: [PATCH 11/12] Update MathToROCDL.cpp
---
mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 11d9e1c8d0296..a2bff066ef21a 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -95,6 +95,7 @@ static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
amdgpu::Chipset chipset) {
// V_MED3_F16/F32 only exists in gfx9+ architectures
+ // Only populate chipset-dependent patterns if chipset is specified
if (chipset.majorVersion >= 9) {
patterns.add<ClampFOpConversion>(converter, chipset);
}
@@ -197,7 +198,6 @@ void ConvertMathToROCDLPass::runOnOperation() {
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
- // Only populate chipset-dependent patterns if chipset is specified
FailureOr<amdgpu::Chipset> maybeChipset;
if (!chipset.empty()) {
maybeChipset = amdgpu::Chipset::parse(chipset);
>From 684e19e4d9c007f2cbfe335d7c193ffd6f39aa0d Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Thu, 16 Oct 2025 02:15:19 -0700
Subject: [PATCH 12/12] Addressed comments; Removed addChipsetDependentPattern;
Added required MLIRAMDGPUUtils to LINK_LIBS
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
.../mlir/Conversion/MathToROCDL/MathToROCDL.h | 8 ++--
mlir/include/mlir/Conversion/Passes.td | 2 +-
.../lib/Conversion/MathToROCDL/CMakeLists.txt | 1 +
.../Conversion/MathToROCDL/MathToROCDL.cpp | 40 ++++++-------------
4 files changed, 20 insertions(+), 31 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
index 770f257d89bd5..9cf030975f203 100644
--- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -20,9 +20,11 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
/// Populate the given list with patterns that convert from Math to ROCDL calls.
-void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns,
- amdgpu::Chipset chipset);
+// `chipset` specifies the AMDGPU chipset to target. If `std::nullopt`, a
+// default selection of ROCm functions is used.
+void populateMathToROCDLConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ std::optional<amdgpu::Chipset> chipset);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index a2eb335faac6c..33936da0190cc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -790,7 +790,7 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
"vector::VectorDialect",
];
let options = [Option<"chipset", "chipset", "std::string",
- /*default=*/"\"gfx000\"",
+ /*default=*/"\"\"",
"Chipset that these operations will run on">];
}
diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
index 2771955aa9493..8cc3fde827830 100644
--- a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMathToROCDL
Core
LINK_LIBS PUBLIC
+ MLIRAMDGPUUtils
MLIRDialectUtils
MLIRFuncDialect
MLIRGPUToGPURuntimeTransforms
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 11d9e1c8d0296..6ff39b3505ab6 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -47,9 +47,6 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
struct ClampFOpConversion final
: public ConvertOpToLLVMPattern<math::ClampFOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
- ClampFOpConversion(const LLVMTypeConverter &converter,
- amdgpu::Chipset chipset)
- : ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
LogicalResult
matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
@@ -58,18 +55,15 @@ struct ClampFOpConversion final
Type opTy = op.getType();
auto resultType = getTypeConverter()->convertType(opTy);
- if (auto vectorType = dyn_cast<VectorType>(opTy)) {
+ if (auto vectorType = dyn_cast<VectorType>(opTy))
opTy = vectorType.getElementType();
- }
- if (!isa<Float16Type, Float32Type>(opTy)) {
+ if (!isa<Float16Type, Float32Type>(opTy))
return rewriter.notifyMatchFailure(
op, "fmed3 only supports f16 and f32 types");
- }
// Handle multi-dimensional vectors (converted to LLVM arrays)
- if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) {
- // Handle multi-dimensional vectors (converted to LLVM arrays)
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType))
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
@@ -79,30 +73,17 @@ struct ClampFOpConversion final
adaptor.getMax());
},
rewriter);
- }
// Handle 1D vectors and scalars directly
rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
op.getMin(), op.getMax());
return success();
}
-
- amdgpu::Chipset chipset;
};
-static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns,
- amdgpu::Chipset chipset) {
-
- // V_MED3_F16/F32 only exists in gfx9+ architectures
- if (chipset.majorVersion >= 9) {
- patterns.add<ClampFOpConversion>(converter, chipset);
- }
-}
-
void mlir::populateMathToROCDLConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- amdgpu::Chipset chipset) {
+ std::optional<amdgpu::Chipset> chipset) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
@@ -178,7 +159,11 @@ void mlir::populateMathToROCDLConversionPatterns(
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64", "__ocml_fmod_f16");
- addChipsetDependentPatterns(converter, patterns, chipset);
+ if (chipset.has_value() && chipset->majorVersion >= 9) {
+ patterns.add<ClampFOpConversion>(converter);
+ } else {
+ LDBG() << "Chipset dependent patterns were not added";
+ }
}
struct ConvertMathToROCDLPass final
@@ -201,11 +186,12 @@ void ConvertMathToROCDLPass::runOnOperation() {
FailureOr<amdgpu::Chipset> maybeChipset;
if (!chipset.empty()) {
maybeChipset = amdgpu::Chipset::parse(chipset);
- if (failed(maybeChipset)) {
+ if (failed(maybeChipset))
return signalPassFailure();
- }
}
- populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
+ populateMathToROCDLConversionPatterns(converter, patterns,
+ succeeded(maybeChipset) ? *maybeChipset
+ : std::nullopt);
ConversionTarget target(getContext());
target
More information about the Mlir-commits
mailing list