[Mlir-commits] [mlir] 5d5f61f - [MLIR] Add complex addition and substraction to the standard dialect

Frederik Gossen llvmlistbot at llvm.org
Fri May 8 02:54:52 PDT 2020


Author: Frederik Gossen
Date: 2020-05-08T09:54:18Z
New Revision: 5d5f61fc894bd4a2e100548ec65d56684883baf8

URL: https://github.com/llvm/llvm-project/commit/5d5f61fc894bd4a2e100548ec65d56684883baf8
DIFF: https://github.com/llvm/llvm-project/commit/5d5f61fc894bd4a2e100548ec65d56684883baf8.diff

LOG: [MLIR] Add complex addition and substraction to the standard dialect

Complex addition and substraction are the first two binary operations on complex
numbers.
Remaining operations will follow the same pattern.

Differential Revision: https://reviews.llvm.org/D79479

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 4e04df9b9215..354ff6a89e7c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -109,6 +109,7 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
 // integer tensor.  The custom assembly form of the operation is as follows
 //
 //     <op>i %0, %1 : i32
+//
 class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
     ArithmeticOp<mnemonic, traits>,
     Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
@@ -121,10 +122,23 @@ class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
 // is as follows
 //
 //     <op>f %0, %1 : f32
+//
 class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
     ArithmeticOp<mnemonic, traits>,
     Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
 
+// Base class for standard arithmetic operations on complex numbers with a
+// floating-point element type.
+// These operations take two operands and return one result, all of which must
+// be complex numbers of the same type.
+// The assembly format is as follows
+//
+//     <op>cf %0, %1 : complex<f32>
+//
+class ComplexFloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
+    ArithmeticOp<mnemonic, traits>,
+    Arguments<(ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs)>;
+
 // Base class for memref allocating ops: alloca and alloc.
 //
 //   %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)>
@@ -201,6 +215,26 @@ def AbsFOp : FloatUnaryOp<"absf"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// AddCFOp
+//===----------------------------------------------------------------------===//
+
+def AddCFOp : ComplexFloatArithmeticOp<"addcf"> {
+  let summary = "complex number addition";
+  let description = [{
+    The `addcf` operation takes two complex number operands and returns their
+    sum, a single complex number.
+    All operands and result must be of the same type, a complex number with a
+    floating-point element type.
+
+    Example:
+
+    ```mlir
+    %a = addcf %b, %c : complex<f32>
+    ```
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // AddFOp
 //===----------------------------------------------------------------------===//
@@ -2407,6 +2441,26 @@ def StoreOp : Std_Op<"store",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// SubCFOp
+//===----------------------------------------------------------------------===//
+
+def SubCFOp : ComplexFloatArithmeticOp<"subcf"> {
+  let summary = "complex number subtraction";
+  let description = [{
+    The `subcf` operation takes two complex number operands and returns their
+    
diff erence, a single complex number.
+    All operands and result must be of the same type, a complex number with a
+    floating-point element type.
+
+    Example:
+
+    ```mlir
+    %a = subcf %b, %c : complex<f32>
+    ```
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // SubFOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 8ced1b0667a2..ca7b5c2607a6 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -719,7 +719,6 @@ def SignlessIntegerOrFloatLike : TypeConstraint<Or<[
     SignlessIntegerLike.predicate, FloatLike.predicate]>,
     "signless-integer-like or floating-point-like">;
 
-
 //===----------------------------------------------------------------------===//
 // Attribute definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 1a01daa1188e..e9fd083b198e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -443,12 +443,12 @@ Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
   return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
 }
 
-void ComplexStructBuilder ::setImaginary(OpBuilder &builder, Location loc,
-                                         Value imaginary) {
+void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
+                                        Value imaginary) {
   setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
 }
 
-Value ComplexStructBuilder ::imaginary(OpBuilder &builder, Location loc) {
+Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
   return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
 }
 
@@ -1326,8 +1326,7 @@ using UnsignedShiftRightOpLowering =
     OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
 using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
 
-// Lowerings for operations on complex numbers, `CreateComplexOp`, `ReOp`, and
-// `ImOp`.
+// Lowerings for operations on complex numbers.
 
 struct CreateComplexOpLowering
     : public ConvertOpToLLVMPattern<CreateComplexOp> {
@@ -1385,6 +1384,82 @@ struct ImOpLowering : public ConvertOpToLLVMPattern<ImOp> {
   }
 };
 
+struct BinaryComplexOperands {
+  Value lhsReal, lhsImag, rhsReal, rhsImag;
+};
+
+template <typename OpTy>
+BinaryComplexOperands
+unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
+                            ConversionPatternRewriter &rewriter) {
+  auto bop = cast<OpTy>(op);
+  auto loc = bop.getLoc();
+  OperandAdaptor<OpTy> transformed(operands);
+
+  // Extract real and imaginary values from operands.
+  BinaryComplexOperands unpacked;
+  ComplexStructBuilder lhs(transformed.lhs());
+  unpacked.lhsReal = lhs.real(rewriter, loc);
+  unpacked.lhsImag = lhs.imaginary(rewriter, loc);
+  ComplexStructBuilder rhs(transformed.rhs());
+  unpacked.rhsReal = rhs.real(rewriter, loc);
+  unpacked.rhsImag = rhs.imaginary(rewriter, loc);
+
+  return unpacked;
+}
+
+struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
+  using ConvertOpToLLVMPattern<AddCFOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto op = cast<AddCFOp>(operation);
+    auto loc = op.getLoc();
+    BinaryComplexOperands arg =
+        unpackBinaryComplexOperands<AddCFOp>(op, operands, rewriter);
+
+    // Initialize complex number struct for result.
+    auto structType = this->typeConverter.convertType(op.getType());
+    auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
+
+    // Emit IR to add complex numbers.
+    Value real = rewriter.create<LLVM::FAddOp>(loc, arg.lhsReal, arg.rhsReal);
+    Value imag = rewriter.create<LLVM::FAddOp>(loc, arg.lhsImag, arg.rhsImag);
+    result.setReal(rewriter, loc, real);
+    result.setImaginary(rewriter, loc, imag);
+
+    rewriter.replaceOp(op, {result});
+    return success();
+  }
+};
+
+struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
+  using ConvertOpToLLVMPattern<SubCFOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto op = cast<SubCFOp>(operation);
+    auto loc = op.getLoc();
+    BinaryComplexOperands arg =
+        unpackBinaryComplexOperands<SubCFOp>(op, operands, rewriter);
+
+    // Initialize complex number struct for result.
+    auto structType = this->typeConverter.convertType(op.getType());
+    auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
+
+    // Emit IR to substract complex numbers.
+    Value real = rewriter.create<LLVM::FSubOp>(loc, arg.lhsReal, arg.rhsReal);
+    Value imag = rewriter.create<LLVM::FSubOp>(loc, arg.lhsImag, arg.rhsImag);
+    result.setReal(rewriter, loc, real);
+    result.setImaginary(rewriter, loc, imag);
+
+    rewriter.replaceOp(op, {result});
+    return success();
+  }
+};
+
 // Check if the MemRefType `type` is supported by the lowering. We currently
 // only support memrefs with identity maps.
 static bool isSupportedMemRefType(MemRefType type) {
@@ -2874,6 +2949,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
   // clang-format off
   patterns.insert<
       AbsFOpLowering,
+      AddCFOpLowering,
       AddFOpLowering,
       AddIOpLowering,
       AllocaOpLowering,
@@ -2921,6 +2997,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       SplatOpLowering,
       SplatNdOpLowering,
       SqrtOpLowering,
+      SubCFOpLowering,
       SubFOpLowering,
       SubIOpLowering,
       TruncateIOpLowering,

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index b7cb13e51ca2..cd16539edf02 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -83,6 +83,48 @@ func @complex_numbers() {
   return
 }
 
+// CHECK-LABEL: llvm.func @complex_addition()
+// CHECK-DAG:     %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm<"{ double, double }">
+// CHECK-DAG:     %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm<"{ double, double }">
+// CHECK-DAG:     %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"{ double, double }">
+// CHECK-DAG:     %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"{ double, double }">
+// CHECK:         %[[C0:.*]] = llvm.mlir.undef : !llvm<"{ double, double }">
+// CHECK-DAG:     %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] : !llvm.double
+// CHECK-DAG:     %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] : !llvm.double
+// CHECK:         %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm<"{ double, double }">
+// CHECK:         %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm<"{ double, double }">
+func @complex_addition() {
+  %a_re = constant 1.2 : f64
+  %a_im = constant 3.4 : f64
+  %a = create_complex %a_re, %a_im : complex<f64>
+  %b_re = constant 5.6 : f64
+  %b_im = constant 7.8 : f64
+  %b = create_complex %b_re, %b_im : complex<f64>
+  %c = addcf %a, %b : complex<f64>
+  return
+}
+
+// CHECK-LABEL: llvm.func @complex_substraction()
+// CHECK-DAG:     %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm<"{ double, double }">
+// CHECK-DAG:     %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm<"{ double, double }">
+// CHECK-DAG:     %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"{ double, double }">
+// CHECK-DAG:     %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"{ double, double }">
+// CHECK:         %[[C0:.*]] = llvm.mlir.undef : !llvm<"{ double, double }">
+// CHECK-DAG:     %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] : !llvm.double
+// CHECK-DAG:     %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] : !llvm.double
+// CHECK:         %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm<"{ double, double }">
+// CHECK:         %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm<"{ double, double }">
+func @complex_substraction() {
+  %a_re = constant 1.2 : f64
+  %a_im = constant 3.4 : f64
+  %a = create_complex %a_re, %a_im : complex<f64>
+  %b_re = constant 5.6 : f64
+  %b_im = constant 7.8 : f64
+  %b = create_complex %b_re, %b_im : complex<f64>
+  %c = subcf %a, %b : complex<f64>
+  return
+}
+
 // CHECK-LABEL: func @simple_caller() {
 // CHECK-NEXT:  llvm.call @simple_loop() : () -> ()
 // CHECK-NEXT:  llvm.return


        


More information about the Mlir-commits mailing list