[Mlir-commits] [mlir] 0ff5c32 - [mlir] Use custom mlir::Complex type for non-float complex numbers (#191821)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 6 05:51:35 PDT 2026
Author: David Truby
Date: 2026-05-06T13:51:29+01:00
New Revision: 0ff5c32c28219ea7e75869678fb5fe3b1b4b0e0d
URL: https://github.com/llvm/llvm-project/commit/0ff5c32c28219ea7e75869678fb5fe3b1b4b0e0d
DIFF: https://github.com/llvm/llvm-project/commit/0ff5c32c28219ea7e75869678fb5fe3b1b4b0e0d.diff
LOG: [mlir] Use custom mlir::Complex type for non-float complex numbers (#191821)
Instantiating std::complex for types where std::is_floating_point<T> is
false is not allowed, and throws warnings when building with MSSTL. This
patch fixes those warnings by introducing an mlir::Complex type, which
is a typedef to std::complex when T satisfies is_floating_point, and a
custom complex type otherwise.
The std::complex implementation from libc++ has been used as a guide for
implementing the custom type.
Fixes #65255
Added:
mlir/include/mlir/Support/Complex.h
mlir/unittests/Support/ComplexTest.cpp
Modified:
mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td
mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/unittests/IR/AttributeTest.cpp
mlir/unittests/Support/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td
index 52fd824f65e74..a47dc9927dd8c 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td
@@ -57,8 +57,8 @@ def Complex_NumberAttr : Complex_Attr<"Number", "number",
];
let extraClassDeclaration = [{
- std::complex<APFloat> getValue() {
- return std::complex<APFloat>(getReal(), getImag());
+ mlir::Complex<APFloat> getValue() {
+ return mlir::Complex<APFloat>(getReal(), getImag());
}
}];
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index 714e664dd0f4e..7e2190dc28084 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -22,6 +22,7 @@
#include "mlir/ExecutionEngine/SparseTensor/MapRef.h"
#include "mlir/ExecutionEngine/SparseTensor/Storage.h"
+#include "mlir/Support/Complex.h"
#include <fstream>
@@ -36,6 +37,9 @@ struct is_complex final : public std::false_type {};
template <typename T>
struct is_complex<std::complex<T>> final : public std::true_type {};
+template <typename T>
+struct is_complex<mlir::NonFloatComplex<T>> final : public std::true_type {};
+
/// Returns an element-value of non-complex type. If `IsPattern` is true,
/// then returns an arbitrary value. If `IsPattern` is false, then
/// reads the value from the current line buffer beginning at `linePtr`.
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 1f805882db276..c7eddf44fb29b 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -10,9 +10,9 @@
#define MLIR_IR_BUILTINATTRIBUTES_H
#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/Support/Complex.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"
-#include <complex>
#include <optional>
namespace mlir {
@@ -75,6 +75,8 @@ template <typename T>
struct is_complex_t : public std::false_type {};
template <typename T>
struct is_complex_t<std::complex<T>> : public std::true_type {};
+template <typename T>
+struct is_complex_t<mlir::NonFloatComplex<T>> : public std::true_type {};
} // namespace detail
/// An attribute that represents a reference to a dense vector or tensor
@@ -167,7 +169,7 @@ class DenseElementsAttr : public Attribute {
/// element type of 'type'. 'type' must be a vector or tensor with static
/// shape.
static DenseElementsAttr get(ShapedType type,
- ArrayRef<std::complex<APInt>> values);
+ ArrayRef<mlir::Complex<APInt>> values);
/// Constructs a dense float elements attribute from an array of APFloat
/// values. Each APFloat value is expected to have the same bitwidth as the
@@ -180,7 +182,7 @@ class DenseElementsAttr : public Attribute {
/// element type of 'type'. 'type' must be a vector or tensor with static
/// shape.
static DenseElementsAttr get(ShapedType type,
- ArrayRef<std::complex<APFloat>> values);
+ ArrayRef<mlir::Complex<APFloat>> values);
/// Construct a dense elements attribute for an initializer_list of values.
/// Each value is expected to be the same bitwidth of the element type of
@@ -298,11 +300,11 @@ class DenseElementsAttr : public Attribute {
/// values.
class ComplexIntElementIterator
: public detail::DenseElementIndexedIteratorImpl<
- ComplexIntElementIterator, std::complex<APInt>, std::complex<APInt>,
- std::complex<APInt>> {
+ ComplexIntElementIterator, mlir::Complex<APInt>,
+ mlir::Complex<APInt>, mlir::Complex<APInt>> {
public:
- /// Accesses the raw std::complex<APInt> value at this iterator position.
- std::complex<APInt> operator*() const;
+ /// Accesses the raw mlir::Complex<APInt> value at this iterator position.
+ mlir::Complex<APInt> operator*() const;
private:
friend DenseElementsAttr;
@@ -339,10 +341,10 @@ class DenseElementsAttr : public Attribute {
class ComplexFloatElementIterator final
: public llvm::mapped_iterator_base<ComplexFloatElementIterator,
ComplexIntElementIterator,
- std::complex<APFloat>> {
+ mlir::Complex<APFloat>> {
public:
/// Map the element to the iterator result type.
- std::complex<APFloat> mapElement(const std::complex<APInt> &value) const {
+ mlir::Complex<APFloat> mapElement(const mlir::Complex<APInt> &value) const {
return {APFloat(*smt, value.real()), APFloat(*smt, value.imag())};
}
@@ -442,7 +444,7 @@ class DenseElementsAttr : public Attribute {
ElementIterator<T>(rawData, splat, getNumElements()));
}
- /// Try to get the held element values as a range of std::complex.
+ /// Try to get the held element values as a range of mlir::Complex.
template <typename T, typename ElementT>
using ComplexValueTemplateCheckT =
std::enable_if_t<detail::is_complex_t<T>::value &&
@@ -545,7 +547,7 @@ class DenseElementsAttr : public Attribute {
/// element type of this attribute must be a complex of integer type.
template <typename T>
using ComplexAPIntValueTemplateCheckT =
- std::enable_if_t<std::is_same<T, std::complex<APInt>>::value>;
+ std::enable_if_t<std::is_same<T, mlir::Complex<APInt>>::value>;
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
FailureOr<iterator_range_impl<ComplexIntElementIterator>>
tryGetValues() const {
@@ -566,7 +568,7 @@ class DenseElementsAttr : public Attribute {
/// element type of this attribute must be a complex of float type.
template <typename T>
using ComplexAPFloatValueTemplateCheckT =
- std::enable_if_t<std::is_same<T, std::complex<APFloat>>::value>;
+ std::enable_if_t<std::is_same<T, mlir::Complex<APFloat>>::value>;
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
FailureOr<iterator_range_impl<ComplexFloatElementIterator>>
tryGetValues() const {
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 299200788136a..6165a24c0d34f 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -296,19 +296,19 @@ def Builtin_DenseTypedElementsAttr : Builtin_Attr<
uint8_t, uint16_t, uint32_t, uint64_t,
int8_t, int16_t, int32_t, int64_t,
short, unsigned short, int, unsigned, long, unsigned long,
- std::complex<uint8_t>, std::complex<uint16_t>, std::complex<uint32_t>,
- std::complex<uint64_t>,
- std::complex<int8_t>, std::complex<int16_t>, std::complex<int32_t>,
- std::complex<int64_t>,
+ mlir::Complex<uint8_t>, mlir::Complex<uint16_t>, mlir::Complex<uint32_t>,
+ mlir::Complex<uint64_t>,
+ mlir::Complex<int8_t>, mlir::Complex<int16_t>, mlir::Complex<int32_t>,
+ mlir::Complex<int64_t>,
// Float types.
- float, double, std::complex<float>, std::complex<double>
+ float, double, mlir::Complex<float>, mlir::Complex<double>
>;
using NonContiguousIterableTypesT = std::tuple<
Attribute,
// Integer types.
- APInt, bool, std::complex<APInt>,
+ APInt, bool, mlir::Complex<APInt>,
// Float types.
- APFloat, std::complex<APFloat>
+ APFloat, mlir::Complex<APFloat>
>;
/// Provide a `try_value_begin_impl` to enable iteration within
@@ -931,12 +931,12 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
APInt, bool, uint8_t, uint16_t, uint32_t, uint64_t,
int8_t, int16_t, int32_t, int64_t,
short, unsigned short, int, unsigned, long, unsigned long,
- std::complex<APInt>, std::complex<uint8_t>, std::complex<uint16_t>,
- std::complex<uint32_t>, std::complex<uint64_t>, std::complex<int8_t>,
- std::complex<int16_t>, std::complex<int32_t>, std::complex<int64_t>,
+ mlir::Complex<APInt>, mlir::Complex<uint8_t>, mlir::Complex<uint16_t>,
+ mlir::Complex<uint32_t>, mlir::Complex<uint64_t>, mlir::Complex<int8_t>,
+ mlir::Complex<int16_t>, mlir::Complex<int32_t>, mlir::Complex<int64_t>,
// Float types.
APFloat, float, double,
- std::complex<APFloat>, std::complex<float>, std::complex<double>,
+ mlir::Complex<APFloat>, mlir::Complex<float>, mlir::Complex<double>,
// String types.
StringRef
>;
@@ -978,7 +978,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
return getZeroAPInt();
}
template <typename T>
- std::enable_if_t<std::is_same<std::complex<APInt>, T>::value, T>
+ std::enable_if_t<std::is_same<mlir::Complex<APInt>, T>::value, T>
getZeroValue() const {
APInt intZero = getZeroAPInt();
return {intZero, intZero};
@@ -990,7 +990,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
return getZeroAPFloat();
}
template <typename T>
- std::enable_if_t<std::is_same<std::complex<APFloat>, T>::value, T>
+ std::enable_if_t<std::is_same<mlir::Complex<APFloat>, T>::value, T>
getZeroValue() const {
APFloat floatZero = getZeroAPFloat();
return {floatZero, floatZero};
@@ -1002,8 +1002,8 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
DenseElementsAttr::is_valid_cpp_fp_type<T>::value ||
std::is_same<T, StringRef>::value ||
(detail::is_complex_t<T>::value &&
- !llvm::is_one_of<T, std::complex<APInt>,
- std::complex<APFloat>>::value),
+ !llvm::is_one_of<T, mlir::Complex<APInt>,
+ mlir::Complex<APFloat>>::value),
T>
getZeroValue() const {
return T();
diff --git a/mlir/include/mlir/Support/Complex.h b/mlir/include/mlir/Support/Complex.h
new file mode 100644
index 0000000000000..86c06f230b8ef
--- /dev/null
+++ b/mlir/include/mlir/Support/Complex.h
@@ -0,0 +1,269 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the declaration of the mlir::NonFloatComplex type and
+/// mlir::Complex type alias. The interface is intended to match the
+/// std::complex type, and the mlir::Complex alias defers to std::complex for
+/// builtin floating point types.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_COMPLEX_H
+#define MLIR_SUPPORT_COMPLEX_H
+
+#include <complex>
+#include <type_traits>
+
+namespace mlir {
+
+// The copy constructors should only be implicit iff the underlying constructors
+// are explicit and the conversion would not narrow. This is the case if the
+// underlying destination type is copy-list-initializeable from the source type,
+// so define a helper to determine if that is the case.
+namespace detail {
+// NOLINTBEGIN
+template <typename From, typename To>
+auto test_copy_list_initializable(int)
+ -> decltype(void(std::declval<To &>() = {std::declval<From &>()}),
+ std::true_type{});
+
+template <typename, typename>
+auto test_copy_list_initializable(...) -> std::false_type;
+
+template <typename From, typename To>
+struct is_copy_list_initializable
+ : std::bool_constant<
+ decltype(detail::test_copy_list_initializable<From, To>(0))::value> {
+};
+
+template <typename From, typename To>
+constexpr bool is_copy_list_initializable_v =
+ is_copy_list_initializable<From, To>::value;
+// NOLINTEND
+} // namespace detail
+
+template <typename T>
+class NonFloatComplex {
+public:
+ using value_type = T;
+
+private:
+ T re;
+ T im;
+
+public:
+ constexpr NonFloatComplex(const T &re = T{}, const T &im = T{})
+ : re(re), im(im) {}
+
+ constexpr NonFloatComplex(const NonFloatComplex &other) = default;
+
+ template <typename U,
+ std::enable_if_t<detail::is_copy_list_initializable_v<U, T>>...>
+ constexpr NonFloatComplex(const NonFloatComplex<U> &other)
+ : re{other.re}, im{other.im} {}
+
+ template <typename U,
+ std::enable_if_t<!detail::is_copy_list_initializable_v<U, T>>...>
+ constexpr explicit NonFloatComplex(const NonFloatComplex<U> &other)
+ : re(other.re), im(other.im) {}
+
+ template <typename U,
+ std::enable_if_t<detail::is_copy_list_initializable_v<U, T>>...>
+ constexpr NonFloatComplex(const std::complex<U> &other)
+ : re{other.real()}, im{other.imag()} {}
+
+ template <typename U,
+ std::enable_if_t<!detail::is_copy_list_initializable_v<U, T>>...>
+ constexpr explicit NonFloatComplex(const std::complex<U> &other)
+ : re(other.real()), im(other.imag()) {}
+
+ [[nodiscard]] constexpr T real() const { return re; }
+ constexpr void real(T value) { re = value; }
+ [[nodiscard]] constexpr T imag() const { return im; }
+ constexpr void imag(T value) { im = value; }
+
+ constexpr NonFloatComplex &operator=(const NonFloatComplex &other) = default;
+
+ constexpr NonFloatComplex &operator=(const T &real) {
+ re = real;
+ im = T{};
+ return *this;
+ }
+
+ constexpr NonFloatComplex &operator+=(const T &real) {
+ re += real;
+ return *this;
+ }
+
+ constexpr NonFloatComplex &operator-=(const T &real) {
+ re -= real;
+ return *this;
+ }
+
+ constexpr NonFloatComplex &operator*=(const T &real) {
+ re *= real;
+ im *= real;
+ return *this;
+ }
+
+ constexpr NonFloatComplex &operator/=(const T &real) {
+ re /= real;
+ im /= real;
+ return *this;
+ }
+
+ constexpr NonFloatComplex &operator+=(const NonFloatComplex &other) {
+ re += other.re;
+ im += other.im;
+ return *this;
+ }
+
+ constexpr NonFloatComplex &operator-=(const NonFloatComplex &other) {
+ re -= other.re;
+ im -= other.im;
+ return *this;
+ }
+
+ constexpr NonFloatComplex &operator*=(const NonFloatComplex &other) {
+ *this = *this * NonFloatComplex{other.re, other.im};
+ return *this;
+ }
+
+ constexpr NonFloatComplex &operator/=(const NonFloatComplex &other) {
+ *this = *this / NonFloatComplex{other.re, other.im};
+ return *this;
+ }
+
+ template <typename U>
+ constexpr NonFloatComplex &operator=(const std::complex<U> &other) {
+ re = other.real();
+ im = other.imag();
+ return *this;
+ }
+};
+
+template <typename T, typename U>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator+(const NonFloatComplex<T> &x, const U &y) {
+ NonFloatComplex<T> t{x};
+ t += y;
+ return t;
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator-(const NonFloatComplex<T> &x, const U &y) {
+ NonFloatComplex<T> t{x};
+ t -= y;
+ return t;
+}
+
+template <typename T>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator*(const NonFloatComplex<T> &x, const NonFloatComplex<T> &y) {
+ T a = x.real();
+ T b = x.imag();
+ T c = y.real();
+ T d = y.imag();
+
+ return {(a * c) - (b * d), (a * d) + (b * c)};
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator*(const NonFloatComplex<T> &x, const U &y) {
+ NonFloatComplex<T> t{x};
+ t *= y;
+ return t;
+}
+
+template <typename T>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator/(const NonFloatComplex<T> &x, const NonFloatComplex<T> &y) {
+ T a = x.real();
+ T b = x.imag();
+ T c = y.real();
+ T d = y.imag();
+
+ T denom = c * c + d * d;
+ return {(a * c + b * d) / denom, (b * c - a * d) / denom};
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator/(const NonFloatComplex<T> &x, const U &y) {
+ NonFloatComplex<T> t{x};
+ t /= y;
+ return t;
+}
+
+template <typename T>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator+(const NonFloatComplex<T> &x) {
+ return x;
+}
+
+template <typename T>
+[[nodiscard]] constexpr NonFloatComplex<T>
+operator-(const NonFloatComplex<T> &x) {
+ return {-x.real(), -x.imag()};
+}
+
+template <typename T>
+[[nodiscard]] constexpr bool operator==(const NonFloatComplex<T> &x,
+ const NonFloatComplex<T> &y) {
+ return x.real() == y.real() && x.imag() == y.imag();
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr bool operator==(const NonFloatComplex<T> &x,
+ const U &y) {
+ return x == NonFloatComplex<T>{y};
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr bool operator==(const T &x,
+ const NonFloatComplex<U> &y) {
+ return NonFloatComplex<U>{x} == y;
+}
+
+template <typename T>
+[[nodiscard]] constexpr bool operator!=(const NonFloatComplex<T> &x,
+ const NonFloatComplex<T> &y) {
+ return !(x == y);
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr bool operator!=(const NonFloatComplex<T> &x,
+ const U &y) {
+ return !(x == y);
+}
+
+template <typename T, typename U>
+[[nodiscard]] constexpr bool operator!=(const U &x,
+ const NonFloatComplex<T> &y) {
+ return !(y == x);
+}
+
+template <typename T>
+[[nodiscard]] constexpr T real(const NonFloatComplex<T> &x) {
+ return x.real();
+}
+
+template <typename T>
+[[nodiscard]] constexpr T imag(const NonFloatComplex<T> &x) {
+ return x.imag();
+}
+
+template <typename T>
+using Complex = std::conditional_t<std::is_floating_point_v<T>, std::complex<T>,
+ NonFloatComplex<T>>;
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index d7075b795ccb9..ca8e4ae2cecbc 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -592,7 +592,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
if (isComplex) {
// If this is a complex, treat the parsed values as complex values.
auto complexData = llvm::ArrayRef(
- reinterpret_cast<std::complex<APInt> *>(intValues.data()),
+ reinterpret_cast<mlir::Complex<APInt> *>(intValues.data()),
intValues.size() / 2);
return DenseElementsAttr::get(type, complexData);
}
@@ -606,7 +606,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
if (isComplex) {
// If this is a complex, treat the parsed values as complex values.
auto complexData = llvm::ArrayRef(
- reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
+ reinterpret_cast<mlir::Complex<APFloat> *>(floatValues.data()),
floatValues.size() / 2);
return DenseElementsAttr::get(type, complexData);
}
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index cdcb3cba55752..18b3c030090e2 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -151,8 +151,8 @@ struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
};
struct BinaryComplexOperands {
- std::complex<Value> lhs;
- std::complex<Value> rhs;
+ mlir::Complex<Value> lhs;
+ mlir::Complex<Value> rhs;
};
template <typename OpTy>
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 75008d6cc2591..ec270db189081 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2688,7 +2688,7 @@ void AsmPrinter::Impl::printDenseTypedElementsAttr(DenseTypedElementsAttr attr,
// printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
// and hence was replaced.
if (llvm::isa<IntegerType>(complexElementType)) {
- auto valueIt = attr.value_begin<std::complex<APInt>>();
+ auto valueIt = attr.value_begin<mlir::Complex<APInt>>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
auto complexValue = *(valueIt + index);
os << "(";
@@ -2698,7 +2698,7 @@ void AsmPrinter::Impl::printDenseTypedElementsAttr(DenseTypedElementsAttr attr,
os << ")";
});
} else {
- auto valueIt = attr.value_begin<std::complex<APFloat>>();
+ auto valueIt = attr.value_begin<mlir::Complex<APFloat>>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
auto complexValue = *(valueIt + index);
os << "(";
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index c06ae5b178624..9fda2ef8e5059 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -649,15 +649,15 @@ APInt DenseElementsAttr::IntElementIterator::operator*() const {
DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
DenseElementsAttr attr, size_t dataIndex)
- : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
- std::complex<APInt>, std::complex<APInt>,
- std::complex<APInt>>(
- attr.getRawData().data(), attr.isSplat(), dataIndex) {
+ : DenseElementIndexedIteratorImpl<
+ ComplexIntElementIterator, mlir::Complex<APInt>, mlir::Complex<APInt>,
+ mlir::Complex<APInt>>(attr.getRawData().data(), attr.isSplat(),
+ dataIndex) {
auto complexType = llvm::cast<ComplexType>(attr.getElementType());
bitWidth = getDenseElementBitWidth(complexType.getElementType());
}
-std::complex<APInt>
+mlir::Complex<APInt>
DenseElementsAttr::ComplexIntElementIterator::operator*() const {
size_t storageWidth = getDenseElementStorageWidth(bitWidth);
size_t offset = getDataIndex() * storageWidth * 2;
@@ -922,8 +922,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
return DenseTypedElementsAttr::getRaw(type, storageBitWidth, values);
}
-DenseElementsAttr DenseElementsAttr::get(ShapedType type,
- ArrayRef<std::complex<APInt>> values) {
+DenseElementsAttr
+DenseElementsAttr::get(ShapedType type, ArrayRef<mlir::Complex<APInt>> values) {
ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
assert(llvm::isa<IntegerType>(complex.getElementType()));
assert(hasSameNumElementsOrSplat(type, values));
@@ -945,7 +945,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
}
DenseElementsAttr
DenseElementsAttr::get(ShapedType type,
- ArrayRef<std::complex<APFloat>> values) {
+ ArrayRef<mlir::Complex<APFloat>> values) {
ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
assert(llvm::isa<FloatType>(complex.getElementType()));
assert(hasSameNumElementsOrSplat(type, values));
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 900cacabd592e..642f3dd717b05 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -191,28 +191,28 @@ TEST(DenseSplatTest, StringAttrSplat) {
TEST(DenseComplexTest, ComplexFloatSplat) {
MLIRContext context;
ComplexType complexType = ComplexType::get(Float32Type::get(&context));
- std::complex<float> value(10.0, 15.0);
+ mlir::Complex<float> value(10.0, 15.0);
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexIntSplat) {
MLIRContext context;
ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
- std::complex<int64_t> value(10, 15);
+ mlir::Complex<int64_t> value(10, 15);
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexAPFloatSplat) {
MLIRContext context;
ComplexType complexType = ComplexType::get(Float32Type::get(&context));
- std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
+ mlir::Complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexAPIntSplat) {
MLIRContext context;
ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
- std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
+ mlir::Complex<APInt> value(APInt(64, 10), APInt(64, 15));
testSplat(complexType, value);
}
diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index 3a6365b401d49..f4a5ab7d5b9a0 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRSupportTests
CyclicReplacerCacheTest.cpp
+ ComplexTest.cpp
IndentedOstreamTest.cpp
StorageUniquerTest.cpp
)
diff --git a/mlir/unittests/Support/ComplexTest.cpp b/mlir/unittests/Support/ComplexTest.cpp
new file mode 100644
index 0000000000000..a91199727737e
--- /dev/null
+++ b/mlir/unittests/Support/ComplexTest.cpp
@@ -0,0 +1,256 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// This file contains the tests for the mlir::NonFloatComplex type.
+///
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/Complex.h"
+#include "gtest/gtest.h"
+
+namespace mlir {
+// Provide ostream operator so that tests pretty print NonFloatComplex values
+template <typename T>
+static std::ostream &operator<<(std::ostream &os, const NonFloatComplex<T> c) {
+ os << "(" << c.real() << "," << c.imag() << ")";
+ return os;
+}
+
+} // namespace mlir
+
+// The majority of these tests just check that NonFloatComplex does exactly the
+// same as std::complex<float>.
+
+TEST(ComplexTest, Typedef) {
+ EXPECT_TRUE((std::is_same_v<mlir::Complex<float>, std::complex<float>>));
+
+ EXPECT_TRUE((std::is_same_v<mlir::Complex<int>, mlir::NonFloatComplex<int>>));
+}
+
+TEST(ComplexTest, DefaultConstructor) {
+ mlir::NonFloatComplex<float> mc;
+ std::complex<float> sc;
+ EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, RealConstructor) {
+ mlir::NonFloatComplex<float> mc{10};
+ std::complex<float> sc{10};
+ EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, MemberConstructor) {
+ mlir::NonFloatComplex<float> mc{10, 20};
+ std::complex<float> sc{10, 20};
+ EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, ExplicitCopyConstructor) {
+ std::complex<double> sc{5, 10};
+ mlir::NonFloatComplex<float> mc{sc};
+ EXPECT_EQ(mc, sc);
+
+ // check the explicit constructors were used
+ EXPECT_FALSE((std::is_convertible_v<decltype(sc), decltype(mc)>));
+}
+
+TEST(ComplexTest, ImplicitCopyConstructor) {
+ std::complex<float> sc{};
+ mlir::NonFloatComplex<float> mc = sc;
+ EXPECT_EQ(mc, sc);
+
+ // check the implicit constructors were used
+ EXPECT_TRUE((std::is_convertible_v<decltype(sc), decltype(mc)>));
+}
+
+TEST(ComplexTest, RealAccessor) {
+ mlir::NonFloatComplex<float> mc{5};
+ std::complex<float> sc{5};
+ EXPECT_EQ(mc.real(), sc.real());
+}
+
+TEST(ComplexTest, RealSetter) {
+ mlir::NonFloatComplex<float> mc{5};
+ mc.real(7);
+ std::complex<float> sc{5};
+ sc.real(7);
+ EXPECT_EQ(mc.real(), sc.real());
+}
+
+TEST(ComplexTest, ImagAccessor) {
+ mlir::NonFloatComplex<int> mc{2, 5};
+ std::complex<float> sc{2, 5};
+ EXPECT_EQ(mc.imag(), sc.imag());
+}
+
+TEST(ComplexTest, ImagSetter) {
+ mlir::NonFloatComplex<int> mc{2, 5};
+ mc.imag(8);
+ std::complex<float> sc{2, 5};
+ sc.imag(8);
+ EXPECT_EQ(mc.imag(), sc.imag());
+}
+
+TEST(ComplexTest, CopyAssignment) {
+ mlir::NonFloatComplex<int> mc{2, 5};
+ mlir::NonFloatComplex<int> mc2 = mc;
+
+ EXPECT_EQ(mc, mc2);
+}
+
+TEST(ComplexTest, StdCopyAssignment) {
+ std::complex<float> sc{2, 5};
+ mlir::NonFloatComplex<float> mc = sc;
+
+ EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, RealAssignment) {
+ std::complex<float> sc = 2.f;
+ mlir::NonFloatComplex<float> mc = 2.f;
+
+ EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, PlusEqualsReal) {
+ mlir::NonFloatComplex<float> mc{2, 5};
+ mc += 7;
+ std::complex<float> sc{2, 5};
+ sc += 7;
+
+ EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, MinusEqualsReal) {
+ mlir::NonFloatComplex<float> mc{3, 6};
+ mc -= 8;
+ std::complex<float> sc{3, 6};
+ sc -= 8;
+
+ EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, TimesEqualsReal) {
+ mlir::NonFloatComplex<float> mc{1, 4};
+ mc *= 2;
+ std::complex<float> sc{1, 4};
+ sc *= 2;
+
+ EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, DivideEqualsReal) {
+ mlir::NonFloatComplex<float> mc{1, 4};
+ mc /= 2;
+ std::complex<float> sc{1, 4};
+ sc /= 2;
+
+ EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, AssignmentOp) {
+ mlir::NonFloatComplex<int> mc{2, 5};
+ mlir::NonFloatComplex<int> mc2 = mc;
+
+ EXPECT_EQ(mc, mc2);
+}
+
+TEST(ComplexTest, StdAssignmentOp) {
+
+ std::complex<float> sc{2, 5};
+ mlir::NonFloatComplex<float> mc = sc;
+
+ EXPECT_EQ(mc, sc);
+}
+
+TEST(ComplexTest, AddOp) {
+ mlir::NonFloatComplex<float> mc1{2, 5};
+ mlir::NonFloatComplex<float> mc2{3, 7};
+ std::complex<float> sc1{2, 5};
+ std::complex<float> sc2{3, 7};
+
+ EXPECT_EQ(mc1 + mc2, sc1 + sc2);
+ EXPECT_EQ(mc1 + 5.f, sc1 + 5.f);
+}
+
+TEST(ComplexTest, MinusOp) {
+ mlir::NonFloatComplex<float> mc1{2, 5};
+ mlir::NonFloatComplex<float> mc2{3, 7};
+ std::complex<float> sc1{2, 5};
+ std::complex<float> sc2{3, 7};
+
+ EXPECT_EQ(mc1 - mc2, sc1 - sc2);
+ EXPECT_EQ(mc1 - 5.f, sc1 - 5.f);
+}
+
+TEST(ComplexTest, TimesOp) {
+ mlir::NonFloatComplex<float> mc1{2, 5};
+ mlir::NonFloatComplex<float> mc2{3, 7};
+ std::complex<float> sc1{2, 5};
+ std::complex<float> sc2{3, 7};
+
+ EXPECT_EQ(mc1 * mc2, sc1 * sc2);
+ EXPECT_EQ(mc1 * 5.f, sc1 * 5.f);
+}
+
+TEST(ComplexTest, DivideOp) {
+ mlir::NonFloatComplex<float> mc1{5, 10};
+ mlir::NonFloatComplex<float> mc2{3, 4};
+ std::complex<float> sc1{5, 10};
+ std::complex<float> sc2{3, 4};
+
+ EXPECT_EQ(mc1 / mc2, sc1 / sc2);
+ EXPECT_EQ(mc1 / 5.f, sc1 / 5.f);
+}
+
+TEST(ComplexTest, EqualityOp) {
+ mlir::NonFloatComplex<float> mc1{3, 4};
+ mlir::NonFloatComplex<float> mc2{3, 4};
+
+ EXPECT_EQ(mc1, mc2);
+ EXPECT_EQ(mc2, mc1);
+}
+
+TEST(ComplexTest, StdEqualityOp) {
+ mlir::NonFloatComplex<float> mc{7, 8};
+ std::complex<float> sc{7, 8};
+
+ EXPECT_EQ(mc, sc);
+ EXPECT_EQ(sc, mc);
+}
+
+TEST(ComplexTest, InequalityOp) {
+ mlir::NonFloatComplex<float> mc1{3, 4};
+ mlir::NonFloatComplex<float> mc2{7, 8};
+
+ EXPECT_NE(mc1, mc2);
+ EXPECT_NE(mc2, mc1);
+}
+
+TEST(ComplexTest, StdInequalityOp) {
+ mlir::NonFloatComplex<float> mc{7, 8};
+ std::complex<float> sc{3, 4};
+
+ EXPECT_NE(mc, sc);
+ EXPECT_NE(sc, mc);
+}
+
+TEST(ComplexTest, RealFn) {
+ mlir::NonFloatComplex<float> mc{4, 6};
+ std::complex<float> sc{4, 6};
+
+ EXPECT_EQ(real(mc), real(sc));
+}
+
+TEST(ComplexTest, ImagFn) {
+ mlir::NonFloatComplex<float> mc{4, 6};
+ std::complex<float> sc{4, 6};
+
+ EXPECT_EQ(imag(mc), imag(sc));
+}
More information about the Mlir-commits
mailing list