[Mlir-commits] [mlir] 263d2fc - Fix math.cbrt with vector and f16 arguments.
Johannes Reifferscheid
llvmlistbot at llvm.org
Tue Jan 10 12:32:18 PST 2023
Author: Johannes Reifferscheid
Date: 2023-01-10T21:32:12+01:00
New Revision: 263d2fce557acaf82f53b377d0d4c0de16630d8b
URL: https://github.com/llvm/llvm-project/commit/263d2fce557acaf82f53b377d0d4c0de16630d8b
DIFF: https://github.com/llvm/llvm-project/commit/263d2fce557acaf82f53b377d0d4c0de16630d8b.diff
LOG: Fix math.cbrt with vector and f16 arguments.
Reviewed By: bkramer
Differential Revision: https://reviews.llvm.org/D141421
Added:
Modified:
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 8a8adb5924666..c48686e877eac 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -154,19 +154,20 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
void mlir::populateMathToLibmConversionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit,
llvm::Optional<PatternBenefit> log1pBenefit) {
- patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
- VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
- VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
- VecOpToScalarOp<math::RoundEvenOp>,
+ patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::CbrtOp>,
+ VecOpToScalarOp<math::ExpM1Op>, VecOpToScalarOp<math::TanhOp>,
+ VecOpToScalarOp<math::CosOp>, VecOpToScalarOp<math::SinOp>,
+ VecOpToScalarOp<math::ErfOp>, VecOpToScalarOp<math::RoundEvenOp>,
VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
VecOpToScalarOp<math::TanOp>, VecOpToScalarOp<math::TruncOp>>(
patterns.getContext(), benefit);
- patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
- PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
- PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
- PromoteOpToF32<math::RoundEvenOp>, PromoteOpToF32<math::RoundOp>,
- PromoteOpToF32<math::AtanOp>, PromoteOpToF32<math::TanOp>,
- PromoteOpToF32<math::TruncOp>>(patterns.getContext(), benefit);
+ patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::CbrtOp>,
+ PromoteOpToF32<math::ExpM1Op>, PromoteOpToF32<math::TanhOp>,
+ PromoteOpToF32<math::CosOp>, PromoteOpToF32<math::SinOp>,
+ PromoteOpToF32<math::ErfOp>, PromoteOpToF32<math::RoundEvenOp>,
+ PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
+ PromoteOpToF32<math::TanOp>, PromoteOpToF32<math::TruncOp>>(
+ patterns.getContext(), benefit);
patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
"atan", benefit);
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index b0459d8bfcead..eb375059df1e7 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -246,13 +246,22 @@ func.func @trunc_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK-LABEL: func @cbrt_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
-func.func @cbrt_caller(%float: f32, %double: f64) -> (f32, f64) {
- // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32
+func.func @cbrt_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16,
+ %float_vec: vector<2xf32>) -> (f32, f64, f16, bf16, vector<2xf32>) {
+ // CHECK: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32
%float_result = math.cbrt %float : f32
- // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64
+ // CHECK: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64
%double_result = math.cbrt %double : f64
+ // Just check that these lower successfully:
+ // CHECK: call @cbrtf
+ %half_result = math.cbrt %half : f16
+ // CHECK: call @cbrtf
+ %bfloat_result = math.cbrt %bfloat : bf16
+ // CHECK: call @cbrtf
+ %vec_result = math.cbrt %float_vec : vector<2xf32>
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
- return %float_result, %double_result : f32, f64
+ return %float_result, %double_result, %half_result, %bfloat_result, %vec_result
+ : f32, f64, f16, bf16, vector<2xf32>
}
// CHECK-LABEL: func @cos_caller
More information about the Mlir-commits
mailing list