[Mlir-commits] [mlir] 672b908 - [mlir] Add sin & cos ops to complex dialect

Goran Flegar llvmlistbot at llvm.org
Tue May 3 10:37:12 PDT 2022


Author: Goran Flegar
Date: 2022-05-03T19:36:12+02:00
New Revision: 672b908bca672ad5b0ecffce653242f26a87cd20

URL: https://github.com/llvm/llvm-project/commit/672b908bca672ad5b0ecffce653242f26a87cd20
DIFF: https://github.com/llvm/llvm-project/commit/672b908bca672ad5b0ecffce653242f26a87cd20.diff

LOG: [mlir] Add sin & cos ops to complex dialect

Also adds conversions for those ops to math + arith.

Differential Revision: https://reviews.llvm.org/D124773

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
    mlir/test/Dialect/Complex/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 8e176e3b2d922..b215d0de5bc5e 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -109,6 +109,26 @@ def ConstantOp : Complex_Op<"constant", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// CosOp
+//===----------------------------------------------------------------------===//
+
+def CosOp : ComplexUnaryOp<"cos", [SameOperandsAndResultType]> {
+  let summary = "computes cosine of a complex number";
+  let description = [{
+    The `cos` op takes a single complex number and computes the cosine of
+    it, i.e. `cos(x)`, where `x` is the input value.
+
+    Example:
+
+    ```mlir
+    %a = complex.cos %b : complex<f32>
+    ```
+  }];
+
+  let results = (outs Complex<AnyFloat>:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // CreateOp
 //===----------------------------------------------------------------------===//
@@ -369,6 +389,26 @@ def SignOp : ComplexUnaryOp<"sign", [SameOperandsAndResultType]> {
   let results = (outs Complex<AnyFloat>:$result);
 }
 
+//===----------------------------------------------------------------------===//
+// SinOp
+//===----------------------------------------------------------------------===//
+
+def SinOp : ComplexUnaryOp<"sin", [SameOperandsAndResultType]> {
+  let summary = "computes sine of a complex number";
+  let description = [{
+    The `sin` op takes a single complex number and computes the sine of
+    it, i.e. `sin(x)`, where `x` is the input value.
+
+    Example:
+
+    ```mlir
+    %a = complex.sin %b : complex<f32>
+    ```
+  }];
+
+  let results = (outs Complex<AnyFloat>:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // SubOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 4bfd5e257ac94..a676e0afe9a88 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -103,6 +103,71 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
   }
 };
 
+template <typename TrigonometricOp>
+struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
+  using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
+
+  using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto type = adaptor.getComplex().getType().template cast<ComplexType>();
+    auto elementType = type.getElementType().template cast<FloatType>();
+
+    Value real =
+        rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+    Value imag =
+        rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
+
+    // Trigonometric ops use a set of common building blocks to convert to real
+    // ops. Here we create these building blocks and call into an op-specific
+    // implementation in the subclass to combine them.
+    Value half = rewriter.create<arith::ConstantOp>(
+        loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
+    Value exp = rewriter.create<math::ExpOp>(loc, imag);
+    Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
+    Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
+    Value sin = rewriter.create<math::SinOp>(loc, real);
+    Value cos = rewriter.create<math::CosOp>(loc, real);
+
+    auto resultPair =
+        combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
+
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
+                                                   resultPair.second);
+    return success();
+  }
+
+  virtual std::pair<Value, Value>
+  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
+          Value cos, ConversionPatternRewriter &rewriter) const = 0;
+};
+
+struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
+  using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
+
+  std::pair<Value, Value>
+  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
+          Value cos, ConversionPatternRewriter &rewriter) const override {
+    // Complex cosine is defined as;
+    //   cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
+    // Plugging in:
+    //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
+    //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
+    // and defining t := exp(y)
+    // We get:
+    //   Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
+    //   Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
+    Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
+    Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
+    Value 
diff  = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
+    Value resultImag = rewriter.create<arith::MulFOp>(loc, 
diff , sin);
+    return {resultReal, resultImag};
+  }
+};
+
 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
 
@@ -588,6 +653,29 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
   }
 };
 
+struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
+  using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
+
+  std::pair<Value, Value>
+  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
+          Value cos, ConversionPatternRewriter &rewriter) const override {
+    // Complex sine is defined as;
+    //   sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
+    // Plugging in:
+    //   exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
+    //   exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
+    // and defining t := exp(y)
+    // We get:
+    //   Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
+    //   Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
+    Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
+    Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
+    Value 
diff  = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
+    Value resultImag = rewriter.create<arith::MulFOp>(loc, 
diff , cos);
+    return {resultReal, resultImag};
+  }
+};
+
 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
 
@@ -627,13 +715,15 @@ void mlir::populateComplexToStandardConversionPatterns(
       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
+      CosOpConversion,
       DivOpConversion,
       ExpOpConversion,
       LogOpConversion,
       Log1pOpConversion,
       MulOpConversion,
       NegOpConversion,
-      SignOpConversion>(patterns.getContext());
+      SignOpConversion,
+      SinOpConversion>(patterns.getContext());
   // clang-format on
 }
 

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 03e80fc67f9b1..8e7098f832bb8 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -29,6 +29,27 @@ func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// CHECK-LABEL: func @complex_cos
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
+  %cos = complex.cos %arg : complex<f32>
+  return %cos : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]]
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]]
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_REXP]], %[[HALF_EXP]]
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[COS]]
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_REXP]], %[[HALF_EXP]]
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[SIN]]
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK:     return %[[RESULT]]
+
 // CHECK-LABEL: func @complex_div
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -358,6 +379,27 @@ func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
 // CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
 // CHECK: return %[[NOT_EQUAL]] : i1
 
+// CHECK-LABEL: func @complex_sin
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
+  %sin = complex.sin %arg : complex<f32>
+  return %sin : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]]
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]]
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_EXP]], %[[HALF_REXP]]
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[SIN]]
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_EXP]], %[[HALF_REXP]]
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]]
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK:     return %[[RESULT]]
+
 // CHECK-LABEL: func @complex_sign
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {

diff  --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index a24d780d05687..6c2ed8bc9dfee 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -26,6 +26,9 @@ func.func @ops(%f: f32) {
   // CHECK: complex.add %[[C]], %[[C]] : complex<f32>
   %sum = complex.add %complex, %complex : complex<f32>
 
+  // CHECK: complex.cos %[[C]] : complex<f32>
+  %cos = complex.cos %complex : complex<f32>
+
   // CHECK: complex.div %[[C]], %[[C]] : complex<f32>
   %div = complex.div %complex, %complex : complex<f32>
 
@@ -53,6 +56,9 @@ func.func @ops(%f: f32) {
   // CHECK: complex.sign %[[C]] : complex<f32>
   %sign = complex.sign %complex : complex<f32>
 
+  // CHECK: complex.sin %[[C]] : complex<f32>
+  %sin = complex.sin %complex : complex<f32>
+
   // CHECK: complex.sub %[[C]], %[[C]] : complex<f32>
   %
diff  = complex.sub %complex, %complex : complex<f32>
   return


        


More information about the Mlir-commits mailing list