[llvm-branch-commits] [mlir] 11f4c58 - [mlir] Add `complex.abs`, `complex.div` and `complex.mul` to ComplexOps.

Alexander Belyaev via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 19 03:15:09 PST 2021


Author: Alexander Belyaev
Date: 2021-01-19T12:09:59+01:00
New Revision: 11f4c58c153cedf6fe04cab49d4a4f02d00e3383

URL: https://github.com/llvm/llvm-project/commit/11f4c58c153cedf6fe04cab49d4a4f02d00e3383
DIFF: https://github.com/llvm/llvm-project/commit/11f4c58c153cedf6fe04cab49d4a4f02d00e3383.diff

LOG: [mlir] Add `complex.abs`, `complex.div` and `complex.mul` to ComplexOps.

Differential Revision: https://reviews.llvm.org/D94911

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
    mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
    mlir/test/Dialect/Complex/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index a4329df7c1aa..960f5f64eec3 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -43,11 +43,37 @@ def AddOp : ComplexArithmeticOp<"add"> {
     Example:
 
     ```mlir
-    %a = add %b, %c : complex<f32>
+    %a = complex.add %b, %c : complex<f32>
     ```
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// AbsOp
+//===----------------------------------------------------------------------===//
+
+def AbsOp : Complex_Op<"abs",
+    [NoSideEffect,
+     TypesMatchWith<"complex element type matches result type",
+                    "complex", "result",
+                    "$_self.cast<ComplexType>().getElementType()">]> {
+  let summary = "computes absolute value of a complex number";
+  let description = [{
+    The `abs` op takes a single complex number and computes its absolute value.
+
+    Example:
+
+    ```mlir
+    %a = complex.abs %b : complex<f32>
+    ```
+  }];
+
+  let arguments = (ins Complex<AnyFloat>:$complex);
+  let results = (outs AnyFloat:$result);
+
+  let assemblyFormat = "$complex attr-dict `:` type($complex)";
+}
+
 //===----------------------------------------------------------------------===//
 // CreateOp
 //===----------------------------------------------------------------------===//
@@ -70,7 +96,7 @@ def CreateOp : Complex_Op<"create",
     Example:
 
     ```mlir
-    %a = create_complex %b, %c : complex<f32>
+    %a = complex.create %b, %c : complex<f32>
     ```
   }];
 
@@ -80,6 +106,22 @@ def CreateOp : Complex_Op<"create",
   let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)";
 }
 
+//===----------------------------------------------------------------------===//
+// DivOp
+//===----------------------------------------------------------------------===//
+
+def DivOp : ComplexArithmeticOp<"div"> {
+  let summary = "complex division";
+  let description = [{
+    The `div` operation takes two complex numbers and returns result of their
+    division:
+
+    ```mlir
+    %a = complex.div %b, %c : complex<f32>
+    ```
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ImOp
 //===----------------------------------------------------------------------===//
@@ -96,7 +138,7 @@ def ImOp : Complex_Op<"im",
     Example:
 
     ```mlir
-    %a = im %b : complex<f32>
+    %a = complex.im %b : complex<f32>
     ```
   }];
 
@@ -106,6 +148,21 @@ def ImOp : Complex_Op<"im",
   let assemblyFormat = "$complex attr-dict `:` type($complex)";
 }
 
+//===----------------------------------------------------------------------===//
+// MulOp
+//===----------------------------------------------------------------------===//
+
+def MulOp : ComplexArithmeticOp<"mul"> {
+  let summary = "complex multiplication";
+  let description = [{
+    The `mul` operation takes two complex numbers and returns their product:
+
+    ```mlir
+    %a = complex.mul %b, %c : complex<f32>
+    ```
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ReOp
 //===----------------------------------------------------------------------===//
@@ -122,7 +179,7 @@ def ReOp : Complex_Op<"re",
     Example:
 
     ```mlir
-    %a = re %b : complex<f32>
+    %a = complex.re %b : complex<f32>
     ```
   }];
 
@@ -145,7 +202,7 @@ def SubOp : ComplexArithmeticOp<"sub"> {
     Example:
 
     ```mlir
-    %a = sub %b, %c : complex<f32>
+    %a = complex.sub %b, %c : complex<f32>
     ```
   }];
 }

diff  --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index 1096b08da1f1..270b9489625c 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -17,6 +17,29 @@ using namespace mlir::LLVM;
 
 namespace {
 
+struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
+  using ConvertOpToLLVMPattern<complex::AbsOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::AbsOp::Adaptor transformed(operands);
+    auto loc = op.getLoc();
+
+    ComplexStructBuilder complexStruct(transformed.complex());
+    Value real = complexStruct.real(rewriter, op.getLoc());
+    Value imag = complexStruct.imaginary(rewriter, op.getLoc());
+
+    auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+    Value sqNorm = rewriter.create<LLVM::FAddOp>(
+        loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
+        rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
+
+    rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
+    return success();
+  }
+};
+
 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
   using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
 
@@ -123,6 +146,88 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
   }
 };
 
+struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
+  using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    BinaryComplexOperands arg =
+        unpackBinaryComplexOperands<complex::DivOp>(op, operands, rewriter);
+
+    // Initialize complex number struct for result.
+    auto structType = typeConverter->convertType(op.getType());
+    auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
+
+    // Emit IR to add complex numbers.
+    auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+    Value rhsRe = arg.rhs.real();
+    Value rhsIm = arg.rhs.imag();
+    Value lhsRe = arg.lhs.real();
+    Value lhsIm = arg.lhs.imag();
+
+    Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
+        loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
+        rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
+
+    Value resultReal = rewriter.create<LLVM::FAddOp>(
+        loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
+        rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
+
+    Value resultImag = rewriter.create<LLVM::FSubOp>(
+        loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
+        rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
+
+    result.setReal(
+        rewriter, loc,
+        rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
+    result.setImaginary(
+        rewriter, loc,
+        rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
+
+    rewriter.replaceOp(op, {result});
+    return success();
+  }
+};
+
+struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
+  using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    BinaryComplexOperands arg =
+        unpackBinaryComplexOperands<complex::MulOp>(op, operands, rewriter);
+
+    // Initialize complex number struct for result.
+    auto structType = typeConverter->convertType(op.getType());
+    auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
+
+    // Emit IR to add complex numbers.
+    auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+    Value rhsRe = arg.rhs.real();
+    Value rhsIm = arg.rhs.imag();
+    Value lhsRe = arg.lhs.real();
+    Value lhsIm = arg.lhs.imag();
+
+    Value real = rewriter.create<LLVM::FSubOp>(
+        loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
+        rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
+
+    Value imag = rewriter.create<LLVM::FAddOp>(
+        loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
+        rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
+
+    result.setReal(rewriter, loc, real);
+    result.setImaginary(rewriter, loc, imag);
+
+    rewriter.replaceOp(op, {result});
+    return success();
+  }
+};
+
 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
   using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
 
@@ -156,9 +261,12 @@ void mlir::populateComplexToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   // clang-format off
   patterns.insert<
+      AbsOpConversion,
       AddOpConversion,
       CreateOpConversion,
+      DivOpConversion,
       ImOpConversion,
+      MulOpConversion,
       ReOpConversion,
       SubOpConversion
     >(converter);

diff  --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
index fde21df8abf3..ffc7bbfb3ec5 100644
--- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -18,6 +18,8 @@ func @complex_numbers() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: llvm.func @complex_addition()
 // CHECK-DAG:     %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
 // CHECK-DAG:     %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
@@ -39,6 +41,8 @@ func @complex_addition() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: llvm.func @complex_substraction()
 // CHECK-DAG:     %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
 // CHECK-DAG:     %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
@@ -59,3 +63,79 @@ func @complex_substraction() {
   %c = complex.sub %a, %b : complex<f64>
   return
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @complex_div
+// CHECK-SAME:    %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %div = complex.div %lhs, %rhs : complex<f32>
+  return %div : complex<f32>
+}
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+
+// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
+
+// CHECK-DAG: %[[RHS_RE_SQ:.*]] = llvm.fmul %[[RHS_RE]], %[[RHS_RE]]  : f32
+// CHECK-DAG: %[[RHS_IM_SQ:.*]] = llvm.fmul %[[RHS_IM]], %[[RHS_IM]]  : f32
+// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[RHS_RE_SQ]], %[[RHS_IM_SQ]]  : f32
+
+// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_RE]]  : f32
+// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]]  : f32
+// CHECK: %[[REAL_TMP_2:.*]] = llvm.fadd %[[REAL_TMP_0]], %[[REAL_TMP_1]]  : f32
+
+// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]]  : f32
+// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]]  : f32
+// CHECK: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]]  : f32
+
+// CHECK: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]]  : f32
+// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]]  : f32
+// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]]
+// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
+
+// -----
+
+// CHECK-LABEL: llvm.func @complex_mul
+// CHECK-SAME:    %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %mul = complex.mul %lhs, %rhs : complex<f32>
+  return %mul : complex<f32>
+}
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
+
+// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]]  : f32
+// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]]  : f32
+// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]]  : f32
+
+// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]]  : f32
+// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]]  : f32
+// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]]  : f32
+
+// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0]
+// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1]
+// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
+
+// -----
+
+// CHECK-LABEL: llvm.func @complex_abs
+// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]])
+func @complex_abs(%arg: complex<f32>) -> f32 {
+  %abs = complex.abs %arg: complex<f32>
+  return %abs : f32
+}
+// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
+// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]]  : f32
+// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]]  : f32
+// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]]  : f32
+// CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32
+// CHECK: llvm.return %[[NORM]] : f32
+

diff  --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index 152e8704c5ff..9685886ee525 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -3,21 +3,30 @@
 
 
 // CHECK-LABEL: func @ops(
-// CHECK-SAME:            [[F:%.*]]: f32) {
+// CHECK-SAME:            %[[F:.*]]: f32) {
 func @ops(%f: f32) {
-  // CHECK: [[C:%.*]] = complex.create [[F]], [[F]] : complex<f32>
+  // CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
   %complex = complex.create %f, %f : complex<f32>
 
-  // CHECK: complex.re [[C]] : complex<f32>
+  // CHECK: complex.re %[[C]] : complex<f32>
   %real = complex.re %complex : complex<f32>
 
-  // CHECK: complex.im [[C]] : complex<f32>
+  // CHECK: complex.im %[[C]] : complex<f32>
   %imag = complex.im %complex : complex<f32>
 
-  // CHECK: complex.add [[C]], [[C]] : complex<f32>
+  // CHECK: complex.abs %[[C]] : complex<f32>
+  %abs = complex.abs %complex : complex<f32>
+
+  // CHECK: complex.add %[[C]], %[[C]] : complex<f32>
   %sum = complex.add %complex, %complex : complex<f32>
 
-  // CHECK: complex.sub [[C]], [[C]] : complex<f32>
+  // CHECK: complex.div %[[C]], %[[C]] : complex<f32>
+  %div = complex.div %complex, %complex : complex<f32>
+
+  // CHECK: complex.mul %[[C]], %[[C]] : complex<f32>
+  %prod = complex.mul %complex, %complex : complex<f32>
+
+  // CHECK: complex.sub %[[C]], %[[C]] : complex<f32>
   %
diff  = complex.sub %complex, %complex : complex<f32>
   return
 }


        


More information about the llvm-branch-commits mailing list