[Mlir-commits] [mlir] 998a3a3 - Add a math.cbrt instruction and lowering to libm.

Johannes Reifferscheid llvmlistbot at llvm.org
Mon Jan 2 23:44:19 PST 2023


Author: Johannes Reifferscheid
Date: 2023-01-03T08:44:12+01:00
New Revision: 998a3a38948c9d220ddc759b8a6eee987e3ad320

URL: https://github.com/llvm/llvm-project/commit/998a3a38948c9d220ddc759b8a6eee987e3ad320
DIFF: https://github.com/llvm/llvm-project/commit/998a3a38948c9d220ddc759b8a6eee987e3ad320.diff

LOG: Add a math.cbrt instruction and lowering to libm.

There's currently no way to get accurate cube roots in the math dialect.
powf(x, 1/3.0) is too inaccurate in some cases.

Reviewed By: akuegel

Differential Revision: https://reviews.llvm.org/D140842

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
    mlir/test/Dialect/Math/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 3f2a8d7cb464..f8e9fd601304 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -196,6 +196,28 @@ def Math_Atan2Op : Math_FloatBinaryOp<"atan2">{
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// CbrtOp
+//===----------------------------------------------------------------------===//
+
+def Math_CbrtOp : Math_FloatUnaryOp<"cbrt"> {
+  let summary = "cube root of the specified value";
+  let description = [{
+    The `cbrt` operation computes the cube root. It takes one operand of
+    floating point type (i.e., scalar, tensor or vector) and returns one result
+    of the same type. It has no standard attributes.
+
+    Example:
+
+    ```mlir
+    // Scalar cube root value.
+    %a = math.cbrt %b : f64
+    ```
+
+    Note: This op is not equivalent to powf(..., 1/3.0).
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // CeilOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index d40666d6608c..8a8adb592466 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -171,6 +171,8 @@ void mlir::populateMathToLibmConversionPatterns(
                                                  "atan", benefit);
   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
                                                   "atan2f", "atan2", benefit);
+  patterns.add<ScalarOpToLibmCall<math::CbrtOp>>(patterns.getContext(), "cbrtf",
+                                                 "cbrt", benefit);
   patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
                                                 "erf", benefit);
   patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),

diff  --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index d911f8b1b8fb..b0459d8bfcea 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -8,6 +8,8 @@
 // CHECK-DAG: @expm1f(f32) -> f32 attributes {llvm.readnone}
 // CHECK-DAG: @atan2(f64, f64) -> f64 attributes {llvm.readnone}
 // CHECK-DAG: @atan2f(f32, f32) -> f32 attributes {llvm.readnone}
+// CHECK-DAG: @cbrt(f64) -> f64 attributes {llvm.readnone}
+// CHECK-DAG: @cbrtf(f32) -> f32 attributes {llvm.readnone}
 // CHECK-DAG: @tan(f64) -> f64 attributes {llvm.readnone}
 // CHECK-DAG: @tanf(f32) -> f32 attributes {llvm.readnone}
 // CHECK-DAG: @tanh(f64) -> f64 attributes {llvm.readnone}
@@ -241,6 +243,18 @@ func.func @trunc_caller(%float: f32, %double: f64) -> (f32, f64) {
   return %float_result, %double_result : f32, f64
 }
 
+// CHECK-LABEL: func @cbrt_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @cbrt_caller(%float: f32, %double: f64) -> (f32, f64)  {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.cbrt %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.cbrt %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
+
 // CHECK-LABEL: func @cos_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64

diff  --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index 0f744f52d1c9..7e45d9bc6f74 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -26,6 +26,18 @@ func.func @atan2(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
   return
 }
 
+// CHECK-LABEL: func @cbrt(
+// CHECK-SAME:             %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @cbrt(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+  // CHECK: %{{.*}} = math.cbrt %[[F]] : f32
+  %0 = math.cbrt %f : f32
+  // CHECK: %{{.*}} = math.cbrt %[[V]] : vector<4xf32>
+  %1 = math.cbrt %v : vector<4xf32>
+  // CHECK: %{{.*}} = math.cbrt %[[T]] : tensor<4x4x?xf32>
+  %2 = math.cbrt %t : tensor<4x4x?xf32>
+  return
+}
+
 // CHECK-LABEL: func @cos(
 // CHECK-SAME:            %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
 func.func @cos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {


        


More information about the Mlir-commits mailing list