[Mlir-commits] [mlir] [mlir] Use custom mlir::Complex type for non-float complex numbers (PR #191821)

David Truby llvmlistbot at llvm.org
Mon Apr 13 07:20:52 PDT 2026


https://github.com/DavidTruby created https://github.com/llvm/llvm-project/pull/191821

Instantiating std::complex for types where std::is_floating_point<T> is false is not allowed, and throws warnings when building with MSSTL. This patch fixes those warnings by introducing an mlir::Complex type, which is a typedef to std::complex when T satisfies is_floating_point, and a custom complex type otherwise.

The std::complex implementation from libc++ has been used as a guide for implementing the custom type.

Fixes #65255

>From 945ac07a4fcd05564af892595215b00e0e2a79c0 Mon Sep 17 00:00:00 2001
From: David Truby <david.truby at arm.com>
Date: Tue, 7 Apr 2026 14:47:43 +0100
Subject: [PATCH] [mlir] Use custom mlir::Complex type for non-float complex
 numbers

Instantiating std::complex for types where std::is_floating_point<T> is
false is not allowed, and throws warnings when building with MSSTL. This
patch fixes those warnings by introducing an mlir::Complex type, which
is a typedef to std::complex when T satisfies is_floating_point, and a
custom complex type otherwise.

The std::complex implementation from libc++ has been used as a guide for
implementing the custom type.

Fixes #65255
---
 .../Dialect/Complex/IR/ComplexAttributes.td   |   4 +-
 .../mlir/ExecutionEngine/SparseTensor/File.h  |   4 +
 mlir/include/mlir/IR/BuiltinAttributes.h      |  26 +-
 mlir/include/mlir/IR/BuiltinAttributes.td     |  30 +-
 mlir/include/mlir/Support/Complex.h           | 262 ++++++++++++++++++
 mlir/lib/AsmParser/AttributeParser.cpp        |   4 +-
 .../ComplexToLLVM/ComplexToLLVM.cpp           |   4 +-
 mlir/lib/IR/AsmPrinter.cpp                    |   4 +-
 mlir/lib/IR/BuiltinAttributes.cpp             |  16 +-
 mlir/unittests/IR/AttributeTest.cpp           |   8 +-
 mlir/unittests/Support/CMakeLists.txt         |   1 +
 mlir/unittests/Support/ComplexTest.cpp        | 251 +++++++++++++++++
 12 files changed, 567 insertions(+), 47 deletions(-)
 create mode 100644 mlir/include/mlir/Support/Complex.h
 create mode 100644 mlir/unittests/Support/ComplexTest.cpp

diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td
index 52fd824f65e74..a47dc9927dd8c 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td
@@ -57,8 +57,8 @@ def Complex_NumberAttr : Complex_Attr<"Number", "number",
   ];
 
   let extraClassDeclaration = [{
-    std::complex<APFloat> getValue() {
-      return std::complex<APFloat>(getReal(), getImag());
+    mlir::Complex<APFloat> getValue() {
+      return mlir::Complex<APFloat>(getReal(), getImag());
     }
   }];
 
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index 714e664dd0f4e..7e2190dc28084 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -22,6 +22,7 @@
 
 #include "mlir/ExecutionEngine/SparseTensor/MapRef.h"
 #include "mlir/ExecutionEngine/SparseTensor/Storage.h"
+#include "mlir/Support/Complex.h"
 
 #include <fstream>
 
@@ -36,6 +37,9 @@ struct is_complex final : public std::false_type {};
 template <typename T>
 struct is_complex<std::complex<T>> final : public std::true_type {};
 
+template <typename T>
+struct is_complex<mlir::NonFloatComplex<T>> final : public std::true_type {};
+
 /// Returns an element-value of non-complex type.  If `IsPattern` is true,
 /// then returns an arbitrary value.  If `IsPattern` is false, then
 /// reads the value from the current line buffer beginning at `linePtr`.
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 1f805882db276..c7eddf44fb29b 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -10,9 +10,9 @@
 #define MLIR_IR_BUILTINATTRIBUTES_H
 
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/Support/Complex.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Sequence.h"
-#include <complex>
 #include <optional>
 
 namespace mlir {
@@ -75,6 +75,8 @@ template <typename T>
 struct is_complex_t : public std::false_type {};
 template <typename T>
 struct is_complex_t<std::complex<T>> : public std::true_type {};
+template <typename T>
+struct is_complex_t<mlir::NonFloatComplex<T>> : public std::true_type {};
 } // namespace detail
 
 /// An attribute that represents a reference to a dense vector or tensor
@@ -167,7 +169,7 @@ class DenseElementsAttr : public Attribute {
   /// element type of 'type'. 'type' must be a vector or tensor with static
   /// shape.
   static DenseElementsAttr get(ShapedType type,
-                               ArrayRef<std::complex<APInt>> values);
+                               ArrayRef<mlir::Complex<APInt>> values);
 
   /// Constructs a dense float elements attribute from an array of APFloat
   /// values. Each APFloat value is expected to have the same bitwidth as the
@@ -180,7 +182,7 @@ class DenseElementsAttr : public Attribute {
   /// element type of 'type'. 'type' must be a vector or tensor with static
   /// shape.
   static DenseElementsAttr get(ShapedType type,
-                               ArrayRef<std::complex<APFloat>> values);
+                               ArrayRef<mlir::Complex<APFloat>> values);
 
   /// Construct a dense elements attribute for an initializer_list of values.
   /// Each value is expected to be the same bitwidth of the element type of
@@ -298,11 +300,11 @@ class DenseElementsAttr : public Attribute {
   /// values.
   class ComplexIntElementIterator
       : public detail::DenseElementIndexedIteratorImpl<
-            ComplexIntElementIterator, std::complex<APInt>, std::complex<APInt>,
-            std::complex<APInt>> {
+            ComplexIntElementIterator, mlir::Complex<APInt>,
+            mlir::Complex<APInt>, mlir::Complex<APInt>> {
   public:
-    /// Accesses the raw std::complex<APInt> value at this iterator position.
-    std::complex<APInt> operator*() const;
+    /// Accesses the raw mlir::Complex<APInt> value at this iterator position.
+    mlir::Complex<APInt> operator*() const;
 
   private:
     friend DenseElementsAttr;
@@ -339,10 +341,10 @@ class DenseElementsAttr : public Attribute {
   class ComplexFloatElementIterator final
       : public llvm::mapped_iterator_base<ComplexFloatElementIterator,
                                           ComplexIntElementIterator,
-                                          std::complex<APFloat>> {
+                                          mlir::Complex<APFloat>> {
   public:
     /// Map the element to the iterator result type.
-    std::complex<APFloat> mapElement(const std::complex<APInt> &value) const {
+    mlir::Complex<APFloat> mapElement(const mlir::Complex<APInt> &value) const {
       return {APFloat(*smt, value.real()), APFloat(*smt, value.imag())};
     }
 
@@ -442,7 +444,7 @@ class DenseElementsAttr : public Attribute {
         ElementIterator<T>(rawData, splat, getNumElements()));
   }
 
-  /// Try to get the held element values as a range of std::complex.
+  /// Try to get the held element values as a range of mlir::Complex.
   template <typename T, typename ElementT>
   using ComplexValueTemplateCheckT =
       std::enable_if_t<detail::is_complex_t<T>::value &&
@@ -545,7 +547,7 @@ class DenseElementsAttr : public Attribute {
   /// element type of this attribute must be a complex of integer type.
   template <typename T>
   using ComplexAPIntValueTemplateCheckT =
-      std::enable_if_t<std::is_same<T, std::complex<APInt>>::value>;
+      std::enable_if_t<std::is_same<T, mlir::Complex<APInt>>::value>;
   template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
   FailureOr<iterator_range_impl<ComplexIntElementIterator>>
   tryGetValues() const {
@@ -566,7 +568,7 @@ class DenseElementsAttr : public Attribute {
   /// element type of this attribute must be a complex of float type.
   template <typename T>
   using ComplexAPFloatValueTemplateCheckT =
-      std::enable_if_t<std::is_same<T, std::complex<APFloat>>::value>;
+      std::enable_if_t<std::is_same<T, mlir::Complex<APFloat>>::value>;
   template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
   FailureOr<iterator_range_impl<ComplexFloatElementIterator>>
   tryGetValues() const {
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 299200788136a..6165a24c0d34f 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -296,19 +296,19 @@ def Builtin_DenseTypedElementsAttr : Builtin_Attr<
       uint8_t, uint16_t, uint32_t, uint64_t,
       int8_t, int16_t, int32_t, int64_t,
       short, unsigned short, int, unsigned, long, unsigned long,
-      std::complex<uint8_t>, std::complex<uint16_t>, std::complex<uint32_t>,
-      std::complex<uint64_t>,
-      std::complex<int8_t>, std::complex<int16_t>, std::complex<int32_t>,
-      std::complex<int64_t>,
+      mlir::Complex<uint8_t>, mlir::Complex<uint16_t>, mlir::Complex<uint32_t>,
+      mlir::Complex<uint64_t>,
+      mlir::Complex<int8_t>, mlir::Complex<int16_t>, mlir::Complex<int32_t>,
+      mlir::Complex<int64_t>,
       // Float types.
-      float, double, std::complex<float>, std::complex<double>
+      float, double, mlir::Complex<float>, mlir::Complex<double>
     >;
     using NonContiguousIterableTypesT = std::tuple<
       Attribute,
       // Integer types.
-      APInt, bool, std::complex<APInt>,
+      APInt, bool, mlir::Complex<APInt>,
       // Float types.
-      APFloat, std::complex<APFloat>
+      APFloat, mlir::Complex<APFloat>
     >;
 
     /// Provide a `try_value_begin_impl` to enable iteration within
@@ -931,12 +931,12 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
       APInt, bool, uint8_t, uint16_t, uint32_t, uint64_t,
       int8_t, int16_t, int32_t, int64_t,
       short, unsigned short, int, unsigned, long, unsigned long,
-      std::complex<APInt>, std::complex<uint8_t>, std::complex<uint16_t>,
-      std::complex<uint32_t>, std::complex<uint64_t>, std::complex<int8_t>,
-      std::complex<int16_t>, std::complex<int32_t>, std::complex<int64_t>,
+      mlir::Complex<APInt>, mlir::Complex<uint8_t>, mlir::Complex<uint16_t>,
+      mlir::Complex<uint32_t>, mlir::Complex<uint64_t>, mlir::Complex<int8_t>,
+      mlir::Complex<int16_t>, mlir::Complex<int32_t>, mlir::Complex<int64_t>,
       // Float types.
       APFloat, float, double,
-      std::complex<APFloat>, std::complex<float>, std::complex<double>,
+      mlir::Complex<APFloat>, mlir::Complex<float>, mlir::Complex<double>,
       // String types.
       StringRef
     >;
@@ -978,7 +978,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
       return getZeroAPInt();
     }
     template <typename T>
-    std::enable_if_t<std::is_same<std::complex<APInt>, T>::value, T>
+    std::enable_if_t<std::is_same<mlir::Complex<APInt>, T>::value, T>
     getZeroValue() const {
       APInt intZero = getZeroAPInt();
       return {intZero, intZero};
@@ -990,7 +990,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
       return getZeroAPFloat();
     }
     template <typename T>
-    std::enable_if_t<std::is_same<std::complex<APFloat>, T>::value, T>
+    std::enable_if_t<std::is_same<mlir::Complex<APFloat>, T>::value, T>
     getZeroValue() const {
       APFloat floatZero = getZeroAPFloat();
       return {floatZero, floatZero};
@@ -1002,8 +1002,8 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
                          DenseElementsAttr::is_valid_cpp_fp_type<T>::value ||
                          std::is_same<T, StringRef>::value ||
                          (detail::is_complex_t<T>::value &&
-                          !llvm::is_one_of<T, std::complex<APInt>,
-                                           std::complex<APFloat>>::value),
+                          !llvm::is_one_of<T, mlir::Complex<APInt>,
+                                           mlir::Complex<APFloat>>::value),
                      T>
     getZeroValue() const {
       return T();
diff --git a/mlir/include/mlir/Support/Complex.h b/mlir/include/mlir/Support/Complex.h
new file mode 100644
index 0000000000000..970c9dc03cf61
--- /dev/null
+++ b/mlir/include/mlir/Support/Complex.h
@@ -0,0 +1,262 @@
+
+//===- Complex.h - Complex Number type for use in MLIR ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_COMPLEX_H
+#define MLIR_SUPPORT_COMPLEX_H
+
+#include <complex>
+#include <type_traits>
+
+namespace mlir {
+
+// The copy constructors should only be implicit iff the underlying constructors
+// are explicit and the conversion would not narrow. This is the case if the
+// underlying destination type is copy-list-initializeable from the source type,
+// so define a helper to determine if that is the case.
+namespace detail {
+// NOLINTBEGIN
+template <typename From, typename To>
+auto test_copy_list_initializable(int)
+    -> decltype(void(std::declval<To &>() = {std::declval<From &>()}),
+                std::true_type{});
+
+template <typename, typename>
+auto test_copy_list_initializable(...) -> std::false_type;
+
+template <typename From, typename To>
+struct is_copy_list_initializable
+    : std::bool_constant<
+          decltype(detail::test_copy_list_initializable<From, To>(0))::value> {
+};
+
+template <typename From, typename To>
+constexpr bool is_copy_list_initializable_v =
+    is_copy_list_initializable<From, To>::value;
+// NOLINTEND
+} // namespace detail
+
+template <typename T>
+class NonFloatComplex {
+public:
+  using value_type = T;
+
+private:
+  T re;
+  T im;
+
+public:
+  constexpr NonFloatComplex(const T &re = T{}, const T &im = T{})
+      : re(re), im(im) {}
+
+  constexpr NonFloatComplex(const NonFloatComplex &other) = default;
+
+  template <typename U,
+            std::enable_if_t<detail::is_copy_list_initializable_v<U, T>>...>
+  constexpr NonFloatComplex(const NonFloatComplex<U> &other)
+      : re{other.re}, im{other.im} {}
+
+  template <typename U,
+            std::enable_if_t<!detail::is_copy_list_initializable_v<U, T>>...>
+  constexpr explicit NonFloatComplex(const NonFloatComplex<U> &other)
+      : re(other.re), im(other.im) {}
+
+  template <typename U,
+            std::enable_if_t<detail::is_copy_list_initializable_v<U, T>>...>
+  constexpr NonFloatComplex(const std::complex<U> &other)
+      : re{other.real()}, im{other.imag()} {}
+
+  template <typename U,
+            std::enable_if_t<!detail::is_copy_list_initializable_v<U, T>>...>
+  constexpr explicit NonFloatComplex(const std::complex<U> &other)
+      : re(other.real()), im(other.imag()) {}
+
+  [[nodiscard]] constexpr T real() const { return re; }
+  constexpr void real(T value) { re = value; }
+  [[nodiscard]] constexpr T imag() const { return im; }
+  constexpr void imag(T value) { im = value; }
+
+  constexpr NonFloatComplex &operator=(const NonFloatComplex &other) = default;
+
+  constexpr NonFloatComplex &operator=(const T &real) {
+    re = real;
+    im = T{};
+    return *this;
+  }
+
+  constexpr NonFloatComplex &operator+=(const T &real) {
+    re += real;
+    return *this;
+  }
+
+  constexpr NonFloatComplex &operator-=(const T &real) {
+    re -= real;
+    return *this;
+  }
+
+  constexpr NonFloatComplex &operator*=(const T &real) {
+    re *= real;
+    im *= real;
+    return *this;
+  }
+
+  constexpr NonFloatComplex &operator/=(const T &real) {
+    re /= real;
+    im /= real;
+    return *this;
+  }
+
+  constexpr NonFloatComplex &operator+=(const NonFloatComplex &other) {
+    re += other.re;
+    im += other.im;
+    return *this;
+  }
+
+  constexpr NonFloatComplex &operator-=(const NonFloatComplex &other) {
+    re -= other.re;
+    im -= other.im;
+    return *this;
+  }
+
+  constexpr NonFloatComplex &operator*=(const NonFloatComplex &other) {
+    *this = *this * NonFloatComplex{other.re, other.im};
+    return *this;
+  }
+
+  constexpr NonFloatComplex &operator/=(const NonFloatComplex &other) {
+    *this = *this / NonFloatComplex{other.re, other.im};
+    return *this;
+  }
+
+  template <typename U>
+  constexpr NonFloatComplex &operator=(const std::complex<U> &other) {
+    re = other.real();
+    im = other.imag();
+    return *this;
+  }
+};
+
+template <typename T, typename U>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator+(const NonFloatComplex<T> &x, const U &y) {
+  NonFloatComplex<T> t{x};
+  t += y;
+  return t;
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator-(const NonFloatComplex<T> &x, const U &y) {
+  NonFloatComplex<T> t{x};
+  t -= y;
+  return t;
+}
+
+template <typename T>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator*(const NonFloatComplex<T> &x, const NonFloatComplex<T> &y) {
+  T a = x.real();
+  T b = x.imag();
+  T c = y.real();
+  T d = y.imag();
+
+  return {(a * c) - (b * d), (a * d) + (b * c)};
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator*(const NonFloatComplex<T> &x, const U &y) {
+  NonFloatComplex<T> t{x};
+  t *= y;
+  return t;
+}
+
+template <typename T>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator/(const NonFloatComplex<T> &x, const NonFloatComplex<T> &y) {
+  T a = x.real();
+  T b = x.imag();
+  T c = y.real();
+  T d = y.imag();
+
+  T denom = c * c + d * d;
+  return {(a * c + b * d) / denom, (b * c - a * d) / denom};
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator/(const NonFloatComplex<T> &x, const U &y) {
+  NonFloatComplex<T> t{x};
+  t /= y;
+  return t;
+}
+
+template <typename T>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator+(const NonFloatComplex<T> &x) {
+  return x;
+}
+
+template <typename T>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator-(const NonFloatComplex<T> &x) {
+  return {-x.real(), -x.imag()};
+}
+
+template <typename T>
+[[nodiscard]] constexpr bool operator==(const NonFloatComplex<T> &x,
+                                        const NonFloatComplex<T> &y) {
+  return x.real() == y.real() && x.imag() == y.imag();
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr bool operator==(const NonFloatComplex<T> &x,
+                                        const U &y) {
+  return x == NonFloatComplex<T>{y};
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr bool operator==(const T &x,
+                                        const NonFloatComplex<U> &y) {
+  return NonFloatComplex<U>{x} == y;
+}
+
+template <typename T>
+[[nodiscard]] constexpr bool operator!=(const NonFloatComplex<T> &x,
+                                        const NonFloatComplex<T> &y) {
+  return !(x == y);
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr bool operator!=(const NonFloatComplex<T> &x,
+                                        const U &y) {
+  return !(x == y);
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr bool operator!=(const U &x,
+                                        const NonFloatComplex<T> &y) {
+  return !(y == x);
+}
+
+template <typename T>
+[[nodiscard]] constexpr T real(const NonFloatComplex<T> &x) {
+  return x.real();
+}
+
+template <typename T>
+[[nodiscard]] constexpr T imag(const NonFloatComplex<T> &x) {
+  return x.imag();
+}
+
+template <typename T>
+using Complex = std::conditional_t<std::is_floating_point_v<T>, std::complex<T>,
+                                   NonFloatComplex<T>>;
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index d7075b795ccb9..ca8e4ae2cecbc 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -592,7 +592,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
     if (isComplex) {
       // If this is a complex, treat the parsed values as complex values.
       auto complexData = llvm::ArrayRef(
-          reinterpret_cast<std::complex<APInt> *>(intValues.data()),
+          reinterpret_cast<mlir::Complex<APInt> *>(intValues.data()),
           intValues.size() / 2);
       return DenseElementsAttr::get(type, complexData);
     }
@@ -606,7 +606,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
     if (isComplex) {
       // If this is a complex, treat the parsed values as complex values.
       auto complexData = llvm::ArrayRef(
-          reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
+          reinterpret_cast<mlir::Complex<APFloat> *>(floatValues.data()),
           floatValues.size() / 2);
       return DenseElementsAttr::get(type, complexData);
     }
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index cdcb3cba55752..18b3c030090e2 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -151,8 +151,8 @@ struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
 };
 
 struct BinaryComplexOperands {
-  std::complex<Value> lhs;
-  std::complex<Value> rhs;
+  mlir::Complex<Value> lhs;
+  mlir::Complex<Value> rhs;
 };
 
 template <typename OpTy>
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 75008d6cc2591..ec270db189081 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2688,7 +2688,7 @@ void AsmPrinter::Impl::printDenseTypedElementsAttr(DenseTypedElementsAttr attr,
     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
     // and hence was replaced.
     if (llvm::isa<IntegerType>(complexElementType)) {
-      auto valueIt = attr.value_begin<std::complex<APInt>>();
+      auto valueIt = attr.value_begin<mlir::Complex<APInt>>();
       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
         auto complexValue = *(valueIt + index);
         os << "(";
@@ -2698,7 +2698,7 @@ void AsmPrinter::Impl::printDenseTypedElementsAttr(DenseTypedElementsAttr attr,
         os << ")";
       });
     } else {
-      auto valueIt = attr.value_begin<std::complex<APFloat>>();
+      auto valueIt = attr.value_begin<mlir::Complex<APFloat>>();
       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
         auto complexValue = *(valueIt + index);
         os << "(";
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index c06ae5b178624..9fda2ef8e5059 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -649,15 +649,15 @@ APInt DenseElementsAttr::IntElementIterator::operator*() const {
 
 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
     DenseElementsAttr attr, size_t dataIndex)
-    : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
-                                      std::complex<APInt>, std::complex<APInt>,
-                                      std::complex<APInt>>(
-          attr.getRawData().data(), attr.isSplat(), dataIndex) {
+    : DenseElementIndexedIteratorImpl<
+          ComplexIntElementIterator, mlir::Complex<APInt>, mlir::Complex<APInt>,
+          mlir::Complex<APInt>>(attr.getRawData().data(), attr.isSplat(),
+                                dataIndex) {
   auto complexType = llvm::cast<ComplexType>(attr.getElementType());
   bitWidth = getDenseElementBitWidth(complexType.getElementType());
 }
 
-std::complex<APInt>
+mlir::Complex<APInt>
 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
   size_t offset = getDataIndex() * storageWidth * 2;
@@ -922,8 +922,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
   return DenseTypedElementsAttr::getRaw(type, storageBitWidth, values);
 }
-DenseElementsAttr DenseElementsAttr::get(ShapedType type,
-                                         ArrayRef<std::complex<APInt>> values) {
+DenseElementsAttr
+DenseElementsAttr::get(ShapedType type, ArrayRef<mlir::Complex<APInt>> values) {
   ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
   assert(llvm::isa<IntegerType>(complex.getElementType()));
   assert(hasSameNumElementsOrSplat(type, values));
@@ -945,7 +945,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
 }
 DenseElementsAttr
 DenseElementsAttr::get(ShapedType type,
-                       ArrayRef<std::complex<APFloat>> values) {
+                       ArrayRef<mlir::Complex<APFloat>> values) {
   ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
   assert(llvm::isa<FloatType>(complex.getElementType()));
   assert(hasSameNumElementsOrSplat(type, values));
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 900cacabd592e..642f3dd717b05 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -191,28 +191,28 @@ TEST(DenseSplatTest, StringAttrSplat) {
 TEST(DenseComplexTest, ComplexFloatSplat) {
   MLIRContext context;
   ComplexType complexType = ComplexType::get(Float32Type::get(&context));
-  std::complex<float> value(10.0, 15.0);
+  mlir::Complex<float> value(10.0, 15.0);
   testSplat(complexType, value);
 }
 
 TEST(DenseComplexTest, ComplexIntSplat) {
   MLIRContext context;
   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
-  std::complex<int64_t> value(10, 15);
+  mlir::Complex<int64_t> value(10, 15);
   testSplat(complexType, value);
 }
 
 TEST(DenseComplexTest, ComplexAPFloatSplat) {
   MLIRContext context;
   ComplexType complexType = ComplexType::get(Float32Type::get(&context));
-  std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
+  mlir::Complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
   testSplat(complexType, value);
 }
 
 TEST(DenseComplexTest, ComplexAPIntSplat) {
   MLIRContext context;
   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
-  std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
+  mlir::Complex<APInt> value(APInt(64, 10), APInt(64, 15));
   testSplat(complexType, value);
 }
 
diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index 3a6365b401d49..f4a5ab7d5b9a0 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRSupportTests
   CyclicReplacerCacheTest.cpp
+  ComplexTest.cpp
   IndentedOstreamTest.cpp
   StorageUniquerTest.cpp
 )
diff --git a/mlir/unittests/Support/ComplexTest.cpp b/mlir/unittests/Support/ComplexTest.cpp
new file mode 100644
index 0000000000000..a58413ee2e458
--- /dev/null
+++ b/mlir/unittests/Support/ComplexTest.cpp
@@ -0,0 +1,251 @@
+//===- ComplexTest.cpp - mlir::complex Tests ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/Complex.h"
+#include "gtest/gtest.h"
+
+namespace mlir {
+// Provide ostream operator so that tests pretty print NonFloatComplex values
+template <typename T>
+static std::ostream &operator<<(std::ostream &os, const NonFloatComplex<T> c) {
+  os << "(" << c.real() << "," << c.imag() << ")";
+  return os;
+}
+
+} // namespace mlir
+
+// The majority of these tests just check that NonFloatComplex does exactly the
+// same as std::complex<float>.
+
+TEST(ComplexTest, Typedef) {
+  EXPECT_TRUE((std::is_same_v<mlir::Complex<float>, std::complex<float>>));
+
+  EXPECT_TRUE((std::is_same_v<mlir::Complex<int>, mlir::NonFloatComplex<int>>));
+}
+
+TEST(ComplexTest, DefaultConstructor) {
+  mlir::NonFloatComplex<float> mc;
+  std::complex<float> sc;
+  EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, RealConstructor) {
+  mlir::NonFloatComplex<float> mc{10};
+  std::complex<float> sc{10};
+  EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, MemberConstructor) {
+  mlir::NonFloatComplex<float> mc{10, 20};
+  std::complex<float> sc{10, 20};
+  EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, ExplicitCopyConstructor) {
+  std::complex<double> sc{5, 10};
+  mlir::NonFloatComplex<float> mc{sc};
+  EXPECT_EQ(mc, sc);
+
+  // check the explicit constructors were used
+  EXPECT_FALSE((std::is_convertible_v<decltype(sc), decltype(mc)>));
+}
+
+TEST(ComplexTest, ImplicitCopyConstructor) {
+  std::complex<float> sc{};
+  mlir::NonFloatComplex<float> mc = sc;
+  EXPECT_EQ(mc, sc);
+
+  // check the implicit constructors were used
+  EXPECT_TRUE((std::is_convertible_v<decltype(sc), decltype(mc)>));
+}
+
+TEST(ComplexTest, RealAccessor) {
+  mlir::NonFloatComplex<float> mc{5};
+  std::complex<float> sc{5};
+  EXPECT_EQ(mc.real(), sc.real());
+}
+
+TEST(ComplexTest, RealSetter) {
+  mlir::NonFloatComplex<float> mc{5};
+  mc.real(7);
+  std::complex<float> sc{5};
+  sc.real(7);
+  EXPECT_EQ(mc.real(), sc.real());
+}
+
+TEST(ComplexTest, ImagAccessor) {
+  mlir::NonFloatComplex<int> mc{2, 5};
+  std::complex<float> sc{2, 5};
+  EXPECT_EQ(mc.imag(), sc.imag());
+}
+
+TEST(ComplexTest, ImagSetter) {
+  mlir::NonFloatComplex<int> mc{2, 5};
+  mc.imag(8);
+  std::complex<float> sc{2, 5};
+  sc.imag(8);
+  EXPECT_EQ(mc.imag(), sc.imag());
+}
+
+TEST(ComplexTest, CopyAssignment) {
+  mlir::NonFloatComplex<int> mc{2, 5};
+  mlir::NonFloatComplex<int> mc2 = mc;
+
+  EXPECT_EQ(mc, mc2);
+}
+
+TEST(ComplexTest, StdCopyAssignment) {
+  std::complex<float> sc{2, 5};
+  mlir::NonFloatComplex<float> mc = sc;
+
+  EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, RealAssignment) {
+  std::complex<float> sc = 2.f;
+  mlir::NonFloatComplex<float> mc = 2.f;
+
+  EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, PlusEqualsReal) {
+  mlir::NonFloatComplex<float> mc{2, 5};
+  mc += 7;
+  std::complex<float> sc{2, 5};
+  sc += 7;
+
+  EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, MinusEqualsReal) {
+  mlir::NonFloatComplex<float> mc{3, 6};
+  mc -= 8;
+  std::complex<float> sc{3, 6};
+  sc -= 8;
+
+  EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, TimesEqualsReal) {
+  mlir::NonFloatComplex<float> mc{1, 4};
+  mc *= 2;
+  std::complex<float> sc{1, 4};
+  sc *= 2;
+
+  EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, DivideEqualsReal) {
+  mlir::NonFloatComplex<float> mc{1, 4};
+  mc /= 2;
+  std::complex<float> sc{1, 4};
+  sc /= 2;
+
+  EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, AssignmentOp) {
+  mlir::NonFloatComplex<int> mc{2, 5};
+  mlir::NonFloatComplex<int> mc2 = mc;
+
+  EXPECT_EQ(mc, mc2);
+}
+
+TEST(ComplexTest, StdAssignmentOp) {
+
+  std::complex<float> sc{2, 5};
+  mlir::NonFloatComplex<float> mc = sc;
+
+  EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, AddOp) {
+  mlir::NonFloatComplex<float> mc1{2, 5};
+  mlir::NonFloatComplex<float> mc2{3, 7};
+  std::complex<float> sc1{2, 5};
+  std::complex<float> sc2{3, 7};
+
+  EXPECT_EQ(mc1 + mc2, sc1 + sc2);
+  EXPECT_EQ(mc1 + 5.f, sc1 + 5.f);
+}
+
+TEST(ComplexTest, MinusOp) {
+  mlir::NonFloatComplex<float> mc1{2, 5};
+  mlir::NonFloatComplex<float> mc2{3, 7};
+  std::complex<float> sc1{2, 5};
+  std::complex<float> sc2{3, 7};
+
+  EXPECT_EQ(mc1 - mc2, sc1 - sc2);
+  EXPECT_EQ(mc1 - 5.f, sc1 - 5.f);
+}
+
+TEST(ComplexTest, TimesOp) {
+  mlir::NonFloatComplex<float> mc1{2, 5};
+  mlir::NonFloatComplex<float> mc2{3, 7};
+  std::complex<float> sc1{2, 5};
+  std::complex<float> sc2{3, 7};
+
+  EXPECT_EQ(mc1 * mc2, sc1 * sc2);
+  EXPECT_EQ(mc1 * 5.f, sc1 * 5.f);
+}
+
+TEST(ComplexTest, DivideOp) {
+  mlir::NonFloatComplex<float> mc1{5, 10};
+  mlir::NonFloatComplex<float> mc2{3, 4};
+  std::complex<float> sc1{5, 10};
+  std::complex<float> sc2{3, 4};
+
+  EXPECT_EQ(mc1 / mc2, sc1 / sc2);
+  EXPECT_EQ(mc1 / 5.f, sc1 / 5.f);
+}
+
+TEST(ComplexTest, EqualityOp) {
+  mlir::NonFloatComplex<float> mc1{3, 4};
+  mlir::NonFloatComplex<float> mc2{3, 4};
+
+  EXPECT_EQ(mc1, mc2);
+  EXPECT_EQ(mc2, mc1);
+}
+
+TEST(ComplexTest, StdEqualityOp) {
+  mlir::NonFloatComplex<float> mc{7, 8};
+  std::complex<float> sc{7, 8};
+
+  EXPECT_EQ(mc, sc);
+  EXPECT_EQ(sc, mc);
+}
+
+TEST(ComplexTest, InequalityOp) {
+  mlir::NonFloatComplex<float> mc1{3, 4};
+  mlir::NonFloatComplex<float> mc2{7, 8};
+
+  EXPECT_NE(mc1, mc2);
+  EXPECT_NE(mc2, mc1);
+}
+
+TEST(ComplexTest, StdInequalityOp) {
+  mlir::NonFloatComplex<float> mc{7, 8};
+  std::complex<float> sc{3, 4};
+
+  EXPECT_NE(mc, sc);
+  EXPECT_NE(sc, mc);
+}
+
+TEST(ComplexTest, RealFn) {
+  mlir::NonFloatComplex<float> mc{4, 6};
+  std::complex<float> sc{4, 6};
+
+  EXPECT_EQ(real(mc), real(sc));
+}
+
+TEST(ComplexTest, ImagFn) {
+  mlir::NonFloatComplex<float> mc{4, 6};
+  std::complex<float> sc{4, 6};
+
+  EXPECT_EQ(imag(mc), imag(sc));
+}



More information about the Mlir-commits mailing list