[Mlir-commits] [mlir] dee46d0 - [mlir] Fold complex.create(complex.re(op), complex.im(op))

Adrian Kuegel llvmlistbot at llvm.org
Wed May 26 05:03:12 PDT 2021


Author: Adrian Kuegel
Date: 2021-05-26T14:02:53+02:00
New Revision: dee46d08293f2ff693893d85c472029207ce750e

URL: https://github.com/llvm/llvm-project/commit/dee46d08293f2ff693893d85c472029207ce750e
DIFF: https://github.com/llvm/llvm-project/commit/dee46d08293f2ff693893d85c472029207ce750e.diff

LOG: [mlir] Fold complex.create(complex.re(op), complex.im(op))

Differential Revision: https://reviews.llvm.org/D103148

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
    mlir/test/Dialect/Complex/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 930d768848011..8b3ea0c84fb04 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -88,7 +88,7 @@ def CreateOp : Complex_Op<"create",
 
   let summary = "complex number creation operation";
   let description = [{
-    The `complex.complex` operation creates a complex number from two
+    The `complex.create` operation creates a complex number from two
     floating-point operands, the real and the imaginary part.
 
     Example:
@@ -102,6 +102,7 @@ def CreateOp : Complex_Op<"create",
   let results = (outs Complex<AnyFloat>:$complex);
 
   let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)";
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 34e27d6882765..58412d37605b0 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -19,22 +19,35 @@ using namespace mlir::complex;
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
 
-OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary op takes two operands");
+  // Fold complex.create(complex.re(op), complex.im(op)).
+  if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
+    if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
+      if (reOp.getOperand() == imOp.getOperand()) {
+        return reOp.getOperand();
+      }
+    }
+  }
+  return {};
+}
+
+OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 1 && "unary op takes 1 operand");
   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
   if (arrayAttr && arrayAttr.size() == 2)
-    return arrayAttr[0];
+    return arrayAttr[1];
   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
-    return createOp.getOperand(0);
+    return createOp.getOperand(1);
   return {};
 }
 
-OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 1 && "unary op takes 1 operand");
   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
   if (arrayAttr && arrayAttr.size() == 2)
-    return arrayAttr[1];
+    return arrayAttr[0];
   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
-    return createOp.getOperand(1);
+    return createOp.getOperand(0);
   return {};
 }

diff  --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index 0f8e5fe80bfd5..07801b348a275 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -1,5 +1,28 @@
 // RUN: mlir-opt %s -canonicalize | FileCheck %s
 
+// CHECK-LABEL: func @create_of_real_and_imag
+// CHECK-SAME: (%[[CPLX:.*]]: complex<f32>)
+func @create_of_real_and_imag(%cplx: complex<f32>) -> complex<f32> {
+  // CHECK-NEXT: return %[[CPLX]] : complex<f32>
+  %real = complex.re %cplx : complex<f32>
+  %imag = complex.im %cplx : complex<f32>
+  %complex = complex.create %real, %imag : complex<f32>
+  return %complex : complex<f32>
+}
+
+// CHECK-LABEL: func @create_of_real_and_imag_
diff erent_operand
+// CHECK-SAME: (%[[CPLX:.*]]: complex<f32>, %[[CPLX2:.*]]: complex<f32>)
+func @create_of_real_and_imag_
diff erent_operand(
+    %cplx: complex<f32>, %cplx2 : complex<f32>) -> complex<f32> {
+  // CHECK-NEXT: %[[REAL:.*]] = complex.re %[[CPLX]] : complex<f32>
+  // CHECK-NEXT: %[[IMAG:.*]] = complex.im %[[CPLX2]] : complex<f32>
+  // CHECK-NEXT: %[[COMPLEX:.*]] = complex.create %[[REAL]], %[[IMAG]] : complex<f32>
+  %real = complex.re %cplx : complex<f32>
+  %imag = complex.im %cplx2 : complex<f32>
+  %complex = complex.create %real, %imag : complex<f32>
+  return %complex: complex<f32>
+}
+
 // CHECK-LABEL: func @real_of_const(
 func @real_of_const() -> f32 {
   // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32


        


More information about the Mlir-commits mailing list