[clang] 9d91b07 - [CIR] Implement EqualOp for ComplexType (#145769)

via cfe-commits cfe-commits at lists.llvm.org
Thu Jun 26 09:06:25 PDT 2025


Author: Amr Hesham
Date: 2025-06-26T18:06:22+02:00
New Revision: 9d91b07e1e8bc2f3c6086efa6cfcf2565e98658c

URL: https://github.com/llvm/llvm-project/commit/9d91b07e1e8bc2f3c6086efa6cfcf2565e98658c
DIFF: https://github.com/llvm/llvm-project/commit/9d91b07e1e8bc2f3c6086efa6cfcf2565e98658c.diff

LOG: [CIR] Implement EqualOp for ComplexType (#145769)

This change adds support for equal operation for ComplexType


https://github.com/llvm/llvm-project/issues/141365

Added: 
    

Modified: 
    clang/include/clang/CIR/Dialect/IR/CIROps.td
    clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
    clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
    clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
    clang/test/CIR/CodeGen/complex.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index ef77c46b011f7..b0a8c2bd9d332 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2455,6 +2455,31 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ComplexEqualOp
+//===----------------------------------------------------------------------===//
+
+def ComplexEqualOp : CIR_Op<"complex.eq", [Pure, SameTypeOperands]> {
+
+  let summary = "Computes whether two complex values are equal";
+  let description = [{
+   The `complex.equal` op takes two complex numbers and returns whether
+   they are equal.
+
+    ```mlir
+    %r = cir.complex.eq %a, %b : !cir.complex<!cir.float>
+    ```
+  }];
+
+  let results = (outs CIR_BoolType:$result);
+  let arguments = (ins CIR_ComplexType:$lhs, CIR_ComplexType:$rhs);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs
+    `:` qualified(type($lhs)) attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Assume Operations
 //===----------------------------------------------------------------------===//

diff  --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 7f8dcd96a6bff..4bcbc6d7ce798 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -894,9 +894,17 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
       }
     } else {
       // Complex Comparison: can only be an equality comparison.
-      assert(!cir::MissingFeatures::complexType());
-      cgf.cgm.errorNYI(loc, "complex comparison");
-      result = builder.getBool(false, loc);
+      assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE);
+
+      BinOpInfo boInfo = emitBinOps(e);
+      if (e->getOpcode() == BO_EQ) {
+        result =
+            builder.create<cir::ComplexEqualOp>(loc, boInfo.lhs, boInfo.rhs);
+      } else {
+        assert(!cir::MissingFeatures::complexType());
+        cgf.cgm.errorNYI(loc, "complex not equal");
+        result = builder.getBool(false, loc);
+      }
     }
 
     return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(),

diff  --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 1c13c88902d9a..d510f1c756936 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1900,29 +1900,30 @@ void ConvertCIRToLLVMPass::runOnOperation() {
                CIRToLLVMBrOpLowering,
                CIRToLLVMCallOpLowering,
                CIRToLLVMCmpOpLowering,
+               CIRToLLVMComplexCreateOpLowering,
+               CIRToLLVMComplexEqualOpLowering,
+               CIRToLLVMComplexImagOpLowering,
+               CIRToLLVMComplexRealOpLowering,
                CIRToLLVMConstantOpLowering,
                CIRToLLVMExpectOpLowering,
                CIRToLLVMFuncOpLowering,
                CIRToLLVMGetGlobalOpLowering,
                CIRToLLVMGetMemberOpLowering,
                CIRToLLVMSelectOpLowering,
-               CIRToLLVMSwitchFlatOpLowering,
                CIRToLLVMShiftOpLowering,
-               CIRToLLVMStackSaveOpLowering,
                CIRToLLVMStackRestoreOpLowering,
+               CIRToLLVMStackSaveOpLowering,
+               CIRToLLVMSwitchFlatOpLowering,
                CIRToLLVMTrapOpLowering,
                CIRToLLVMUnaryOpLowering,
+               CIRToLLVMVecCmpOpLowering,
                CIRToLLVMVecCreateOpLowering,
                CIRToLLVMVecExtractOpLowering,
                CIRToLLVMVecInsertOpLowering,
-               CIRToLLVMVecCmpOpLowering,
-               CIRToLLVMVecSplatOpLowering,
-               CIRToLLVMVecShuffleOpLowering,
                CIRToLLVMVecShuffleDynamicOpLowering,
-               CIRToLLVMVecTernaryOpLowering,
-               CIRToLLVMComplexCreateOpLowering,
-               CIRToLLVMComplexRealOpLowering,
-               CIRToLLVMComplexImagOpLowering
+               CIRToLLVMVecShuffleOpLowering,
+               CIRToLLVMVecSplatOpLowering,
+               CIRToLLVMVecTernaryOpLowering
       // clang-format on
       >(converter, patterns.getContext());
 
@@ -2244,6 +2245,43 @@ mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
   return mlir::success();
 }
 
+mlir::LogicalResult CIRToLLVMComplexEqualOpLowering::matchAndRewrite(
+    cir::ComplexEqualOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  mlir::Value lhs = adaptor.getLhs();
+  mlir::Value rhs = adaptor.getRhs();
+
+  auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
+  mlir::Type complexElemTy =
+      getTypeConverter()->convertType(complexType.getElementType());
+
+  mlir::Location loc = op.getLoc();
+  auto lhsReal =
+      rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
+  auto lhsImag =
+      rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
+  auto rhsReal =
+      rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
+  auto rhsImag =
+      rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
+
+  if (complexElemTy.isInteger()) {
+    auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
+        loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal);
+    auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
+        loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag);
+    rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
+    return mlir::success();
+  }
+
+  auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
+      loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal);
+  auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
+      loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag);
+  rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
+  return mlir::success();
+}
+
 std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
   return std::make_unique<ConvertCIRToLLVMPass>();
 }

diff  --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index 8502cb1ae5d9f..25cf218cf8b6c 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -463,6 +463,16 @@ class CIRToLLVMComplexImagOpLowering
                   mlir::ConversionPatternRewriter &) const override;
 };
 
+class CIRToLLVMComplexEqualOpLowering
+    : public mlir::OpConversionPattern<cir::ComplexEqualOp> {
+public:
+  using mlir::OpConversionPattern<cir::ComplexEqualOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::ComplexEqualOp op, OpAdaptor,
+                  mlir::ConversionPatternRewriter &) const override;
+};
+
 } // namespace direct
 } // namespace cir
 

diff  --git a/clang/test/CIR/CodeGen/complex.cpp b/clang/test/CIR/CodeGen/complex.cpp
index ad3720097a795..1e9ce0e29fd46 100644
--- a/clang/test/CIR/CodeGen/complex.cpp
+++ b/clang/test/CIR/CodeGen/complex.cpp
@@ -368,4 +368,77 @@ int foo17(int _Complex a, int _Complex b) {
 // OGCG: %[[B_REAL:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
 // OGCG: %[[TMP_B:.*]] = load i32, ptr %[[B_REAL]], align 4
 // OGCG: %[[ADD:.*]] = add nsw i32 %[[TMP_A]], %[[TMP_B]]
-// OGCG: ret i32 %[[ADD]]
\ No newline at end of file
+// OGCG: ret i32 %[[ADD]]
+
+bool foo18(int _Complex a, int _Complex b) {
+  return a == b;
+}
+
+// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
+// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
+// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!s32i>
+
+// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
+// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
+// LLVM: %[[A_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 0
+// LLVM: %[[A_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 1
+// LLVM: %[[B_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 0
+// LLVM: %[[B_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 1
+// LLVM: %[[CMP_REAL:.*]] = icmp eq i32 %[[A_REAL]], %[[B_REAL]]
+// LLVM: %[[CMP_IMAG:.*]] = icmp eq i32 %[[A_IMAG]], %[[B_IMAG]]
+// LLVM: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
+
+// OGCG: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, align 4
+// OGCG: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, align 4
+// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 0
+// OGCG: %[[A_REAL:.*]] = load i32, ptr %[[A_REAL_PTR]], align 4
+// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 1
+// OGCG: %[[A_IMAG:.*]] = load i32, ptr %[[A_IMAG_PTR]], align 4
+// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
+// OGCG: %[[B_REAL:.*]] = load i32, ptr %[[B_REAL_PTR]], align 4
+// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 1
+// OGCG: %[[B_IMAG:.*]] = load i32, ptr %[[B_IMAG_PTR]], align 4
+// OGCG: %[[CMP_REAL:.*]] = icmp eq i32 %[[A_REAL]], %[[B_REAL]]
+// OGCG: %[[CMP_IMAG:.*]] = icmp eq i32 %[[A_IMAG]], %[[B_IMAG]]
+// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
+
+bool foo19(double _Complex a, double _Complex b) {
+  return a == b;
+}
+
+// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
+// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
+// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!cir.double>
+
+// LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8
+// LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8
+// LLVM: %[[A_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 0
+// LLVM: %[[A_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 1
+// LLVM: %[[B_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 0
+// LLVM: %[[B_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 1
+// LLVM: %[[CMP_REAL:.*]] = fcmp oeq double %[[A_REAL]], %[[B_REAL]]
+// LLVM: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]]
+// LLVM: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
+
+// OGCG: %[[COMPLEX_A:.*]] = alloca { double, double }, align 8
+// OGCG: %[[COMPLEX_B:.*]] = alloca { double, double }, align 8
+// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
+// OGCG: store double {{.*}}, ptr %[[A_REAL_PTR]], align 8
+// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
+// OGCG: store double {{.*}}, ptr %[[A_IMAG_PTR]], align 8
+// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
+// OGCG: store double {{.*}}, ptr %[[B_REAL_PTR]], align 8
+// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
+// OGCG: store double {{.*}}, ptr %[[B_IMAG_PTR]], align 8
+// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
+// OGCG: %[[A_REAL:.*]] = load double, ptr %[[A_REAL_PTR]], align 8
+// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
+// OGCG: %[[A_IMAG:.*]] = load double, ptr %[[A_IMAG_PTR]], align 8
+// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
+// OGCG: %[[B_REAL:.*]] = load double, ptr %[[B_REAL_PTR]], align 8
+// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
+// OGCG: %[[B_IMAG:.*]] = load double, ptr %[[B_IMAG_PTR]], align 8
+// OGCG: %[[CMP_REAL:.*]] = fcmp oeq double %[[A_REAL]], %[[B_REAL]]
+// OGCG: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]]
+// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
+


        


More information about the cfe-commits mailing list