[Mlir-commits] [mlir] 036a699 - [mlir][complex] Canonicalization for consecutive complex.add and sub

Alexander Belyaev llvmlistbot at llvm.org
Tue Jun 28 02:41:23 PDT 2022


Author: lewuathe
Date: 2022-06-28T11:41:16+02:00
New Revision: 036a6996750dccfd4493cb169f79e12b585b1e75

URL: https://github.com/llvm/llvm-project/commit/036a6996750dccfd4493cb169f79e12b585b1e75
DIFF: https://github.com/llvm/llvm-project/commit/036a6996750dccfd4493cb169f79e12b585b1e75.diff

LOG: [mlir][complex] Canonicalization for consecutive complex.add and sub

Add basic canonicalization for consecutive complex.add and sub operations.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D128702

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
    mlir/test/Dialect/Complex/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index f98037c9a515e..5ca24398843cb 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -73,6 +73,8 @@ def AddOp : ComplexArithmeticOp<"add"> {
     %a = complex.add %b, %c : complex<f32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 7545d3feec17b..0390a00cf6844 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
 
 using namespace mlir;
 using namespace mlir::complex;
@@ -103,6 +104,26 @@ OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// AddOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary op takes 2 operands");
+
+  // complex.add(complex.sub(a, b), b) -> a
+  if (auto sub = getLhs().getDefiningOp<SubOp>())
+    if (getRhs() == sub.getRhs())
+      return sub.getLhs();
+
+  // complex.add(b, complex.sub(a, b)) -> a
+  if (auto sub = getRhs().getDefiningOp<SubOp>())
+    if (getLhs() == sub.getRhs())
+      return sub.getLhs();
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index 2d492a223d4c7..8bca3232774a1 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -62,3 +62,25 @@ func.func @imag_of_create_op() -> f32 {
   %1 = complex.im %complex : complex<f32>
   return %1 : f32
 }
+
+// CHECK-LABEL: func @complex_add_sub_lhs
+func.func @complex_add_sub_lhs() -> complex<f32> {
+  %complex1 = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
+  %complex2 = complex.constant [0.0 : f32, 2.0 : f32] : complex<f32>
+  // CHECK: %[[CPLX:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex<f32>
+  // CHECK-NEXT: return %[[CPLX:.*]] : complex<f32>
+  %sub = complex.sub %complex1, %complex2 : complex<f32>
+  %add = complex.add %sub, %complex2 : complex<f32>
+  return %add : complex<f32>
+}
+
+// CHECK-LABEL: func @complex_add_sub_rhs
+func.func @complex_add_sub_rhs() -> complex<f32> {
+  %complex1 = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
+  %complex2 = complex.constant [0.0 : f32, 2.0 : f32] : complex<f32>
+  // CHECK: %[[CPLX:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex<f32>
+  // CHECK-NEXT: return %[[CPLX:.*]] : complex<f32>
+  %sub = complex.sub %complex1, %complex2 : complex<f32>
+  %add = complex.add %complex2, %sub : complex<f32>
+  return %add : complex<f32>
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list