[Mlir-commits] [mlir] d0cb0d3 - [mlir] Add Complex dialect.

Alexander Belyaev llvmlistbot at llvm.org
Fri Jan 15 11:02:54 PST 2021


Author: Alexander Belyaev
Date: 2021-01-15T19:58:10+01:00
New Revision: d0cb0d30a431578ecedb98c57780154789f3c594

URL: https://github.com/llvm/llvm-project/commit/d0cb0d30a431578ecedb98c57780154789f3c594
DIFF: https://github.com/llvm/llvm-project/commit/d0cb0d30a431578ecedb98c57780154789f3c594.diff

LOG: [mlir] Add Complex dialect.

Differential Revision: https://reviews.llvm.org/D94764

Added: 
    mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
    mlir/include/mlir/Dialect/Complex/CMakeLists.txt
    mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Complex/IR/Complex.h
    mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt
    mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
    mlir/lib/Dialect/Complex/CMakeLists.txt
    mlir/lib/Dialect/Complex/IR/CMakeLists.txt
    mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
    mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
    mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
    mlir/test/Dialect/Complex/ops.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/PassDetail.h
    mlir/lib/Dialect/CMakeLists.txt
    mlir/test/mlir-opt/commandline.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
new file mode 100644
index 000000000000..3dab2a136b28
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
@@ -0,0 +1,29 @@
+//===- ComplexToLLVM.h - Utils to convert from the complex dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_
+#define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+class MLIRContext;
+class ModuleOp;
+template <typename T>
+class OperationPass;
+
+/// Populate the given list with patterns that convert from Complex to LLVM.
+void populateComplexToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
+/// Create a pass to convert Complex operations to the LLVMIR dialect.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertComplexToLLVMPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 2e07a795b6c7..121dae6f46f8 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
+#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index e8ca058adedd..aa228784e48a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -88,6 +88,16 @@ def ConvertAsyncToLLVM : Pass<"convert-async-to-llvm", "ModuleOp"> {
   let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ComplexToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToLLVM : Pass<"convert-complex-to-llvm", "ModuleOp"> {
+  let summary = "Convert Complex dialect to LLVM dialect";
+  let constructor = "mlir::createConvertComplexToLLVMPass()";
+  let dependentDialects = ["LLVM::LLVMDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // GPUCommon
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 51b423ee3b98..df0c751b6ba6 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(Async)
 add_subdirectory(ArmNeon)
 add_subdirectory(ArmSVE)
 add_subdirectory(AVX512)
+add_subdirectory(Complex)
 add_subdirectory(GPU)
 add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)

diff  --git a/mlir/include/mlir/Dialect/Complex/CMakeLists.txt b/mlir/include/mlir/Dialect/Complex/CMakeLists.txt
new file mode 100644
index 000000000000..f33061b2d87c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Complex/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt
new file mode 100644
index 000000000000..9fd6d4206b29
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_mlir_dialect(ComplexOps complex)
+add_mlir_doc(ComplexOps -gen-dialect-doc ComplexOps Dialects/)

diff  --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h
new file mode 100644
index 000000000000..e55c099c6eb2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h
@@ -0,0 +1,32 @@
+//===- Complex.h - Complex dialect --------------------------------*- C++-*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_
+#define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// Complex Dialect
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Complex/IR/ComplexOpsDialect.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Complex Dialect Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Complex/IR/ComplexOps.h.inc"
+
+#endif // MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_

diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
new file mode 100644
index 000000000000..ea398a07addb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
@@ -0,0 +1,23 @@
+//===- ComplexBase.td - Base definitions for complex dialect -*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef COMPLEX_BASE
+#define COMPLEX_BASE
+
+include "mlir/IR/OpBase.td"
+
+def Complex_Dialect : Dialect {
+  let name = "complex";
+  let cppNamespace = "::mlir::complex";
+  let description = [{
+    The complex dialect is intended to hold complex numbers creation and
+    arithmetic ops.
+  }];
+}
+
+#endif // COMPLEX_BASE

diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
new file mode 100644
index 000000000000..a4329df7c1aa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -0,0 +1,153 @@
+//===- ComplexOps.td - Complex op definitions ----------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef COMPLEX_OPS
+#define COMPLEX_OPS
+
+include "mlir/Dialect/Complex/IR/ComplexBase.td"
+include "mlir/Interfaces/VectorInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+class Complex_Op<string mnemonic, list<OpTrait> traits = []>
+    : Op<Complex_Dialect, mnemonic, traits>;
+
+// Base class for standard arithmetic operations on complex numbers with a
+// floating-point element type. These operations take two operands and return
+// one result, all of which must be complex numbers of the same type.
+class ComplexArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
+    Complex_Op<mnemonic,
+       !listconcat(traits, [NoSideEffect,
+                            SameOperandsAndResultType,
+                            DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
+                            ElementwiseMappable])> {
+  let arguments = (ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs);
+  let results = (outs Complex<AnyFloat>:$result);
+  let assemblyFormat = "$lhs `,` $rhs  attr-dict `:` type($result)";
+  let verifier = ?;
+}
+
+//===----------------------------------------------------------------------===//
+// AddOp
+//===----------------------------------------------------------------------===//
+
+def AddOp : ComplexArithmeticOp<"add"> {
+  let summary = "complex addition";
+  let description = [{
+    The `add` operation takes two complex numbers and returns their sum.
+
+    Example:
+
+    ```mlir
+    %a = add %b, %c : complex<f32>
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// CreateOp
+//===----------------------------------------------------------------------===//
+
+def CreateOp : Complex_Op<"create",
+    [NoSideEffect,
+     AllTypesMatch<["real", "imaginary"]>,
+     TypesMatchWith<"complex element type matches real operand type",
+                    "complex", "real",
+                    "$_self.cast<ComplexType>().getElementType()">,
+     TypesMatchWith<"complex element type matches imaginary operand type",
+                    "complex", "imaginary",
+                    "$_self.cast<ComplexType>().getElementType()">]> {
+
+  let summary = "complex number creation operation";
+  let description = [{
+    The `complex.complex` operation creates a complex number from two
+    floating-point operands, the real and the imaginary part.
+
+    Example:
+
+    ```mlir
+    %a = create_complex %b, %c : complex<f32>
+    ```
+  }];
+
+  let arguments = (ins AnyFloat:$real, AnyFloat:$imaginary);
+  let results = (outs Complex<AnyFloat>:$complex);
+
+  let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)";
+}
+
+//===----------------------------------------------------------------------===//
+// ImOp
+//===----------------------------------------------------------------------===//
+
+def ImOp : Complex_Op<"im",
+    [NoSideEffect,
+     TypesMatchWith<"complex element type matches result type",
+                    "complex", "imaginary",
+                    "$_self.cast<ComplexType>().getElementType()">]> {
+  let summary = "extracts the imaginary part of a complex number";
+  let description = [{
+    The `im` op takes a single complex number and extracts the imaginary part.
+
+    Example:
+
+    ```mlir
+    %a = im %b : complex<f32>
+    ```
+  }];
+
+  let arguments = (ins Complex<AnyFloat>:$complex);
+  let results = (outs AnyFloat:$imaginary);
+
+  let assemblyFormat = "$complex attr-dict `:` type($complex)";
+}
+
+//===----------------------------------------------------------------------===//
+// ReOp
+//===----------------------------------------------------------------------===//
+
+def ReOp : Complex_Op<"re",
+    [NoSideEffect,
+     TypesMatchWith<"complex element type matches result type",
+                    "complex", "real",
+                    "$_self.cast<ComplexType>().getElementType()">]> {
+  let summary = "extracts the real part of a complex number";
+  let description = [{
+    The `re` op takes a single complex number and extracts the real part.
+
+    Example:
+
+    ```mlir
+    %a = re %b : complex<f32>
+    ```
+  }];
+
+  let arguments = (ins Complex<AnyFloat>:$complex);
+  let results = (outs AnyFloat:$real);
+
+  let assemblyFormat = "$complex attr-dict `:` type($complex)";
+}
+
+
+//===----------------------------------------------------------------------===//
+// SubOp
+//===----------------------------------------------------------------------===//
+
+def SubOp : ComplexArithmeticOp<"sub"> {
+  let summary = "complex subtraction";
+  let description = [{
+    The `sub` operation takes two complex numbers and returns their 
diff erence.
+
+    Example:
+
+    ```mlir
+    %a = sub %b, %c : complex<f32>
+    ```
+  }];
+}
+
+#endif // COMPLEX_OPS

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 0367f9de2d18..7fd063d11464 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
 #include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
@@ -52,6 +53,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   arm_neon::ArmNeonDialect,
                   async::AsyncDialect,
                   avx512::AVX512Dialect,
+                  complex::ComplexDialect,
                   gpu::GPUDialect,
                   LLVM::LLVMAVX512Dialect,
                   LLVM::LLVMDialect,

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9fc7a40d2d55..6ba8d415e30b 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
 add_subdirectory(ArmNeonToLLVM)
 add_subdirectory(AsyncToLLVM)
 add_subdirectory(AVX512ToLLVM)
+add_subdirectory(ComplexToLLVM)
 add_subdirectory(GPUCommon)
 add_subdirectory(GPUToNVVM)
 add_subdirectory(GPUToROCDL)

diff  --git a/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt
new file mode 100644
index 000000000000..8c170dac2f27
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRComplexToLLVM
+  ComplexToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRComplex
+  MLIRLLVMIR
+  MLIRStandardOpsTransforms
+  MLIRStandardToLLVM
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
new file mode 100644
index 000000000000..1096b08da1f1
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -0,0 +1,193 @@
+//===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+
+struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
+  using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::CreateOp::Adaptor transformed(operands);
+
+    // Pack real and imaginary part in a complex number struct.
+    auto loc = complexOp.getLoc();
+    auto structType = typeConverter->convertType(complexOp.getType());
+    auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
+    complexStruct.setReal(rewriter, loc, transformed.real());
+    complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
+
+    rewriter.replaceOp(complexOp, {complexStruct});
+    return success();
+  }
+};
+
+struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
+  using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::ReOp::Adaptor transformed(operands);
+
+    // Extract real part from the complex number struct.
+    ComplexStructBuilder complexStruct(transformed.complex());
+    Value real = complexStruct.real(rewriter, op.getLoc());
+    rewriter.replaceOp(op, real);
+
+    return success();
+  }
+};
+
+struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
+  using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::ImOp::Adaptor transformed(operands);
+
+    // Extract imaginary part from the complex number struct.
+    ComplexStructBuilder complexStruct(transformed.complex());
+    Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
+    rewriter.replaceOp(op, imaginary);
+
+    return success();
+  }
+};
+
+struct BinaryComplexOperands {
+  std::complex<Value> lhs;
+  std::complex<Value> rhs;
+};
+
+template <typename OpTy>
+BinaryComplexOperands
+unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
+                            ConversionPatternRewriter &rewriter) {
+  auto loc = op.getLoc();
+  typename OpTy::Adaptor transformed(operands);
+
+  // Extract real and imaginary values from operands.
+  BinaryComplexOperands unpacked;
+  ComplexStructBuilder lhs(transformed.lhs());
+  unpacked.lhs.real(lhs.real(rewriter, loc));
+  unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
+  ComplexStructBuilder rhs(transformed.rhs());
+  unpacked.rhs.real(rhs.real(rewriter, loc));
+  unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
+
+  return unpacked;
+}
+
+struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
+  using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    BinaryComplexOperands arg =
+        unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter);
+
+    // Initialize complex number struct for result.
+    auto structType = typeConverter->convertType(op.getType());
+    auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
+
+    // Emit IR to add complex numbers.
+    auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+    Value real =
+        rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
+    Value imag =
+        rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
+    result.setReal(rewriter, loc, real);
+    result.setImaginary(rewriter, loc, imag);
+
+    rewriter.replaceOp(op, {result});
+    return success();
+  }
+};
+
+struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
+  using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    BinaryComplexOperands arg =
+        unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter);
+
+    // Initialize complex number struct for result.
+    auto structType = typeConverter->convertType(op.getType());
+    auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
+
+    // Emit IR to substract complex numbers.
+    auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+    Value real =
+        rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
+    Value imag =
+        rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
+    result.setReal(rewriter, loc, real);
+    result.setImaginary(rewriter, loc, imag);
+
+    rewriter.replaceOp(op, {result});
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateComplexToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+  // clang-format off
+  patterns.insert<
+      AddOpConversion,
+      CreateOpConversion,
+      ImOpConversion,
+      ReOpConversion,
+      SubOpConversion
+    >(converter);
+  // clang-format on
+}
+
+namespace {
+struct ConvertComplexToLLVMPass
+    : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertComplexToLLVMPass::runOnOperation() {
+  auto module = getOperation();
+
+  // Convert to the LLVM IR dialect using the converter defined above.
+  OwningRewritePatternList patterns;
+  LLVMTypeConverter converter(&getContext());
+  populateStdToLLVMConversionPatterns(converter, patterns);
+  populateComplexToLLVMConversionPatterns(converter, patterns);
+
+  LLVMConversionTarget target(getContext());
+  target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
+  if (failed(applyFullConversion(module, target, std::move(patterns))))
+    signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertComplexToLLVMPass() {
+  return std::make_unique<ConvertComplexToLLVMPass>();
+}

diff  --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index ecd932f99c78..c0e1791dc59b 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -19,6 +19,10 @@ class StandardOpsDialect;
 template <typename ConcreteDialect>
 void registerDialect(DialectRegistry &registry);
 
+namespace complex {
+class ComplexDialect;
+} // end namespace complex
+
 namespace gpu {
 class GPUDialect;
 class GPUModuleOp;

diff  --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index ae9afdc70552..295d9356e497 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(ArmNeon)
 add_subdirectory(ArmSVE)
 add_subdirectory(Async)
 add_subdirectory(AVX512)
+add_subdirectory(Complex)
 add_subdirectory(GPU)
 add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)

diff  --git a/mlir/lib/Dialect/Complex/CMakeLists.txt b/mlir/lib/Dialect/Complex/CMakeLists.txt
new file mode 100644
index 000000000000..f33061b2d87c
--- /dev/null
+++ b/mlir/lib/Dialect/Complex/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt
new file mode 100644
index 000000000000..dc8aa658174d
--- /dev/null
+++ b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRComplex
+  ComplexOps.cpp
+  ComplexDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Complex
+
+  DEPENDS
+  MLIRComplexOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRDialect
+  MLIRIR
+  )

diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
new file mode 100644
index 000000000000..44330361e95d
--- /dev/null
+++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
@@ -0,0 +1,16 @@
+//===- ComplexDialect.cpp - MLIR Complex Dialect --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+
+void mlir::complex::ComplexDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
+      >();
+}

diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
new file mode 100644
index 000000000000..6b4855dc4339
--- /dev/null
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -0,0 +1,19 @@
+//===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Complex/IR/Complex.h"
+
+using namespace mlir;
+using namespace mlir::complex;
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"

diff  --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
new file mode 100644
index 000000000000..fde21df8abf3
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s -split-input-file -convert-complex-to-llvm | FileCheck %s
+
+// CHECK-LABEL: llvm.func @complex_numbers()
+// CHECK-NEXT:    %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : f32
+// CHECK-NEXT:    %[[IMAG0:.*]] = llvm.mlir.constant(3.400000e+00 : f32) : f32
+// CHECK-NEXT:    %[[CPLX0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)>
+// CHECK-NEXT:    %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm.struct<(f32, f32)>
+// CHECK-NEXT:    %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm.struct<(f32, f32)>
+// CHECK-NEXT:    %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm.struct<(f32, f32)>
+// CHECK-NEXT:    %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm.struct<(f32, f32)>
+// CHECK-NEXT:    llvm.return
+func @complex_numbers() {
+  %real0 = constant 1.2 : f32
+  %imag0 = constant 3.4 : f32
+  %cplx2 = complex.create %real0, %imag0 : complex<f32>
+  %real1 = complex.re%cplx2 : complex<f32>
+  %imag1 = complex.im %cplx2 : complex<f32>
+  return
+}
+
+// CHECK-LABEL: llvm.func @complex_addition()
+// CHECK-DAG:     %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)>
+// CHECK:         %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] : f64
+// CHECK-DAG:     %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] : f64
+// CHECK:         %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)>
+// CHECK:         %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)>
+func @complex_addition() {
+  %a_re = constant 1.2 : f64
+  %a_im = constant 3.4 : f64
+  %a = complex.create %a_re, %a_im : complex<f64>
+  %b_re = constant 5.6 : f64
+  %b_im = constant 7.8 : f64
+  %b = complex.create %b_re, %b_im : complex<f64>
+  %c = complex.add %a, %b : complex<f64>
+  return
+}
+
+// CHECK-LABEL: llvm.func @complex_substraction()
+// CHECK-DAG:     %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)>
+// CHECK:         %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)>
+// CHECK-DAG:     %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] : f64
+// CHECK-DAG:     %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] : f64
+// CHECK:         %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)>
+// CHECK:         %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)>
+func @complex_substraction() {
+  %a_re = constant 1.2 : f64
+  %a_im = constant 3.4 : f64
+  %a = complex.create %a_re, %a_im : complex<f64>
+  %b_re = constant 5.6 : f64
+  %b_im = constant 7.8 : f64
+  %b = complex.create %b_re, %b_im : complex<f64>
+  %c = complex.sub %a, %b : complex<f64>
+  return
+}

diff  --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
new file mode 100644
index 000000000000..152e8704c5ff
--- /dev/null
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+
+// CHECK-LABEL: func @ops(
+// CHECK-SAME:            [[F:%.*]]: f32) {
+func @ops(%f: f32) {
+  // CHECK: [[C:%.*]] = complex.create [[F]], [[F]] : complex<f32>
+  %complex = complex.create %f, %f : complex<f32>
+
+  // CHECK: complex.re [[C]] : complex<f32>
+  %real = complex.re %complex : complex<f32>
+
+  // CHECK: complex.im [[C]] : complex<f32>
+  %imag = complex.im %complex : complex<f32>
+
+  // CHECK: complex.add [[C]], [[C]] : complex<f32>
+  %sum = complex.add %complex, %complex : complex<f32>
+
+  // CHECK: complex.sub [[C]], [[C]] : complex<f32>
+  %
diff  = complex.sub %complex, %complex : complex<f32>
+  return
+}
+

diff  --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 94eb94483790..bde2de3cd985 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -6,6 +6,7 @@
 // CHECK-NEXT: arm_sve
 // CHECK-NEXT: async
 // CHECK-NEXT: avx512
+// CHECK-NEXT: complex
 // CHECK-NEXT: gpu
 // CHECK-NEXT: linalg
 // CHECK-NEXT: llvm


        


More information about the Mlir-commits mailing list