[Mlir-commits] [mlir] f112bd6 - [mlir] Add SignOp to complex dialect.
Adrian Kuegel
llvmlistbot at llvm.org
Tue Jun 15 06:22:43 PDT 2021
Author: Adrian Kuegel
Date: 2021-06-15T15:22:31+02:00
New Revision: f112bd61ebf315b563fdd3dae947f0c67d02a6cc
URL: https://github.com/llvm/llvm-project/commit/f112bd61ebf315b563fdd3dae947f0c67d02a6cc
DIFF: https://github.com/llvm/llvm-project/commit/f112bd61ebf315b563fdd3dae947f0c67d02a6cc.diff
LOG: [mlir] Add SignOp to complex dialect.
Also add a conversion pattern from Complex Dialect to Standard/Math Dialect.
Differential Revision: https://reviews.llvm.org/D104292
Added:
Modified:
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 989df0e7d368..d533b5db6b41 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -280,6 +280,25 @@ def ReOp : ComplexUnaryOp<"re",
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// SignOp
+//===----------------------------------------------------------------------===//
+
+def SignOp : ComplexUnaryOp<"sign", [SameOperandsAndResultType]> {
+ let summary = "computes sign of a complex number";
+ let description = [{
+ The `sign` op takes a single complex number and computes the sign of
+ it, i.e. `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`.
+
+ Example:
+
+ ```mlir
+ %a = complex.sign %b : complex<f32>
+ ```
+ }];
+
+ let results = (outs Complex<AnyFloat>:$result);
+}
//===----------------------------------------------------------------------===//
// SubOp
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 51aa671d995c..bcd9572a471f 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -335,15 +336,47 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
return success();
}
};
+
+struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
+ using OpConversionPattern<complex::SignOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::SignOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ complex::SignOp::Adaptor transformed(operands);
+ auto type = transformed.complex().getType().cast<ComplexType>();
+ auto elementType = type.getElementType().cast<FloatType>();
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+ Value real = b.create<complex::ReOp>(elementType, transformed.complex());
+ Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
+ Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
+ Value realIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, real, zero);
+ Value imagIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, imag, zero);
+ Value isZero = b.create<AndOp>(realIsZero, imagIsZero);
+ auto abs = b.create<complex::AbsOp>(elementType, transformed.complex());
+ Value realSign = b.create<DivFOp>(real, abs);
+ Value imagSign = b.create<DivFOp>(imag, abs);
+ Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
+ rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, transformed.complex(),
+ sign);
+ return success();
+ }
+};
} // namespace
void mlir::populateComplexToStandardConversionPatterns(
RewritePatternSet &patterns) {
- patterns.add<AbsOpConversion,
- ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
- ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
- DivOpConversion, ExpOpConversion, NegOpConversion>(
- patterns.getContext());
+ // clang-format off
+ patterns.add<
+ AbsOpConversion,
+ ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
+ ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
+ DivOpConversion,
+ ExpOpConversion,
+ NegOpConversion,
+ SignOpConversion>(patterns.getContext());
+ // clang-format on
}
namespace {
@@ -363,7 +396,8 @@ void ConvertComplexToStandardPass::runOnFunction() {
target.addLegalDialect<StandardOpsDialect, math::MathDialect,
complex::ComplexDialect>();
target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
- complex::ExpOp, complex::NotEqualOp, complex::NegOp>();
+ complex::ExpOp, complex::NotEqualOp, complex::NegOp,
+ complex::SignOp>();
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index fe75575b59cf..62edf1e74a4a 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -181,3 +181,27 @@ func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK-DAG: %[[IMAG_NOT_EQUAL:.*]] = cmpf une, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
// CHECK: %[[NOT_EQUAL:.*]] = or %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
// CHECK: return %[[NOT_EQUAL]] : i1
+
+// CHECK-LABEL: func @complex_sign
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func @complex_sign(%arg: complex<f32>) -> complex<f32> {
+ %sign = complex.sign %arg: complex<f32>
+ return %sign : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ZERO:.*]] = constant 0.000000e+00 : f32
+// CHECK: %[[REAL_IS_ZERO:.*]] = cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_IS_ZERO:.*]] = cmpf oeq, %1, %cst : f32
+// CHECK: %[[IS_ZERO:.*]] = and %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
+// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[SQR_REAL:.*]] = mulf %[[REAL2]], %[[REAL2]] : f32
+// CHECK: %[[SQR_IMAG:.*]] = mulf %[[IMAG2]], %[[IMAG2]] : f32
+// CHECK: %[[SQ_NORM:.*]] = addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
+// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[REAL_SIGN:.*]] = divf %[[REAL]], %[[NORM]] : f32
+// CHECK: %[[IMAG_SIGN:.*]] = divf %[[IMAG]], %[[NORM]] : f32
+// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
+// CHECK: %[[RESULT:.*]] = select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
More information about the Mlir-commits
mailing list