[Mlir-commits] [mlir] 8fa2e67 - [mlir][complex] complex.arg op to calculate the angle of complex number
Alexander Belyaev
llvmlistbot at llvm.org
Mon Jun 27 07:46:01 PDT 2022
Author: Lewuathe
Date: 2022-06-27T16:45:41+02:00
New Revision: 8fa2e67979e56db3cc511ff1af920b4fa02fb473
URL: https://github.com/llvm/llvm-project/commit/8fa2e67979e56db3cc511ff1af920b4fa02fb473
DIFF: https://github.com/llvm/llvm-project/commit/8fa2e67979e56db3cc511ff1af920b4fa02fb473.diff
LOG: [mlir][complex] complex.arg op to calculate the angle of complex number
Add complex.arg op which calculates the angle of complex number. The op name is inspired by the function carg in libm.
See: https://sourceware.org/newlib/libm.html#carg
Differential Revision: https://reviews.llvm.org/D128531
Added:
Modified:
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 21797d32a22d..f98037c9a515 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -584,4 +584,25 @@ def ConjOp : ComplexUnaryOp<"conj", [SameOperandsAndResultType]> {
let results = (outs Complex<AnyFloat>:$result);
}
+//===----------------------------------------------------------------------===//
+// AngleOp
+//===----------------------------------------------------------------------===//
+
+def AngleOp : ComplexUnaryOp<"angle",
+ [TypesMatchWith<"complex element type matches result type",
+ "complex", "result",
+ "$_self.cast<ComplexType>().getElementType()">]> {
+ let summary = "computes argument value of a complex number";
+ let description = [{
+ The `angle` op takes a single complex number and computes its argument value with a branch cut along the negative real axis.
+
+ Example:
+
+ ```mlir
+ %a = complex.angle %b : complex<f32>
+ ```
+ }];
+ let results = (outs AnyFloat:$result);
+}
+
#endif // COMPLEX_OPS
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0a5124ada7a4..b104826b757f 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -1009,6 +1009,26 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
}
};
+struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
+ using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto type = op.getType();
+
+ Value real =
+ rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
+ Value imag =
+ rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
+
+ rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::populateComplexToStandardConversionPatterns(
@@ -1016,6 +1036,7 @@ void mlir::populateComplexToStandardConversionPatterns(
// clang-format off
patterns.add<
AbsOpConversion,
+ AngleOpConversion,
Atan2OpConversion,
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 5b37899075a4..9aff4ecc80e4 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -694,3 +694,16 @@ func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
%rsqrt = complex.rsqrt %arg : complex<f32>
return %rsqrt : complex<f32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @complex_angle
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_angle(%arg: complex<f32>) -> f32 {
+ %angle = complex.angle %arg : complex<f32>
+ return %angle : f32
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] : f32
+// CHECK: return %[[RESULT]] : f32
diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index e2df73ed3c9b..a7e166906f4c 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -82,6 +82,27 @@ func.func @pow(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
func.return %pow : complex<f32>
}
+func.func @test_element(%input: tensor<?xcomplex<f32>>,
+ %func: (complex<f32>) -> f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
+
+ scf.for %i = %c0 to %size step %c1 {
+ %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
+
+ %val = func.call_indirect %func(%elem) : (complex<f32>) -> f32
+ vector.print %val : f32
+ scf.yield
+ }
+ func.return
+}
+
+func.func @angle(%arg: complex<f32>) -> f32 {
+ %angle = complex.angle %arg : complex<f32>
+ func.return %angle : f32
+}
+
func.func @entry() {
// complex.sqrt test
%sqrt_test = arith.constant dense<[
@@ -251,6 +272,30 @@ func.func @entry() {
%conj_func = func.constant @conj : (complex<f32>) -> complex<f32>
call @test_unary(%conj_test_cast, %conj_func)
: (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
-
+
+ // complex.angle test
+ %angle_test = arith.constant dense<[
+ (-1.0, -1.0),
+ // CHECK: -2.356
+ (-1.0, 1.0),
+ // CHECK-NEXT: 2.356
+ (0.0, 0.0),
+ // CHECK-NEXT: 0
+ (0.0, 1.0),
+ // CHECK-NEXT: 1.570
+ (1.0, -1.0),
+ // CHECK-NEXT: -0.785
+ (1.0, 0.0),
+ // CHECK-NEXT: 0
+ (1.0, 1.0)
+ // CHECK-NEXT: 0.785
+ ]> : tensor<7xcomplex<f32>>
+ %angle_test_cast = tensor.cast %angle_test
+ : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+ %angle_func = func.constant @angle : (complex<f32>) -> f32
+ call @test_element(%angle_test_cast, %angle_func)
+ : (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
+
func.return
}
More information about the Mlir-commits
mailing list