[Mlir-commits] [mlir] f112bd6 - [mlir] Add SignOp to complex dialect.

Adrian Kuegel llvmlistbot at llvm.org
Tue Jun 15 06:22:43 PDT 2021


Author: Adrian Kuegel
Date: 2021-06-15T15:22:31+02:00
New Revision: f112bd61ebf315b563fdd3dae947f0c67d02a6cc

URL: https://github.com/llvm/llvm-project/commit/f112bd61ebf315b563fdd3dae947f0c67d02a6cc
DIFF: https://github.com/llvm/llvm-project/commit/f112bd61ebf315b563fdd3dae947f0c67d02a6cc.diff

LOG: [mlir] Add SignOp to complex dialect.

Also add a conversion pattern from Complex Dialect to Standard/Math Dialect.

Differential Revision: https://reviews.llvm.org/D104292

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 989df0e7d368..d533b5db6b41 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -280,6 +280,25 @@ def ReOp : ComplexUnaryOp<"re",
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// SignOp
+//===----------------------------------------------------------------------===//
+
+def SignOp : ComplexUnaryOp<"sign", [SameOperandsAndResultType]> {
+  let summary = "computes sign of a complex number";
+  let description = [{
+    The `sign` op takes a single complex number and computes the sign of
+    it, i.e. `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`.
+
+    Example:
+
+    ```mlir
+    %a = complex.sign %b : complex<f32>
+    ```
+  }];
+
+  let results = (outs Complex<AnyFloat>:$result);
+}
 
 //===----------------------------------------------------------------------===//
 // SubOp

diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 51aa671d995c..bcd9572a471f 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -335,15 +336,47 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
     return success();
   }
 };
+
+struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
+  using OpConversionPattern<complex::SignOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::SignOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::SignOp::Adaptor transformed(operands);
+    auto type = transformed.complex().getType().cast<ComplexType>();
+    auto elementType = type.getElementType().cast<FloatType>();
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+    Value real = b.create<complex::ReOp>(elementType, transformed.complex());
+    Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
+    Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
+    Value realIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, real, zero);
+    Value imagIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, imag, zero);
+    Value isZero = b.create<AndOp>(realIsZero, imagIsZero);
+    auto abs = b.create<complex::AbsOp>(elementType, transformed.complex());
+    Value realSign = b.create<DivFOp>(real, abs);
+    Value imagSign = b.create<DivFOp>(imag, abs);
+    Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
+    rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, transformed.complex(),
+                                          sign);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::populateComplexToStandardConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<AbsOpConversion,
-               ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
-               ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
-               DivOpConversion, ExpOpConversion, NegOpConversion>(
-      patterns.getContext());
+  // clang-format off
+  patterns.add<
+      AbsOpConversion,
+      ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
+      ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
+      DivOpConversion,
+      ExpOpConversion,
+      NegOpConversion,
+      SignOpConversion>(patterns.getContext());
+  // clang-format on
 }
 
 namespace {
@@ -363,7 +396,8 @@ void ConvertComplexToStandardPass::runOnFunction() {
   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
                          complex::ComplexDialect>();
   target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
-                      complex::ExpOp, complex::NotEqualOp, complex::NegOp>();
+                      complex::ExpOp, complex::NotEqualOp, complex::NegOp,
+                      complex::SignOp>();
   if (failed(applyPartialConversion(function, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index fe75575b59cf..62edf1e74a4a 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -181,3 +181,27 @@ func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
 // CHECK-DAG: %[[IMAG_NOT_EQUAL:.*]] = cmpf une, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
 // CHECK: %[[NOT_EQUAL:.*]] = or %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
 // CHECK: return %[[NOT_EQUAL]] : i1
+
+// CHECK-LABEL: func @complex_sign
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func @complex_sign(%arg: complex<f32>) -> complex<f32> {
+  %sign = complex.sign %arg: complex<f32>
+  return %sign : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ZERO:.*]] = constant 0.000000e+00 : f32
+// CHECK: %[[REAL_IS_ZERO:.*]] = cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_IS_ZERO:.*]] = cmpf oeq, %1, %cst : f32
+// CHECK: %[[IS_ZERO:.*]] = and %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
+// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[SQR_REAL:.*]] = mulf %[[REAL2]], %[[REAL2]] : f32
+// CHECK: %[[SQR_IMAG:.*]] = mulf %[[IMAG2]], %[[IMAG2]] : f32
+// CHECK: %[[SQ_NORM:.*]] = addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
+// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[REAL_SIGN:.*]] = divf %[[REAL]], %[[NORM]] : f32
+// CHECK: %[[IMAG_SIGN:.*]] = divf %[[IMAG]], %[[NORM]] : f32
+// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
+// CHECK: %[[RESULT:.*]] = select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>


        


More information about the Mlir-commits mailing list