[Mlir-commits] [mlir] 8fa2e67 - [mlir][complex] complex.arg op to calculate the angle of complex number

Alexander Belyaev llvmlistbot at llvm.org
Mon Jun 27 07:46:01 PDT 2022


Author: Lewuathe
Date: 2022-06-27T16:45:41+02:00
New Revision: 8fa2e67979e56db3cc511ff1af920b4fa02fb473

URL: https://github.com/llvm/llvm-project/commit/8fa2e67979e56db3cc511ff1af920b4fa02fb473
DIFF: https://github.com/llvm/llvm-project/commit/8fa2e67979e56db3cc511ff1af920b4fa02fb473.diff

LOG: [mlir][complex] complex.arg op to calculate the angle of complex number

Add complex.arg op which calculates the angle of complex number. The op name is inspired by the function carg in libm.

See: https://sourceware.org/newlib/libm.html#carg

Differential Revision: https://reviews.llvm.org/D128531

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
    mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 21797d32a22d..f98037c9a515 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -584,4 +584,25 @@ def ConjOp : ComplexUnaryOp<"conj", [SameOperandsAndResultType]> {
   let results = (outs Complex<AnyFloat>:$result);
 }
 
+//===----------------------------------------------------------------------===//
+// AngleOp
+//===----------------------------------------------------------------------===//
+
+def AngleOp : ComplexUnaryOp<"angle",
+                           [TypesMatchWith<"complex element type matches result type",
+                                           "complex", "result",
+                                           "$_self.cast<ComplexType>().getElementType()">]> {
+  let summary = "computes argument value of a complex number";
+  let description = [{
+    The `angle` op takes a single complex number and computes its argument value with a branch cut along the negative real axis.
+
+    Example:
+
+    ```mlir
+         %a = complex.angle %b : complex<f32>
+    ```
+  }];
+  let results = (outs AnyFloat:$result);
+}
+
 #endif // COMPLEX_OPS

diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0a5124ada7a4..b104826b757f 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -1009,6 +1009,26 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
   }
 };
 
+struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
+  using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto type = op.getType();
+
+    Value real =
+        rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
+    Value imag =
+        rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
+
+    rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateComplexToStandardConversionPatterns(
@@ -1016,6 +1036,7 @@ void mlir::populateComplexToStandardConversionPatterns(
   // clang-format off
   patterns.add<
       AbsOpConversion,
+      AngleOpConversion,
       Atan2OpConversion,
       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 5b37899075a4..9aff4ecc80e4 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -694,3 +694,16 @@ func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
   %rsqrt = complex.rsqrt %arg : complex<f32>
   return %rsqrt : complex<f32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @complex_angle
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_angle(%arg: complex<f32>) -> f32 {
+  %angle = complex.angle %arg : complex<f32>
+  return %angle : f32
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] : f32
+// CHECK: return %[[RESULT]] : f32

diff  --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index e2df73ed3c9b..a7e166906f4c 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -82,6 +82,27 @@ func.func @pow(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
   func.return %pow : complex<f32>
 }
 
+func.func @test_element(%input: tensor<?xcomplex<f32>>,
+                      %func: (complex<f32>) -> f32) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
+
+  scf.for %i = %c0 to %size step %c1 {
+    %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
+
+    %val = func.call_indirect %func(%elem) : (complex<f32>) -> f32
+    vector.print %val : f32
+    scf.yield
+  }
+  func.return
+}
+
+func.func @angle(%arg: complex<f32>) -> f32 {
+  %angle = complex.angle %arg : complex<f32>
+  func.return %angle : f32
+}
+
 func.func @entry() {
   // complex.sqrt test
   %sqrt_test = arith.constant dense<[
@@ -251,6 +272,30 @@ func.func @entry() {
   %conj_func = func.constant @conj : (complex<f32>) -> complex<f32>
   call @test_unary(%conj_test_cast, %conj_func)
     : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
-    
+
+  // complex.angle test
+  %angle_test = arith.constant dense<[
+    (-1.0, -1.0),
+    // CHECK:      -2.356
+    (-1.0, 1.0),
+    // CHECK-NEXT:  2.356
+    (0.0, 0.0),
+    // CHECK-NEXT:  0
+    (0.0, 1.0),
+    // CHECK-NEXT:  1.570
+    (1.0, -1.0),
+    // CHECK-NEXT:  -0.785
+    (1.0, 0.0),
+    // CHECK-NEXT:  0
+    (1.0, 1.0)
+    // CHECK-NEXT:  0.785
+  ]> : tensor<7xcomplex<f32>>
+  %angle_test_cast = tensor.cast %angle_test
+    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+  %angle_func = func.constant @angle : (complex<f32>) -> f32
+  call @test_element(%angle_test_cast, %angle_func)
+    : (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
+
   func.return
 }


        


More information about the Mlir-commits mailing list