[Mlir-commits] [mlir] 031265a - [MLIR] Add complex numbers to standard dialect

Frederik Gossen llvmlistbot at llvm.org
Mon May 4 07:04:49 PDT 2020


Author: Frederik Gossen
Date: 2020-05-04T14:04:28Z
New Revision: 031265ad8a2e02f34dd947414bc7cba40342a1c5

URL: https://github.com/llvm/llvm-project/commit/031265ad8a2e02f34dd947414bc7cba40342a1c5
DIFF: https://github.com/llvm/llvm-project/commit/031265ad8a2e02f34dd947414bc7cba40342a1c5.diff

LOG: [MLIR] Add complex numbers to standard dialect

Add `CreateComplexOp`, `ReOp`, and `ImOp` to the standard dialect.
This is the first step to support complex numbers.

Differential Revision: https://reviews.llvm.org/D79159

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 555666753360..2eae578fc966 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -26,6 +26,7 @@ class Type;
 
 namespace mlir {
 
+class ComplexType;
 class LLVMTypeConverter;
 class UnrankedMemRefType;
 
@@ -139,24 +140,29 @@ class LLVMTypeConverter : public TypeConverter {
   LLVM::LLVMDialect *llvmDialect;
 
 private:
-  // Convert a function type.  The arguments and results are converted one by
-  // one.  Additionally, if the function returns more than one value, pack the
-  // results into an LLVM IR structure type so that the converted function type
-  // returns at most one result.
+  /// Convert a function type.  The arguments and results are converted one by
+  /// one.  Additionally, if the function returns more than one value, pack the
+  /// results into an LLVM IR structure type so that the converted function type
+  /// returns at most one result.
   Type convertFunctionType(FunctionType type);
 
-  // Convert the index type.  Uses llvmModule data layout to create an integer
-  // of the pointer bitwidth.
+  /// Convert the index type.  Uses llvmModule data layout to create an integer
+  /// of the pointer bitwidth.
   Type convertIndexType(IndexType type);
 
-  // Convert an integer type `i*` to `!llvm<"i*">`.
+  /// Convert an integer type `i*` to `!llvm<"i*">`.
   Type convertIntegerType(IntegerType type);
 
-  // Convert a floating point type: `f16` to `!llvm.half`, `f32` to
-  // `!llvm.float` and `f64` to `!llvm.double`.  `bf16` is not supported
-  // by LLVM.
+  /// Convert a floating point type: `f16` to `!llvm.half`, `f32` to
+  /// `!llvm.float` and `f64` to `!llvm.double`.  `bf16` is not supported
+  /// by LLVM.
   Type convertFloatType(FloatType type);
 
+  /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`,
+  /// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to
+  /// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported.
+  Type convertComplexType(ComplexType type);
+
   /// Convert a memref type into an LLVM type that captures the relevant data.
   Type convertMemRefType(MemRefType type);
 
@@ -221,6 +227,25 @@ class StructBuilder {
   void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
 };
 
+class ComplexStructBuilder : public StructBuilder {
+public:
+  /// Construct a helper for the given complex number value.
+  using StructBuilder::StructBuilder;
+  /// Build IR creating an `undef` value of the complex number type.
+  static ComplexStructBuilder undef(OpBuilder &builder, Location loc,
+                                    Type type);
+
+  // Build IR extracting the real value from the complex number struct.
+  Value real(OpBuilder &builder, Location loc);
+  // Build IR inserting the real value into the complex number struct.
+  void setReal(OpBuilder &builder, Location loc, Value real);
+
+  // Build IR extracting the imaginary value from the complex number struct.
+  Value imaginary(OpBuilder &builder, Location loc);
+  // Build IR inserting the imaginary value into the complex number struct.
+  void setImaginary(OpBuilder &builder, Location loc, Value imaginary);
+};
+
 /// Helper class to produce LLVM dialect operations extracting or inserting
 /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
 /// The Value may be null, in which case none of the operations are valid.
@@ -476,8 +501,8 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   }
 };
 
-/// Basic lowering implementation for rewriting from Ops to LLVM Dialect Ops
-/// with one result. This supports higher-dimensional vector types.
+/// Basic lowering implementation to rewrite Ops with just one result to the
+/// LLVM Dialect. This supports higher-dimensional vector types.
 template <typename SourceOp, typename TargetOp>
 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 public:

diff  --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
index 6a7c8c53aafa..f3b9304f0c7f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
+++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
@@ -19,6 +19,7 @@ using std_addf = ValueBuilder<AddFOp>;
 using std_alloc = ValueBuilder<AllocOp>;
 using std_alloca = ValueBuilder<AllocaOp>;
 using std_call = OperationBuilder<CallOp>;
+using std_create_complex = ValueBuilder<CreateComplexOp>;
 using std_constant = ValueBuilder<ConstantOp>;
 using std_constant_float = ValueBuilder<ConstantFloatOp>;
 using std_constant_index = ValueBuilder<ConstantIndexOp>;
@@ -26,10 +27,12 @@ using std_constant_int = ValueBuilder<ConstantIntOp>;
 using std_dealloc = OperationBuilder<DeallocOp>;
 using std_dim = ValueBuilder<DimOp>;
 using std_extract_element = ValueBuilder<ExtractElementOp>;
+using std_im = ValueBuilder<ImOp>;
 using std_index_cast = ValueBuilder<IndexCastOp>;
 using std_muli = ValueBuilder<MulIOp>;
 using std_mulf = ValueBuilder<MulFOp>;
 using std_memref_cast = ValueBuilder<MemRefCastOp>;
+using std_re = ValueBuilder<ReOp>;
 using std_ret = OperationBuilder<ReturnOp>;
 using std_select = ValueBuilder<SelectOp>;
 using std_load = ValueBuilder<LoadOp>;

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 2d7ca0f48fdb..efcbdf63983e 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -981,6 +981,40 @@ def CmpIOp : Std_Op<"cmpi",
   let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
 }
 
+//===----------------------------------------------------------------------===//
+// CreateComplexOp
+//===----------------------------------------------------------------------===//
+
+def CreateComplexOp : Std_Op<"create_complex",
+    [NoSideEffect,
+     AllTypesMatch<["real", "imaginary"]>,
+     TypesMatchWith<"complex element type matches real operand type",
+                    "complex", "real",
+                    "$_self.cast<ComplexType>().getElementType()">,
+     TypesMatchWith<"complex element type matches imaginary operand type",
+                    "complex", "imaginary",
+                    "$_self.cast<ComplexType>().getElementType()">]> {
+  let summary = "creates a complex number";
+  let description = [{
+    The `create_complex` operation creates a complex number from two
+    floating-point operands, the real and the imaginary part.
+
+    Example:
+
+    ```mlir
+    %a = create_complex %b, %c : complex<f32>
+    ```
+  }];
+
+  let arguments = (ins AnyFloat:$real, AnyFloat:$imaginary);
+  let results = (outs Complex<AnyFloat>:$complex);
+
+  let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)";
+
+  // `CreateComplexOp` is fully verified by its traits.
+  let verifier = ?;
+}
+
 //===----------------------------------------------------------------------===//
 // CondBranchOp
 //===----------------------------------------------------------------------===//
@@ -1497,6 +1531,36 @@ def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
   let hasFolder = 0;
 }
 
+//===----------------------------------------------------------------------===//
+// ImOp
+//===----------------------------------------------------------------------===//
+
+def ImOp : Std_Op<"im",
+    [NoSideEffect,
+     TypesMatchWith<"complex element type matches result type",
+                    "complex", "imaginary",
+                    "$_self.cast<ComplexType>().getElementType()">]> {
+  let summary = "extracts the imaginary part of a complex number";
+  let description = [{
+    The `im` operation takes a single complex number as its operand and extracts
+    the imaginary part as a floating-point value.
+
+    Example:
+
+    ```mlir
+    %a = im %b : complex<f32>
+    ```
+  }];
+
+  let arguments = (ins Complex<AnyFloat>:$complex);
+  let results = (outs AnyFloat:$imaginary);
+
+  let assemblyFormat = "$complex attr-dict `:` type($complex)";
+
+  // `ImOp` is fully verified by its traits.
+  let verifier = ?;
+}
+
 //===----------------------------------------------------------------------===//
 // IndexCastOp
 //===----------------------------------------------------------------------===//
@@ -1877,6 +1941,36 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
   let assemblyFormat = "operands attr-dict `:` type(operands)";
 }
 
+//===----------------------------------------------------------------------===//
+// ReOp
+//===----------------------------------------------------------------------===//
+
+def ReOp : Std_Op<"re",
+    [NoSideEffect,
+     TypesMatchWith<"complex element type matches result type",
+                    "complex", "real",
+                    "$_self.cast<ComplexType>().getElementType()">]> {
+  let summary = "extracts the real part of a complex number";
+  let description = [{
+    The `re` operation takes a single complex number as its operand and extracts
+    the real part as a floating-point value.
+
+    Example:
+
+    ```mlir
+    %a = re %b : complex<f32>
+    ```
+  }];
+
+  let arguments = (ins Complex<AnyFloat>:$complex);
+  let results = (outs AnyFloat:$real);
+
+  let assemblyFormat = "$complex attr-dict `:` type($complex)";
+
+  // `ReOp` is fully verified by its traits.
+  let verifier = ?;
+}
+
 //===----------------------------------------------------------------------===//
 // RemFOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index e08986ebe59e..d6c0cde2b86a 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -138,6 +138,7 @@ LLVMTypeConverter::LLVMTypeConverter(
         module->getDataLayout().getPointerSizeInBits();
 
   // Register conversions for the standard types.
+  addConversion([&](ComplexType type) { return convertComplexType(type); });
   addConversion([&](FloatType type) { return convertFloatType(type); });
   addConversion([&](FunctionType type) { return convertFunctionType(type); });
   addConversion([&](IndexType type) { return convertIndexType(type); });
@@ -191,6 +192,17 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) {
   }
 }
 
+// Convert a `ComplexType` to an LLVM type. The result is a complex number
+// struct with entries for the
+//   1. real part and for the
+//   2. imaginary part.
+static constexpr unsigned kRealPosInComplexNumberStruct = 0;
+static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
+Type LLVMTypeConverter::convertComplexType(ComplexType type) {
+  auto elementType = convertType(type.getElementType()).cast<LLVM::LLVMType>();
+  return LLVM::LLVMType::getStructTy(llvmDialect, {elementType, elementType});
+}
+
 // Except for signatures, MLIR function types are converted into LLVM
 // pointer-to-function types.
 Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
@@ -392,6 +404,7 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
 /*============================================================================*/
 /* StructBuilder implementation                                               */
 /*============================================================================*/
+
 StructBuilder::StructBuilder(Value v) : value(v) {
   assert(value != nullptr && "value cannot be null");
   structType = value.getType().dyn_cast<LLVM::LLVMType>();
@@ -410,6 +423,35 @@ void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
   value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
                                               builder.getI64ArrayAttr(pos));
 }
+
+/*============================================================================*/
+/* ComplexStructBuilder implementation                                        */
+/*============================================================================*/
+
+ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
+                                                 Location loc, Type type) {
+  Value val = builder.create<LLVM::UndefOp>(loc, type.cast<LLVM::LLVMType>());
+  return ComplexStructBuilder(val);
+}
+
+void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
+                                   Value real) {
+  setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
+}
+
+Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
+  return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
+}
+
+void ComplexStructBuilder ::setImaginary(OpBuilder &builder, Location loc,
+                                         Value imaginary) {
+  setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
+}
+
+Value ComplexStructBuilder ::imaginary(OpBuilder &builder, Location loc) {
+  return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
+}
+
 /*============================================================================*/
 /* MemRefDescriptor implementation                                            */
 /*============================================================================*/
@@ -1284,6 +1326,65 @@ using UnsignedShiftRightOpLowering =
     OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
 using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
 
+// Lowerings for operations on complex numbers, `CreateComplexOp`, `ReOp`, and
+// `ImOp`.
+
+struct CreateComplexOpLowering
+    : public ConvertOpToLLVMPattern<CreateComplexOp> {
+  using ConvertOpToLLVMPattern<CreateComplexOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto complexOp = cast<CreateComplexOp>(op);
+    OperandAdaptor<CreateComplexOp> transformed(operands);
+
+    // Pack real and imaginary part in a complex number struct.
+    auto loc = op->getLoc();
+    auto structType = typeConverter.convertType(complexOp.getType());
+    auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
+    complexStruct.setReal(rewriter, loc, transformed.real());
+    complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
+
+    rewriter.replaceOp(op, {complexStruct});
+    return success();
+  }
+};
+
+struct ReOpLowering : public ConvertOpToLLVMPattern<ReOp> {
+  using ConvertOpToLLVMPattern<ReOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    OperandAdaptor<ReOp> transformed(operands);
+
+    // Extract real part from the complex number struct.
+    ComplexStructBuilder complexStruct(transformed.complex());
+    Value real = complexStruct.real(rewriter, op->getLoc());
+    rewriter.replaceOp(op, real);
+
+    return success();
+  }
+};
+
+struct ImOpLowering : public ConvertOpToLLVMPattern<ImOp> {
+  using ConvertOpToLLVMPattern<ImOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    OperandAdaptor<ImOp> transformed(operands);
+
+    // Extract imaginary part from the complex number struct.
+    ComplexStructBuilder complexStruct(transformed.complex());
+    Value imaginary = complexStruct.imaginary(rewriter, op->getLoc());
+    rewriter.replaceOp(op, imaginary);
+
+    return success();
+  }
+};
+
 // Check if the MemRefType `type` is supported by the lowering. We currently
 // only support memrefs with identity maps.
 static bool isSupportedMemRefType(MemRefType type) {
@@ -2896,6 +2997,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       CopySignOpLowering,
       CosOpLowering,
       ConstLLVMOpLowering,
+      CreateComplexOpLowering,
       DialectCastOpLowering,
       DivFOpLowering,
       ExpOpLowering,
@@ -2906,12 +3008,14 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       Log2OpLowering,
       FPExtLowering,
       FPTruncLowering,
+      ImOpLowering,
       IndexCastOpLowering,
       MulFOpLowering,
       MulIOpLowering,
       NegFOpLowering,
       OrOpLowering,
       PrefetchOpLowering,
+      ReOpLowering,
       RemFOpLowering,
       ReturnOpLowering,
       RsqrtOpLowering,

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index fd4ea071b8a6..1b17e46ccc1b 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -65,6 +65,24 @@ func @simple_loop() {
   return
 }
 
+// CHECK-LABEL: llvm.func @complex_numbers()
+// CHECK-NEXT:    %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : !llvm.float
+// CHECK-NEXT:    %[[IMAG0:.*]] = llvm.mlir.constant(3.400000e+00 : f32) : !llvm.float
+// CHECK-NEXT:    %[[CPLX0:.*]] = llvm.mlir.undef : !llvm<"{ float, float }">
+// CHECK-NEXT:    %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm<"{ float, float }">
+// CHECK-NEXT:    %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm<"{ float, float }">
+// CHECK-NEXT:    %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm<"{ float, float }">
+// CHECK-NEXT:    %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm<"{ float, float }">
+// CHECK-NEXT:    llvm.return
+func @complex_numbers() {
+  %real0 = constant 1.2 : f32
+  %imag0 = constant 3.4 : f32
+  %cplx2 = create_complex %real0, %imag0 : complex<f32>
+  %real1 = re %cplx2 : complex<f32>
+  %imag1 = im %cplx2 : complex<f32>
+  return
+}
+
 // CHECK-LABEL: func @simple_caller() {
 // CHECK-NEXT:  llvm.call @simple_loop() : () -> ()
 // CHECK-NEXT:  llvm.return
@@ -367,6 +385,12 @@ func @more_imperfectly_nested_loops() {
 func @get_i64() -> (i64)
 // CHECK-LABEL: func @get_f32() -> !llvm.float
 func @get_f32() -> (f32)
+// CHECK-LABEL: func @get_c16() -> !llvm<"{ half, half }">
+func @get_c16() -> (complex<f16>)
+// CHECK-LABEL: func @get_c32() -> !llvm<"{ float, float }">
+func @get_c32() -> (complex<f32>)
+// CHECK-LABEL: func @get_c64() -> !llvm<"{ double, double }">
+func @get_c64() -> (complex<f64>)
 // CHECK-LABEL: func @get_memref() -> !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
 // CHECK32-LABEL: func @get_memref() -> !llvm<"{ float*, float*, i32, [4 x i32], [4 x i32] }">
 func @get_memref() -> (memref<42x?x10x?xf32>)

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 69ba75ab481f..21718864d94b 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -86,6 +86,24 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
   // CHECK: %13 = muli %4, %4 : i32
   %i6 = muli %i2, %i2 : i32
 
+  // CHECK: %[[C0:.*]] = create_complex %[[F2:.*]], %[[F2]] : complex<f32>
+  %c0 = "std.create_complex"(%f2, %f2) : (f32, f32) -> complex<f32>
+
+  // CHECK: %[[C1:.*]] = create_complex %[[F2]], %[[F2]] : complex<f32>
+  %c1 = create_complex %f2, %f2 : complex<f32>
+
+  // CHECK: %[[REAL0:.*]] = re %[[CPLX0:.*]] : complex<f32>
+  %real0 = "std.re"(%c0) : (complex<f32>) -> f32
+
+  // CHECK: %[[REAL1:.*]] = re %[[CPLX0]] : complex<f32>
+  %real1 = re %c0 : complex<f32>
+
+  // CHECK: %[[IMAG0:.*]] = im %[[CPLX0]] : complex<f32>
+  %imag0 = "std.im"(%c0) : (complex<f32>) -> f32
+
+  // CHECK: %[[IMAG1:.*]] = im %[[CPLX0]] : complex<f32>
+  %imag1 = im %c0 : complex<f32>
+
   // CHECK: %c42_i32 = constant 42 : i32
   %x = "std.constant"(){value = 42 : i32} : () -> i32
 

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 80fdf3342995..2145c1bbc172 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1220,3 +1220,47 @@ func @assume_alignment(%0: memref<4x4xf16>) {
   // expected-error at -1 {{requires an ancestor op with AutomaticAllocationScope trait}}
   return
 }) : () -> ()
+
+// -----
+
+func @complex_number_from_non_float_operands(%real: i32, %imag: i32) {
+  // expected-error at +1 {{'complex' must be complex type with floating-point elements, but got 'complex<i32>'}}
+  std.create_complex %real, %imag : complex<i32>
+  return
+}
+
+// -----
+
+// expected-note at +1 {{prior use here}}
+func @complex_number_from_
diff erent_float_types(%real: f32, %imag: f64) {
+  // expected-error at +1 {{expects 
diff erent type than prior uses: 'f32' vs 'f64'}}
+  std.create_complex %real, %imag : complex<f32>
+  return
+}
+
+// -----
+
+// expected-note at +1 {{prior use here}}
+func @complex_number_from_incompatible_float_type(%real: f32, %imag: f32) {
+  // expected-error at +1 {{expects 
diff erent type than prior uses: 'f64' vs 'f32'}}
+  std.create_complex %real, %imag : complex<f64>
+  return
+}
+
+// -----
+
+// expected-note at +1 {{prior use here}}
+func @real_part_from_incompatible_complex_type(%cplx: complex<f32>) {
+  // expected-error at +1 {{expects 
diff erent type than prior uses: 'complex<f64>' vs 'complex<f32>'}}
+  std.re %cplx : complex<f64>
+  return
+}
+
+// -----
+
+// expected-note at +1 {{prior use here}}
+func @imaginary_part_from_incompatible_complex_type(%cplx: complex<f64>) {
+  // expected-error at +1 {{expects 
diff erent type than prior uses: 'complex<f32>' vs 'complex<f64>'}}
+  std.re %cplx : complex<f32>
+  return
+}


        


More information about the Mlir-commits mailing list