[clang] [CIR] Update ComplexRealOp to work on scalar type (PR #161080)

Amr Hesham via cfe-commits cfe-commits at lists.llvm.org
Sun Sep 28 05:49:44 PDT 2025


https://github.com/AmrDeveloper created https://github.com/llvm/llvm-project/pull/161080

Update cir::CreateRealOp to make it visible on scalars

Issue #160568

>From 48b11b37327d9f46f1630738a186e971c236acdf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=9CAmr?= <amr96 at programmer.net>
Date: Sun, 28 Sep 2025 14:47:13 +0200
Subject: [PATCH] [CIR] Update ComplexRealOp to work on scalar type

---
 .../clang/CIR/Dialect/Builder/CIRBaseBuilder.h        |  7 ++++---
 clang/include/clang/CIR/Dialect/IR/CIROps.td          |  9 +++++----
 .../clang/CIR/Dialect/IR/CIRTypeConstraints.td        |  6 ++++++
 clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp            |  6 ++++--
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp               | 11 ++++++++++-
 clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp   |  9 +++++++--
 clang/test/CIR/CodeGen/complex.cpp                    |  9 ++++++---
 7 files changed, 42 insertions(+), 15 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index a3f167e3cde2c..7dbb6e42dbb65 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -148,9 +148,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
   }
 
   mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand) {
-    auto operandTy = mlir::cast<cir::ComplexType>(operand.getType());
-    return cir::ComplexRealOp::create(*this, loc, operandTy.getElementType(),
-                                      operand);
+    auto resultType = operand.getType();
+    if (mlir::isa<cir::ComplexType>(resultType))
+      resultType = mlir::cast<cir::ComplexType>(resultType).getElementType();
+    return cir::ComplexRealOp::create(*this, loc, resultType, operand);
   }
 
   mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) {
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index bb394440bf8d8..eaf5b8958c7b8 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -3245,18 +3245,19 @@ def CIR_ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
 def CIR_ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
   let summary = "Extract the real part of a complex value";
   let description = [{
-    `cir.complex.real` operation takes an operand of `!cir.complex` type and
-    yields the real part of it.
+    `cir.complex.real` operation takes an operand of `!cir.complex` or scalar
+    type and yields the real part of it.
 
     Example:
 
     ```mlir
-    %1 = cir.complex.real %0 : !cir.complex<!cir.float> -> !cir.float
+    %real = cir.complex.real %complex : !cir.complex<!cir.float> -> !cir.float
+    %real = cir.complex.real %scalar : !cir.float -> !cir.float
     ```
   }];
 
   let results = (outs CIR_AnyIntOrFloatType:$result);
-  let arguments = (ins CIR_ComplexType:$operand);
+  let arguments = (ins CIR_AnyComplexOrIntOrFloatType:$operand);
 
   let assemblyFormat = [{
     $operand `:` qualified(type($operand)) `->` qualified(type($result))
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
index 82f6e1d33043e..da03a291a7690 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
@@ -165,6 +165,12 @@ def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType],
 
 def CIR_AnyComplexType : CIR_TypeBase<"::cir::ComplexType", "complex type">;
 
+def CIR_AnyComplexOrIntOrFloatType : AnyTypeOf<[
+    CIR_AnyComplexType, CIR_AnyFloatType, CIR_AnyIntType
+], "complex, integer or floating point type"> {
+    let cppFunctionName = "isComplexOrIntegerOrFloatingPointType";
+}
+
 //===----------------------------------------------------------------------===//
 // Array Type predicates
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index bd09d78cd0eb6..f8dcae042740a 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -2146,8 +2146,10 @@ mlir::Value ScalarExprEmitter::VisitRealImag(const UnaryOperator *e,
   }
 
   if (e->getOpcode() == UO_Real) {
-    return promotionTy.isNull() ? Visit(op)
-                                : cgf.emitPromotedScalarExpr(op, promotionTy);
+    mlir::Value operand = promotionTy.isNull()
+                              ? Visit(op)
+                              : cgf.emitPromotedScalarExpr(op, promotionTy);
+    return builder.createComplexReal(loc, operand);
   }
 
   // __imag on a scalar returns zero. Emit the subexpr to ensure side
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 58ef500446aa7..c3dda37db37cf 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -2302,14 +2302,23 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult cir::ComplexRealOp::verify() {
-  if (getType() != getOperand().getType().getElementType()) {
+  mlir::Type operandTy = getOperand().getType();
+  if (mlir::isa<cir::ComplexType>(operandTy)) {
+    operandTy = mlir::cast<cir::ComplexType>(operandTy).getElementType();
+  }
+
+  if (getType() != operandTy) {
     emitOpError() << ": result type does not match operand type";
     return failure();
   }
+
   return success();
 }
 
 OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
+  if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
+    return nullptr;
+
   if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
     return complexCreateOp.getOperand(0);
 
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 876948d53010b..664bf55d2b7f5 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -2992,8 +2992,13 @@ mlir::LogicalResult CIRToLLVMComplexRealOpLowering::matchAndRewrite(
     cir::ComplexRealOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {
   mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
-  rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
-      op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{0});
+  mlir::Value operand = adaptor.getOperand();
+  if (mlir::isa<cir::ComplexType>(op.getOperand().getType())) {
+    operand = mlir::LLVM::ExtractValueOp::create(
+        rewriter, op.getLoc(), resultLLVMTy, operand,
+        llvm::ArrayRef<std::int64_t>{0});
+  }
+  rewriter.replaceOp(op, operand);
   return mlir::success();
 }
 
diff --git a/clang/test/CIR/CodeGen/complex.cpp b/clang/test/CIR/CodeGen/complex.cpp
index e90163172d2df..abb3431e5f37d 100644
--- a/clang/test/CIR/CodeGen/complex.cpp
+++ b/clang/test/CIR/CodeGen/complex.cpp
@@ -1140,7 +1140,8 @@ void real_on_scalar_glvalue() {
 // CIR: %[[A_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["a"]
 // CIR: %[[B_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["b", init]
 // CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.float>, !cir.float
-// CIR: cir.store{{.*}} %[[TMP_A]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
+// CIR: %[[A_REAL:.*]] = cir.complex.real %2 : !cir.float -> !cir.float
+// CIR: cir.store{{.*}} %[[A_REAL]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
 
 // LLVM: %[[A_ADDR:.*]] = alloca float, i64 1, align 4
 // LLVM: %[[B_ADDR:.*]] = alloca float, i64 1, align 4
@@ -1179,7 +1180,8 @@ void real_on_scalar_with_type_promotion() {
 // CIR: %[[B_ADDR:.*]] = cir.alloca !cir.f16, !cir.ptr<!cir.f16>, ["b", init]
 // CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.f16>, !cir.f16
 // CIR: %[[TMP_A_F32:.*]] = cir.cast(floating, %[[TMP_A]] : !cir.f16), !cir.float
-// CIR: %[[TMP_A_F16:.*]] = cir.cast(floating, %[[TMP_A_F32]] : !cir.float), !cir.f16
+// CIR: %[[A_REAL:.*]] = cir.complex.real %[[TMP_A_F32]] : !cir.float -> !cir.float
+// CIR: %[[TMP_A_F16:.*]] = cir.cast(floating, %[[A_REAL]] : !cir.float), !cir.f16
 // CIR: cir.store{{.*}} %[[TMP_A_F16]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
 
 // LLVM: %[[A_ADDR:.*]] = alloca half, i64 1, align 2
@@ -1248,7 +1250,8 @@ void real_on_scalar_from_real_with_type_promotion() {
 // CIR: %[[A_IMAG_F32:.*]] = cir.cast(floating, %[[A_IMAG]] : !cir.f16), !cir.float
 // CIR: %[[A_COMPLEX_F32:.*]] = cir.complex.create %[[A_REAL_F32]], %[[A_IMAG_F32]] : !cir.float -> !cir.complex<!cir.float>
 // CIR: %[[A_REAL_F32:.*]] = cir.complex.real %[[A_COMPLEX_F32]] : !cir.complex<!cir.float> -> !cir.float
-// CIR: %[[A_REAL_F16:.*]] = cir.cast(floating, %[[A_REAL_F32]] : !cir.float), !cir.f16
+// CIR: %[[A_REAL:.*]] = cir.complex.real %[[A_REAL_F32]] : !cir.float -> !cir.float
+// CIR: %[[A_REAL_F16:.*]] = cir.cast(floating, %[[A_REAL]] : !cir.float), !cir.f16
 // CIR: cir.store{{.*}} %[[A_REAL_F16]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
 
 // LLVM: %[[A_ADDR:.*]] = alloca { half, half }, i64 1, align 2



More information about the cfe-commits mailing list