[Mlir-commits] [mlir] 263d2fc - Fix math.cbrt with vector and f16 arguments.

Johannes Reifferscheid llvmlistbot at llvm.org
Tue Jan 10 12:32:18 PST 2023


Author: Johannes Reifferscheid
Date: 2023-01-10T21:32:12+01:00
New Revision: 263d2fce557acaf82f53b377d0d4c0de16630d8b

URL: https://github.com/llvm/llvm-project/commit/263d2fce557acaf82f53b377d0d4c0de16630d8b
DIFF: https://github.com/llvm/llvm-project/commit/263d2fce557acaf82f53b377d0d4c0de16630d8b.diff

LOG: Fix math.cbrt with vector and f16 arguments.

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D141421

Added: 
    

Modified: 
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/test/Conversion/MathToLibm/convert-to-libm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 8a8adb5924666..c48686e877eac 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -154,19 +154,20 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
 void mlir::populateMathToLibmConversionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit,
     llvm::Optional<PatternBenefit> log1pBenefit) {
-  patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
-               VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
-               VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
-               VecOpToScalarOp<math::RoundEvenOp>,
+  patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::CbrtOp>,
+               VecOpToScalarOp<math::ExpM1Op>, VecOpToScalarOp<math::TanhOp>,
+               VecOpToScalarOp<math::CosOp>, VecOpToScalarOp<math::SinOp>,
+               VecOpToScalarOp<math::ErfOp>, VecOpToScalarOp<math::RoundEvenOp>,
                VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
                VecOpToScalarOp<math::TanOp>, VecOpToScalarOp<math::TruncOp>>(
       patterns.getContext(), benefit);
-  patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
-               PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
-               PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
-               PromoteOpToF32<math::RoundEvenOp>, PromoteOpToF32<math::RoundOp>,
-               PromoteOpToF32<math::AtanOp>, PromoteOpToF32<math::TanOp>,
-               PromoteOpToF32<math::TruncOp>>(patterns.getContext(), benefit);
+  patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::CbrtOp>,
+               PromoteOpToF32<math::ExpM1Op>, PromoteOpToF32<math::TanhOp>,
+               PromoteOpToF32<math::CosOp>, PromoteOpToF32<math::SinOp>,
+               PromoteOpToF32<math::ErfOp>, PromoteOpToF32<math::RoundEvenOp>,
+               PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
+               PromoteOpToF32<math::TanOp>, PromoteOpToF32<math::TruncOp>>(
+      patterns.getContext(), benefit);
   patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
                                                  "atan", benefit);
   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),

diff  --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index b0459d8bfcead..eb375059df1e7 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -246,13 +246,22 @@ func.func @trunc_caller(%float: f32, %double: f64) -> (f32, f64) {
 // CHECK-LABEL: func @cbrt_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64
-func.func @cbrt_caller(%float: f32, %double: f64) -> (f32, f64)  {
-  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32
+func.func @cbrt_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16,
+                       %float_vec: vector<2xf32>) -> (f32, f64, f16, bf16, vector<2xf32>)  {
+  // CHECK: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32
   %float_result = math.cbrt %float : f32
-  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64
+  // CHECK: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64
   %double_result = math.cbrt %double : f64
+  // Just check that these lower successfully:
+  // CHECK: call @cbrtf
+  %half_result = math.cbrt %half : f16
+  // CHECK: call @cbrtf
+  %bfloat_result = math.cbrt %bfloat : bf16
+  // CHECK: call @cbrtf
+  %vec_result = math.cbrt %float_vec : vector<2xf32>
   // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
-  return %float_result, %double_result : f32, f64
+  return %float_result, %double_result, %half_result, %bfloat_result, %vec_result
+    : f32, f64, f16, bf16, vector<2xf32>
 }
 
 // CHECK-LABEL: func @cos_caller


        


More information about the Mlir-commits mailing list