[Mlir-commits] [mlir] 672b908 - [mlir] Add sin & cos ops to complex dialect
Goran Flegar
llvmlistbot at llvm.org
Tue May 3 10:37:12 PDT 2022
Author: Goran Flegar
Date: 2022-05-03T19:36:12+02:00
New Revision: 672b908bca672ad5b0ecffce653242f26a87cd20
URL: https://github.com/llvm/llvm-project/commit/672b908bca672ad5b0ecffce653242f26a87cd20
DIFF: https://github.com/llvm/llvm-project/commit/672b908bca672ad5b0ecffce653242f26a87cd20.diff
LOG: [mlir] Add sin & cos ops to complex dialect
Also adds conversions for those ops to math + arith.
Differential Revision: https://reviews.llvm.org/D124773
Added:
Modified:
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
mlir/test/Dialect/Complex/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 8e176e3b2d922..b215d0de5bc5e 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -109,6 +109,26 @@ def ConstantOp : Complex_Op<"constant", [
}];
}
+//===----------------------------------------------------------------------===//
+// CosOp
+//===----------------------------------------------------------------------===//
+
+def CosOp : ComplexUnaryOp<"cos", [SameOperandsAndResultType]> {
+ let summary = "computes cosine of a complex number";
+ let description = [{
+ The `cos` op takes a single complex number and computes the cosine of
+ it, i.e. `cos(x)`, where `x` is the input value.
+
+ Example:
+
+ ```mlir
+ %a = complex.cos %b : complex<f32>
+ ```
+ }];
+
+ let results = (outs Complex<AnyFloat>:$result);
+}
+
//===----------------------------------------------------------------------===//
// CreateOp
//===----------------------------------------------------------------------===//
@@ -369,6 +389,26 @@ def SignOp : ComplexUnaryOp<"sign", [SameOperandsAndResultType]> {
let results = (outs Complex<AnyFloat>:$result);
}
+//===----------------------------------------------------------------------===//
+// SinOp
+//===----------------------------------------------------------------------===//
+
+def SinOp : ComplexUnaryOp<"sin", [SameOperandsAndResultType]> {
+ let summary = "computes sine of a complex number";
+ let description = [{
+ The `sin` op takes a single complex number and computes the sine of
+ it, i.e. `sin(x)`, where `x` is the input value.
+
+ Example:
+
+ ```mlir
+ %a = complex.sin %b : complex<f32>
+ ```
+ }];
+
+ let results = (outs Complex<AnyFloat>:$result);
+}
+
//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 4bfd5e257ac94..a676e0afe9a88 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -103,6 +103,71 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
}
};
+template <typename TrigonometricOp>
+struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
+ using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
+
+ using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto type = adaptor.getComplex().getType().template cast<ComplexType>();
+ auto elementType = type.getElementType().template cast<FloatType>();
+
+ Value real =
+ rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ Value imag =
+ rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
+
+ // Trigonometric ops use a set of common building blocks to convert to real
+ // ops. Here we create these building blocks and call into an op-specific
+ // implementation in the subclass to combine them.
+ Value half = rewriter.create<arith::ConstantOp>(
+ loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
+ Value exp = rewriter.create<math::ExpOp>(loc, imag);
+ Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
+ Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
+ Value sin = rewriter.create<math::SinOp>(loc, real);
+ Value cos = rewriter.create<math::CosOp>(loc, real);
+
+ auto resultPair =
+ combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
+
+ rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
+ resultPair.second);
+ return success();
+ }
+
+ virtual std::pair<Value, Value>
+ combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
+ Value cos, ConversionPatternRewriter &rewriter) const = 0;
+};
+
+struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
+ using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
+
+ std::pair<Value, Value>
+ combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
+ Value cos, ConversionPatternRewriter &rewriter) const override {
+ // Complex cosine is defined as;
+ // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
+ // Plugging in:
+ // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
+ // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
+ // and defining t := exp(y)
+ // We get:
+ // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
+ // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
+ Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
+ Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
+ Value
diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
+ Value resultImag = rewriter.create<arith::MulFOp>(loc,
diff , sin);
+ return {resultReal, resultImag};
+ }
+};
+
struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
using OpConversionPattern<complex::DivOp>::OpConversionPattern;
@@ -588,6 +653,29 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
}
};
+struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
+ using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
+
+ std::pair<Value, Value>
+ combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
+ Value cos, ConversionPatternRewriter &rewriter) const override {
+ // Complex sine is defined as;
+ // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
+ // Plugging in:
+ // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
+ // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
+ // and defining t := exp(y)
+ // We get:
+ // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
+ // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
+ Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
+ Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
+ Value
diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
+ Value resultImag = rewriter.create<arith::MulFOp>(loc,
diff , cos);
+ return {resultReal, resultImag};
+ }
+};
+
struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
using OpConversionPattern<complex::SignOp>::OpConversionPattern;
@@ -627,13 +715,15 @@ void mlir::populateComplexToStandardConversionPatterns(
ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
+ CosOpConversion,
DivOpConversion,
ExpOpConversion,
LogOpConversion,
Log1pOpConversion,
MulOpConversion,
NegOpConversion,
- SignOpConversion>(patterns.getContext());
+ SignOpConversion,
+ SinOpConversion>(patterns.getContext());
// clang-format on
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 03e80fc67f9b1..8e7098f832bb8 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -29,6 +29,27 @@ func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// CHECK-LABEL: func @complex_cos
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
+ %cos = complex.cos %arg : complex<f32>
+ return %cos : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]]
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]]
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_REXP]], %[[HALF_EXP]]
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[COS]]
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_REXP]], %[[HALF_EXP]]
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[SIN]]
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]]
+
// CHECK-LABEL: func @complex_div
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -358,6 +379,27 @@ func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
// CHECK: return %[[NOT_EQUAL]] : i1
+// CHECK-LABEL: func @complex_sin
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
+ %sin = complex.sin %arg : complex<f32>
+ return %sin : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]]
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]]
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_EXP]], %[[HALF_REXP]]
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[SIN]]
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_EXP]], %[[HALF_REXP]]
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]]
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]]
+
// CHECK-LABEL: func @complex_sign
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index a24d780d05687..6c2ed8bc9dfee 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -26,6 +26,9 @@ func.func @ops(%f: f32) {
// CHECK: complex.add %[[C]], %[[C]] : complex<f32>
%sum = complex.add %complex, %complex : complex<f32>
+ // CHECK: complex.cos %[[C]] : complex<f32>
+ %cos = complex.cos %complex : complex<f32>
+
// CHECK: complex.div %[[C]], %[[C]] : complex<f32>
%div = complex.div %complex, %complex : complex<f32>
@@ -53,6 +56,9 @@ func.func @ops(%f: f32) {
// CHECK: complex.sign %[[C]] : complex<f32>
%sign = complex.sign %complex : complex<f32>
+ // CHECK: complex.sin %[[C]] : complex<f32>
+ %sin = complex.sin %complex : complex<f32>
+
// CHECK: complex.sub %[[C]], %[[C]] : complex<f32>
%
diff = complex.sub %complex, %complex : complex<f32>
return
More information about the Mlir-commits
mailing list