[llvm-branch-commits] [mlir] 11f4c58 - [mlir] Add `complex.abs`, `complex.div` and `complex.mul` to ComplexOps.
Alexander Belyaev via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 19 03:15:09 PST 2021
Author: Alexander Belyaev
Date: 2021-01-19T12:09:59+01:00
New Revision: 11f4c58c153cedf6fe04cab49d4a4f02d00e3383
URL: https://github.com/llvm/llvm-project/commit/11f4c58c153cedf6fe04cab49d4a4f02d00e3383
DIFF: https://github.com/llvm/llvm-project/commit/11f4c58c153cedf6fe04cab49d4a4f02d00e3383.diff
LOG: [mlir] Add `complex.abs`, `complex.div` and `complex.mul` to ComplexOps.
Differential Revision: https://reviews.llvm.org/D94911
Added:
Modified:
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/Complex/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index a4329df7c1aa..960f5f64eec3 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -43,11 +43,37 @@ def AddOp : ComplexArithmeticOp<"add"> {
Example:
```mlir
- %a = add %b, %c : complex<f32>
+ %a = complex.add %b, %c : complex<f32>
```
}];
}
+//===----------------------------------------------------------------------===//
+// AbsOp
+//===----------------------------------------------------------------------===//
+
+def AbsOp : Complex_Op<"abs",
+ [NoSideEffect,
+ TypesMatchWith<"complex element type matches result type",
+ "complex", "result",
+ "$_self.cast<ComplexType>().getElementType()">]> {
+ let summary = "computes absolute value of a complex number";
+ let description = [{
+ The `abs` op takes a single complex number and computes its absolute value.
+
+ Example:
+
+ ```mlir
+ %a = complex.abs %b : complex<f32>
+ ```
+ }];
+
+ let arguments = (ins Complex<AnyFloat>:$complex);
+ let results = (outs AnyFloat:$result);
+
+ let assemblyFormat = "$complex attr-dict `:` type($complex)";
+}
+
//===----------------------------------------------------------------------===//
// CreateOp
//===----------------------------------------------------------------------===//
@@ -70,7 +96,7 @@ def CreateOp : Complex_Op<"create",
Example:
```mlir
- %a = create_complex %b, %c : complex<f32>
+ %a = complex.create %b, %c : complex<f32>
```
}];
@@ -80,6 +106,22 @@ def CreateOp : Complex_Op<"create",
let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)";
}
+//===----------------------------------------------------------------------===//
+// DivOp
+//===----------------------------------------------------------------------===//
+
+def DivOp : ComplexArithmeticOp<"div"> {
+ let summary = "complex division";
+ let description = [{
+ The `div` operation takes two complex numbers and returns result of their
+ division:
+
+ ```mlir
+ %a = complex.div %b, %c : complex<f32>
+ ```
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ImOp
//===----------------------------------------------------------------------===//
@@ -96,7 +138,7 @@ def ImOp : Complex_Op<"im",
Example:
```mlir
- %a = im %b : complex<f32>
+ %a = complex.im %b : complex<f32>
```
}];
@@ -106,6 +148,21 @@ def ImOp : Complex_Op<"im",
let assemblyFormat = "$complex attr-dict `:` type($complex)";
}
+//===----------------------------------------------------------------------===//
+// MulOp
+//===----------------------------------------------------------------------===//
+
+def MulOp : ComplexArithmeticOp<"mul"> {
+ let summary = "complex multiplication";
+ let description = [{
+ The `mul` operation takes two complex numbers and returns their product:
+
+ ```mlir
+ %a = complex.mul %b, %c : complex<f32>
+ ```
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//
@@ -122,7 +179,7 @@ def ReOp : Complex_Op<"re",
Example:
```mlir
- %a = re %b : complex<f32>
+ %a = complex.re %b : complex<f32>
```
}];
@@ -145,7 +202,7 @@ def SubOp : ComplexArithmeticOp<"sub"> {
Example:
```mlir
- %a = sub %b, %c : complex<f32>
+ %a = complex.sub %b, %c : complex<f32>
```
}];
}
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index 1096b08da1f1..270b9489625c 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -17,6 +17,29 @@ using namespace mlir::LLVM;
namespace {
+struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
+ using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ complex::AbsOp::Adaptor transformed(operands);
+ auto loc = op.getLoc();
+
+ ComplexStructBuilder complexStruct(transformed.complex());
+ Value real = complexStruct.real(rewriter, op.getLoc());
+ Value imag = complexStruct.imaginary(rewriter, op.getLoc());
+
+ auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+ Value sqNorm = rewriter.create<LLVM::FAddOp>(
+ loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
+ rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
+
+ rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
+ return success();
+ }
+};
+
struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
@@ -123,6 +146,88 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
}
};
+struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
+ using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ BinaryComplexOperands arg =
+ unpackBinaryComplexOperands<complex::DivOp>(op, operands, rewriter);
+
+ // Initialize complex number struct for result.
+ auto structType = typeConverter->convertType(op.getType());
+ auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
+
+ // Emit IR to add complex numbers.
+ auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+ Value rhsRe = arg.rhs.real();
+ Value rhsIm = arg.rhs.imag();
+ Value lhsRe = arg.lhs.real();
+ Value lhsIm = arg.lhs.imag();
+
+ Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
+ loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
+ rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
+
+ Value resultReal = rewriter.create<LLVM::FAddOp>(
+ loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
+ rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
+
+ Value resultImag = rewriter.create<LLVM::FSubOp>(
+ loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
+ rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
+
+ result.setReal(
+ rewriter, loc,
+ rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
+ result.setImaginary(
+ rewriter, loc,
+ rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
+
+ rewriter.replaceOp(op, {result});
+ return success();
+ }
+};
+
+struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
+ using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ BinaryComplexOperands arg =
+ unpackBinaryComplexOperands<complex::MulOp>(op, operands, rewriter);
+
+ // Initialize complex number struct for result.
+ auto structType = typeConverter->convertType(op.getType());
+ auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
+
+ // Emit IR to add complex numbers.
+ auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+ Value rhsRe = arg.rhs.real();
+ Value rhsIm = arg.rhs.imag();
+ Value lhsRe = arg.lhs.real();
+ Value lhsIm = arg.lhs.imag();
+
+ Value real = rewriter.create<LLVM::FSubOp>(
+ loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
+ rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
+
+ Value imag = rewriter.create<LLVM::FAddOp>(
+ loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
+ rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
+
+ result.setReal(rewriter, loc, real);
+ result.setImaginary(rewriter, loc, imag);
+
+ rewriter.replaceOp(op, {result});
+ return success();
+ }
+};
+
struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
@@ -156,9 +261,12 @@ void mlir::populateComplexToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// clang-format off
patterns.insert<
+ AbsOpConversion,
AddOpConversion,
CreateOpConversion,
+ DivOpConversion,
ImOpConversion,
+ MulOpConversion,
ReOpConversion,
SubOpConversion
>(converter);
diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
index fde21df8abf3..ffc7bbfb3ec5 100644
--- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -18,6 +18,8 @@ func @complex_numbers() {
return
}
+// -----
+
// CHECK-LABEL: llvm.func @complex_addition()
// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
@@ -39,6 +41,8 @@ func @complex_addition() {
return
}
+// -----
+
// CHECK-LABEL: llvm.func @complex_substraction()
// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
@@ -59,3 +63,79 @@ func @complex_substraction() {
%c = complex.sub %a, %b : complex<f64>
return
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @complex_div
+// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+ %div = complex.div %lhs, %rhs : complex<f32>
+ return %div : complex<f32>
+}
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+
+// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
+
+// CHECK-DAG: %[[RHS_RE_SQ:.*]] = llvm.fmul %[[RHS_RE]], %[[RHS_RE]] : f32
+// CHECK-DAG: %[[RHS_IM_SQ:.*]] = llvm.fmul %[[RHS_IM]], %[[RHS_IM]] : f32
+// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[RHS_RE_SQ]], %[[RHS_IM_SQ]] : f32
+
+// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_RE]] : f32
+// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] : f32
+// CHECK: %[[REAL_TMP_2:.*]] = llvm.fadd %[[REAL_TMP_0]], %[[REAL_TMP_1]] : f32
+
+// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] : f32
+// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] : f32
+// CHECK: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] : f32
+
+// CHECK: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]] : f32
+// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] : f32
+// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]]
+// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
+
+// -----
+
+// CHECK-LABEL: llvm.func @complex_mul
+// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+ %mul = complex.mul %lhs, %rhs : complex<f32>
+ return %mul : complex<f32>
+}
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
+
+// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] : f32
+// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]] : f32
+// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]] : f32
+
+// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] : f32
+// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] : f32
+// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] : f32
+
+// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0]
+// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1]
+// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
+
+// -----
+
+// CHECK-LABEL: llvm.func @complex_abs
+// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]])
+func @complex_abs(%arg: complex<f32>) -> f32 {
+ %abs = complex.abs %arg: complex<f32>
+ return %abs : f32
+}
+// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
+// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32
+// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32
+// CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32
+// CHECK: llvm.return %[[NORM]] : f32
+
diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index 152e8704c5ff..9685886ee525 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -3,21 +3,30 @@
// CHECK-LABEL: func @ops(
-// CHECK-SAME: [[F:%.*]]: f32) {
+// CHECK-SAME: %[[F:.*]]: f32) {
func @ops(%f: f32) {
- // CHECK: [[C:%.*]] = complex.create [[F]], [[F]] : complex<f32>
+ // CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
%complex = complex.create %f, %f : complex<f32>
- // CHECK: complex.re [[C]] : complex<f32>
+ // CHECK: complex.re %[[C]] : complex<f32>
%real = complex.re %complex : complex<f32>
- // CHECK: complex.im [[C]] : complex<f32>
+ // CHECK: complex.im %[[C]] : complex<f32>
%imag = complex.im %complex : complex<f32>
- // CHECK: complex.add [[C]], [[C]] : complex<f32>
+ // CHECK: complex.abs %[[C]] : complex<f32>
+ %abs = complex.abs %complex : complex<f32>
+
+ // CHECK: complex.add %[[C]], %[[C]] : complex<f32>
%sum = complex.add %complex, %complex : complex<f32>
- // CHECK: complex.sub [[C]], [[C]] : complex<f32>
+ // CHECK: complex.div %[[C]], %[[C]] : complex<f32>
+ %div = complex.div %complex, %complex : complex<f32>
+
+ // CHECK: complex.mul %[[C]], %[[C]] : complex<f32>
+ %prod = complex.mul %complex, %complex : complex<f32>
+
+ // CHECK: complex.sub %[[C]], %[[C]] : complex<f32>
%
diff = complex.sub %complex, %complex : complex<f32>
return
}
More information about the llvm-branch-commits
mailing list