[Mlir-commits] [mlir] d6be277 - [mlir] turn complex-to-llvm into a partial conversion
Alex Zinenko
llvmlistbot at llvm.org
Thu Jan 28 10:14:10 PST 2021
Author: Alex Zinenko
Date: 2021-01-28T19:14:01+01:00
New Revision: d6be27734764ec9718b4206c4b77dfac6c67b1d8
URL: https://github.com/llvm/llvm-project/commit/d6be27734764ec9718b4206c4b77dfac6c67b1d8
DIFF: https://github.com/llvm/llvm-project/commit/d6be27734764ec9718b4206c4b77dfac6c67b1d8.diff
LOG: [mlir] turn complex-to-llvm into a partial conversion
It is no longer necessary to also convert other "standard" ops along with the
complex dialect: the element types are now built-in integers or floating point
types, and the top-level cast between complex and struct is automatically
inserted and removed in progressive lowering.
Reviewed By: herhut
Differential Revision: https://reviews.llvm.org/D95625
Added:
mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir
Modified:
mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/LLVMIR/dialect-cast.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index 270b9489625c..4421a9fb4808 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -286,12 +286,13 @@ void ConvertComplexToLLVMPass::runOnOperation() {
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
- populateStdToLLVMConversionPatterns(converter, patterns);
populateComplexToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
- target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
- if (failed(applyFullConversion(module, target, std::move(patterns))))
+ target.addLegalOp<ModuleOp, FuncOp>();
+ target.addLegalOp<LLVM::DialectCastOp>();
+ target.addIllegalDialect<complex::ComplexDialect>();
+ if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a31402e6b284..adf7ff7b74f5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1375,6 +1375,17 @@ static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type,
return success();
}
+ // Complex types are compatible with the two-element structs.
+ if (auto complexType = type.dyn_cast<ComplexType>()) {
+ auto structType = llvmType.dyn_cast<LLVMStructType>();
+ if (!structType || structType.getBody().size() != 2 ||
+ structType.getBody()[0] != structType.getBody()[1] ||
+ structType.getBody()[0] != complexType.getElementType())
+ return op->emitOpError("expected 'complex' to map to two-element struct "
+ "with identical element types");
+ return success();
+ }
+
// Everything else is not supported.
return op->emitError("unsupported cast");
}
diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
index ffc7bbfb3ec5..6ad3595915a0 100644
--- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -1,14 +1,13 @@
-// RUN: mlir-opt %s -split-input-file -convert-complex-to-llvm | FileCheck %s
+// RUN: mlir-opt %s -convert-complex-to-llvm | FileCheck %s
-// CHECK-LABEL: llvm.func @complex_numbers()
-// CHECK-NEXT: %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : f32
-// CHECK-NEXT: %[[IMAG0:.*]] = llvm.mlir.constant(3.400000e+00 : f32) : f32
+// CHECK-LABEL: func @complex_numbers
+// CHECK-NEXT: %[[REAL0:.*]] = constant 1.200000e+00 : f32
+// CHECK-NEXT: %[[IMAG0:.*]] = constant 3.400000e+00 : f32
// CHECK-NEXT: %[[CPLX0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)>
// CHECK-NEXT: %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm.struct<(f32, f32)>
// CHECK-NEXT: %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm.struct<(f32, f32)>
// CHECK-NEXT: %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm.struct<(f32, f32)>
// CHECK-NEXT: %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm.struct<(f32, f32)>
-// CHECK-NEXT: llvm.return
func @complex_numbers() {
%real0 = constant 1.2 : f32
%imag0 = constant 3.4 : f32
@@ -18,9 +17,7 @@ func @complex_numbers() {
return
}
-// -----
-
-// CHECK-LABEL: llvm.func @complex_addition()
+// CHECK-LABEL: func @complex_addition
// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)>
@@ -41,9 +38,7 @@ func @complex_addition() {
return
}
-// -----
-
-// CHECK-LABEL: llvm.func @complex_substraction()
+// CHECK-LABEL: func @complex_substraction
// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)>
@@ -64,18 +59,19 @@ func @complex_substraction() {
return
}
-// -----
-
-// CHECK-LABEL: llvm.func @complex_div
-// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+// CHECK-LABEL: func @complex_div
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
%div = complex.div %lhs, %rhs : complex<f32>
return %div : complex<f32>
}
-// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
-// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
-// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
-// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+// CHECK: %[[CASTED_LHS:.*]] = llvm.mlir.cast %[[LHS]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK: %[[CASTED_RHS:.*]] = llvm.mlir.cast %[[RHS]] : complex<f32> to ![[C_TY]]
+
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]]
// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
@@ -95,20 +91,23 @@ func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]]
// CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] : f32
// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]]
-// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
-
-// -----
+//
+// CHECK: %[[CASTED_RESULT:.*]] = llvm.mlir.cast %[[RESULT_2]] : ![[C_TY]] to complex<f32>
+// CHECK: return %[[CASTED_RESULT]] : complex<f32>
-// CHECK-LABEL: llvm.func @complex_mul
-// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+// CHECK-LABEL: func @complex_mul
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
%mul = complex.mul %lhs, %rhs : complex<f32>
return %mul : complex<f32>
}
-// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
-// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
-// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
-// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+// CHECK: %[[CASTED_LHS:.*]] = llvm.mlir.cast %[[LHS]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK: %[[CASTED_RHS:.*]] = llvm.mlir.cast %[[RHS]] : complex<f32> to ![[C_TY]]
+
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]]
// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] : f32
@@ -121,21 +120,22 @@ func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0]
// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1]
-// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
-// -----
+// CHECK: %[[CASTED_RESULT:.*]] = llvm.mlir.cast %[[RESULT_2]] : ![[C_TY]] to complex<f32>
+// CHECK: return %[[CASTED_RESULT]] : complex<f32>
-// CHECK-LABEL: llvm.func @complex_abs
-// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]])
+// CHECK-LABEL: func @complex_abs
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func @complex_abs(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg: complex<f32>
return %abs : f32
}
-// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
-// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
+// CHECK: %[[CASTED_ARG:.*]] = llvm.mlir.cast %[[ARG]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[CASTED_ARG]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[CASTED_ARG]][1] : ![[C_TY]]
// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32
// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32
// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32
// CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32
-// CHECK: llvm.return %[[NORM]] : f32
+// CHECK: return %[[NORM]] : f32
diff --git a/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir b/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir
new file mode 100644
index 000000000000..6844f70c862b
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt %s -convert-complex-to-llvm -convert-std-to-llvm | FileCheck %s
+
+// CHECK-LABEL: llvm.func @complex_div
+// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+ %div = complex.div %lhs, %rhs : complex<f32>
+ return %div : complex<f32>
+}
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+
+// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
+
+// CHECK-DAG: %[[RHS_RE_SQ:.*]] = llvm.fmul %[[RHS_RE]], %[[RHS_RE]] : f32
+// CHECK-DAG: %[[RHS_IM_SQ:.*]] = llvm.fmul %[[RHS_IM]], %[[RHS_IM]] : f32
+// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[RHS_RE_SQ]], %[[RHS_IM_SQ]] : f32
+
+// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_RE]] : f32
+// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] : f32
+// CHECK: %[[REAL_TMP_2:.*]] = llvm.fadd %[[REAL_TMP_0]], %[[REAL_TMP_1]] : f32
+
+// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] : f32
+// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] : f32
+// CHECK: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] : f32
+
+// CHECK: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]] : f32
+// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] : f32
+// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]]
+// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
+
+// CHECK-LABEL: llvm.func @complex_mul
+// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+ %mul = complex.mul %lhs, %rhs : complex<f32>
+ return %mul : complex<f32>
+}
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
+
+// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] : f32
+// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]] : f32
+// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]] : f32
+
+// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] : f32
+// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] : f32
+// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] : f32
+
+// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0]
+// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1]
+// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
+
+// CHECK-LABEL: llvm.func @complex_abs
+// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]])
+func @complex_abs(%arg: complex<f32>) -> f32 {
+ %abs = complex.abs %arg: complex<f32>
+ return %abs : f32
+}
+// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
+// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32
+// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32
+// CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32
+// CHECK: llvm.return %[[NORM]] : f32
+
diff --git a/mlir/test/Dialect/LLVMIR/dialect-cast.mlir b/mlir/test/Dialect/LLVMIR/dialect-cast.mlir
index 90eaaa24544f..b72141dc4142 100644
--- a/mlir/test/Dialect/LLVMIR/dialect-cast.mlir
+++ b/mlir/test/Dialect/LLVMIR/dialect-cast.mlir
@@ -222,3 +222,31 @@ func @mlir_dialect_cast_unranked_rank(%0: memref<*xf32>) {
// expected-error at +1 {{expected second element of a memref descriptor to be an !llvm.ptr<i8>}}
llvm.mlir.cast %0 : memref<*xf32> to !llvm.struct<(i64, f32)>
}
+
+// -----
+
+func @mlir_dialect_cast_complex_non_struct(%0: complex<f32>) {
+ // expected-error at +1 {{expected 'complex' to map to two-element struct with identical element types}}
+ llvm.mlir.cast %0 : complex<f32> to f32
+}
+
+// -----
+
+func @mlir_dialect_cast_complex_bad_size(%0: complex<f32>) {
+ // expected-error at +1 {{expected 'complex' to map to two-element struct with identical element types}}
+ llvm.mlir.cast %0 : complex<f32> to !llvm.struct<(f32, f32, f32)>
+}
+
+// -----
+
+func @mlir_dialect_cast_complex_mismatching_type_struct(%0: complex<f32>) {
+ // expected-error at +1 {{expected 'complex' to map to two-element struct with identical element types}}
+ llvm.mlir.cast %0 : complex<f32> to !llvm.struct<(f32, f64)>
+}
+
+// -----
+
+func @mlir_dialect_cast_complex_mismatching_element(%0: complex<f32>) {
+ // expected-error at +1 {{expected 'complex' to map to two-element struct with identical element types}}
+ llvm.mlir.cast %0 : complex<f32> to !llvm.struct<(f64, f64)>
+}
More information about the Mlir-commits
mailing list