[clang] [CIR] Implement simple folding for unary operations (PR #174882)
Andy Kaylor via cfe-commits
cfe-commits at lists.llvm.org
Thu Jan 8 11:23:51 PST 2026
https://github.com/andykaylor updated https://github.com/llvm/llvm-project/pull/174882
>From eb8f57dd03190c364c80843c03a47b34eed1d0fe Mon Sep 17 00:00:00 2001
From: Andy Kaylor <akaylor at nvidia.com>
Date: Wed, 7 Jan 2026 14:44:20 -0800
Subject: [PATCH 1/3] [CIR] Implement simple folding for unary operations
This extends the UnaryOp folder to handle plus, minus, and not operations on
constant operands.
This is in preparation for a change that will attempt to fold these unary
operations as they are generated, but this change only performs the folding
via the cir-canonicalize pass.
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 6 +
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 31 +++
clang/test/CIR/CodeGen/basic.c | 9 +-
clang/test/CIR/CodeGen/basic.cpp | 13 +-
clang/test/CIR/CodeGen/bitfields_be.c | 12 +-
clang/test/CIR/CodeGen/goto.cpp | 5 +-
clang/test/CIR/CodeGen/pointers.cpp | 19 +-
clang/test/CIR/CodeGen/switch.cpp | 5 +-
clang/test/CIR/CodeGen/ternary.cpp | 3 +-
clang/test/CIR/Lowering/goto.cir | 10 +-
clang/test/CIR/Transforms/canonicalize.cir | 187 ++++++++++++++++++-
11 files changed, 253 insertions(+), 47 deletions(-)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 4274ed25542b1..fc76e41885435 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -420,6 +420,12 @@ def CIR_ConstantOp : CIR_Op<"const", [
llvm_unreachable("Expected an IntAttr in ConstantOp");
}
+ llvm::APFloat getFloatValue() {
+ if (const auto fpAttr = getValueAttr<cir::FPAttr>())
+ return fpAttr.getValue();
+ llvm_unreachable("Expected an FPAttr in ConstantOp");
+ }
+
bool getBoolValue() {
if (const auto boolAttr = getValueAttr<cir::BoolAttr>())
return boolAttr.getValue();
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index a17dade12ed24..e80d858e39e43 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -2479,6 +2479,37 @@ OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
if (isBoolNot(previous))
return previous.getInput();
+ // Fold constant unary operations.
+ if (auto srcConst = getInput().getDefiningOp<cir::ConstantOp>()) {
+ switch (getKind()) {
+ case cir::UnaryOpKind::Not:
+ if (mlir::isa<cir::IntType>(srcConst.getType())) {
+ APInt val = srcConst.getIntValue();
+ val.flipAllBits();
+ return cir::IntAttr::get(getType(), val);
+ }
+ assert(mlir::isa<cir::BoolType>(srcConst.getType()));
+ return cir::BoolAttr::get(getContext(), !srcConst.getBoolValue());
+ case cir::UnaryOpKind::Plus:
+ return srcConst.getResult();
+ case cir::UnaryOpKind::Minus:
+ if (mlir::isa<cir::FPTypeInterface>(srcConst.getType())) {
+ APFloat val = srcConst.getFloatValue();
+ val.changeSign();
+ return cir::FPAttr::get(getType(), val);
+ }
+ if (mlir::isa<cir::IntType>(srcConst.getType())) {
+ APInt val = srcConst.getIntValue();
+ val.negate();
+ return cir::IntAttr::get(getType(), val);
+ }
+ assert(mlir::isa<cir::BoolType>(srcConst.getType()));
+ return srcConst.getResult();
+ default:
+ return {};
+ }
+ }
+
return {};
}
//===----------------------------------------------------------------------===//
diff --git a/clang/test/CIR/CodeGen/basic.c b/clang/test/CIR/CodeGen/basic.c
index 9268615bc9fb0..4646e5b771f8b 100644
--- a/clang/test/CIR/CodeGen/basic.c
+++ b/clang/test/CIR/CodeGen/basic.c
@@ -293,12 +293,9 @@ size_type max_size(void) {
}
// CIR: cir.func{{.*}} @max_size()
-// CIR: %0 = cir.alloca !u64i, !cir.ptr<!u64i>, ["__retval"] {alignment = 8 : i64}
-// CIR: %1 = cir.const #cir.int<0> : !s32i
-// CIR: %2 = cir.unary(not, %1) : !s32i, !s32i
-// CIR: %3 = cir.cast integral %2 : !s32i -> !u64i
-// CIR: %4 = cir.const #cir.int<8> : !u64i
-// CIR: %5 = cir.binop(div, %3, %4) : !u64i
+// CIR: %[[NOT_ZERO:.*]] = cir.const #cir.int<18446744073709551615> : !u64i
+// CIR: %[[SIZE_OF_TP:.*]] = cir.const #cir.int<8> : !u64i
+// CIR: %[[RESULT:.*]] = cir.binop(div, %[[NOT_ZERO]], %[[SIZE_OF_TP]]) : !u64i
// LLVM: define{{.*}} i64 @max_size()
// LLVM: store i64 2305843009213693951, ptr
diff --git a/clang/test/CIR/CodeGen/basic.cpp b/clang/test/CIR/CodeGen/basic.cpp
index af8de6fff047a..245a43710bf71 100644
--- a/clang/test/CIR/CodeGen/basic.cpp
+++ b/clang/test/CIR/CodeGen/basic.cpp
@@ -121,16 +121,9 @@ size_type max_size() {
}
// CHECK: cir.func{{.*}} @_Z8max_sizev() -> !u64i
-// CHECK: %0 = cir.alloca !u64i, !cir.ptr<!u64i>, ["__retval"] {alignment = 8 : i64}
-// CHECK: %1 = cir.const #cir.int<0> : !s32i
-// CHECK: %2 = cir.unary(not, %1) : !s32i, !s32i
-// CHECK: %3 = cir.cast integral %2 : !s32i -> !u64i
-// CHECK: %4 = cir.const #cir.int<8> : !u64i
-// CHECK: %5 = cir.binop(div, %3, %4) : !u64i
-// CHECK: cir.store{{.*}} %5, %0 : !u64i, !cir.ptr<!u64i>
-// CHECK: %6 = cir.load{{.*}} %0 : !cir.ptr<!u64i>, !u64i
-// CHECK: cir.return %6 : !u64i
-// CHECK: }
+// CHECK: %[[NOT_ZERO:.*]] = cir.const #cir.int<18446744073709551615> : !u64i
+// CHECK: %[[SIZE_OF_TP:.*]] = cir.const #cir.int<8> : !u64i
+// CHECK: %[[RESULT:.*]] = cir.binop(div, %[[NOT_ZERO]], %[[SIZE_OF_TP]]) : !u64i
void ref_arg(int &x) {
int y = x;
diff --git a/clang/test/CIR/CodeGen/bitfields_be.c b/clang/test/CIR/CodeGen/bitfields_be.c
index f4f3476d2ef23..730d845928578 100644
--- a/clang/test/CIR/CodeGen/bitfields_be.c
+++ b/clang/test/CIR/CodeGen/bitfields_be.c
@@ -52,12 +52,11 @@ void load(S* s) {
// field 'a'
// CIR: cir.func {{.*}} @load
-// CIR: %[[PTR0:.*]] = cir.alloca !cir.ptr<!rec_S>, !cir.ptr<!cir.ptr<!rec_S>>, ["s", init] {alignment = 8 : i64} loc(#loc35)
-// CIR: %[[CONST1:.*]] = cir.const #cir.int<4> : !s32i
-// CIR: %[[MIN1:.*]] = cir.unary(minus, %[[CONST1]]) nsw : !s32i, !s32i
+// CIR: %[[PTR0:.*]] = cir.alloca !cir.ptr<!rec_S>, !cir.ptr<!cir.ptr<!rec_S>>, ["s", init]
+// CIR: %[[CONST1:.*]] = cir.const #cir.int<-4> : !s32i
// CIR: %[[VAL0:.*]] = cir.load align(8) %[[PTR0]] : !cir.ptr<!cir.ptr<!rec_S>>, !cir.ptr<!rec_S>
// CIR: %[[GET0:.*]] = cir.get_member %[[VAL0]][0] {name = "a"} : !cir.ptr<!rec_S> -> !cir.ptr<!u32i>
-// CIR: %[[SET0:.*]] = cir.set_bitfield align(4) (#bfi_a, %[[GET0]] : !cir.ptr<!u32i>, %[[MIN1]] : !s32i) -> !s32i
+// CIR: %[[SET0:.*]] = cir.set_bitfield align(4) (#bfi_a, %[[GET0]] : !cir.ptr<!u32i>, %[[CONST1]] : !s32i) -> !s32i
// LLVM: define dso_local void @load{{.*}}{{.*}}
// LLVM: %[[PTR0:.*]] = load ptr
@@ -94,11 +93,10 @@ void load(S* s) {
// OGCG: store i32 %[[OR1]], ptr %[[PTR1]], align 4
// field 'c'
-// CIR: %[[CONST3:.*]] = cir.const #cir.int<12345> : !s32i
-// CIR: %[[MIN2:.*]] = cir.unary(minus, %[[CONST3]]) nsw : !s32i, !s32i
+// CIR: %[[CONST3:.*]] = cir.const #cir.int<-12345> : !s32i
// CIR: %[[VAL2:.*]] = cir.load align(8) %[[PTR0]] : !cir.ptr<!cir.ptr<!rec_S>>, !cir.ptr<!rec_S>
// CIR: %[[GET2:.*]] = cir.get_member %[[VAL2]][0] {name = "c"} : !cir.ptr<!rec_S> -> !cir.ptr<!u32i>
-// CIR: %[[SET2:.*]] = cir.set_bitfield align(4) (#bfi_c, %[[GET2]] : !cir.ptr<!u32i>, %[[MIN2]] : !s32i) -> !s32i
+// CIR: %[[SET2:.*]] = cir.set_bitfield align(4) (#bfi_c, %[[GET2]] : !cir.ptr<!u32i>, %[[CONST3]] : !s32i) -> !s32i
// LLVM: %[[PTR2:.*]] = load ptr
// LLVM: %[[GET2:.*]] = getelementptr %struct.S, ptr %[[PTR2]], i32 0, i32 0
diff --git a/clang/test/CIR/CodeGen/goto.cpp b/clang/test/CIR/CodeGen/goto.cpp
index 4b825d619c221..e9ef9fd81f160 100644
--- a/clang/test/CIR/CodeGen/goto.cpp
+++ b/clang/test/CIR/CodeGen/goto.cpp
@@ -24,9 +24,8 @@ int shouldNotGenBranchRet(int x) {
// CIR: cir.return [[RET]] : !s32i
// CIR: ^bb2:
// CIR: cir.label "err"
-// CIR: [[ONE:%.*]] = cir.const #cir.int<1> : !s32i
-// CIR: [[MINUS:%.*]] = cir.unary(minus, [[ONE]]) nsw : !s32i, !s32i
-// CIR: cir.store [[MINUS]], [[RETVAL]] : !s32i, !cir.ptr<!s32i>
+// CIR: [[MINUS_ONE:%.*]] = cir.const #cir.int<-1> : !s32i
+// CIR: cir.store [[MINUS_ONE]], [[RETVAL]] : !s32i, !cir.ptr<!s32i>
// CIR: cir.br ^bb1
// LLVM: define dso_local i32 @_Z21shouldNotGenBranchReti
diff --git a/clang/test/CIR/CodeGen/pointers.cpp b/clang/test/CIR/CodeGen/pointers.cpp
index 68eea6210f1dd..b07c1b57c127d 100644
--- a/clang/test/CIR/CodeGen/pointers.cpp
+++ b/clang/test/CIR/CodeGen/pointers.cpp
@@ -9,14 +9,17 @@ void foo(int *iptr, char *cptr, unsigned ustride) {
cptr + 3;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s8i>, !s32i) -> !cir.ptr<!s8i>
- iptr - 2;
- // CHECK: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i
- // CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i
- // CHECK: cir.ptr_stride %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
- cptr - 3;
- // CHECK: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i
- // CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i
- // CHECK: cir.ptr_stride %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s8i>, !s32i) -> !cir.ptr<!s8i>
+
+ // We need to assign to a temporary in these cases because otherwise
+ // constant folding of the unary minus for thenegative stride value also
+ // triggers erasing the unused result of the ptr_stride operation.
+ int* iptr2 = iptr - 2;
+ // CHECK: %[[#STRIDE:]] = cir.const #cir.int<-2> : !s32i
+ // CHECK: cir.ptr_stride %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
+ char* cptr2 = cptr - 3;
+
+ // CHECK: %[[#STRIDE:]] = cir.const #cir.int<-3> : !s32i
+ // CHECK: cir.ptr_stride %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s8i>, !s32i) -> !cir.ptr<!s8i>
iptr + ustride;
// CHECK: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s32i>, !u32i) -> !cir.ptr<!s32i>
diff --git a/clang/test/CIR/CodeGen/switch.cpp b/clang/test/CIR/CodeGen/switch.cpp
index a7468954665c0..3449645643d82 100644
--- a/clang/test/CIR/CodeGen/switch.cpp
+++ b/clang/test/CIR/CodeGen/switch.cpp
@@ -1218,9 +1218,8 @@ int sw_return_multi_cases(int x) {
// CIR-NEXT: cir.return %[[RET2]] : !s32i
// CIR-NEXT: }
// CIR-NEXT: cir.case(default, []) {
-// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
-// CIR: %[[NEG:.*]] = cir.unary(minus, %[[ONE]]) {{.*}} : !s32i, !s32i
-// CIR: cir.store{{.*}} %[[NEG]], %{{.*}} : !s32i, !cir.ptr<!s32i>
+// CIR: %[[MINUS_ONE:.*]] = cir.const #cir.int<-1> : !s32i
+// CIR: cir.store{{.*}} %[[MINUS_ONE]], %{{.*}} : !s32i, !cir.ptr<!s32i>
// CIR: %[[RETDEF:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
// CIR-NEXT: cir.return %[[RETDEF]] : !s32i
// CIR-NEXT: }
diff --git a/clang/test/CIR/CodeGen/ternary.cpp b/clang/test/CIR/CodeGen/ternary.cpp
index 847c0b4a04009..b57cdefdc26ce 100644
--- a/clang/test/CIR/CodeGen/ternary.cpp
+++ b/clang/test/CIR/CodeGen/ternary.cpp
@@ -71,8 +71,7 @@ int foo(int a, int b) {
// CIR: }) : (!cir.bool) -> !s32i
// CIR: [[CAST:%.+]] = cir.cast int_to_bool [[TERNARY_RES]] : !s32i -> !cir.bool
// CIR: cir.if [[CAST]] {
-// CIR: [[ONE:%.+]] = cir.const #cir.int<1> : !s32i
-// CIR: [[MINUS_ONE:%.+]] = cir.unary(minus, [[ONE]]) nsw : !s32i, !s32i
+// CIR: [[MINUS_ONE:%.+]] = cir.const #cir.int<-1> : !s32i
// CIR: cir.store [[MINUS_ONE]], [[RETVAL]] : !s32i, !cir.ptr<!s32i>
// CIR: [[RETVAL_VAL:%.+]] = cir.load [[RETVAL]] : !cir.ptr<!s32i>, !s32i
// CIR: cir.return [[RETVAL_VAL]] : !s32i
diff --git a/clang/test/CIR/Lowering/goto.cir b/clang/test/CIR/Lowering/goto.cir
index cd3a57d2e7138..f19fcdd8e8c8d 100644
--- a/clang/test/CIR/Lowering/goto.cir
+++ b/clang/test/CIR/Lowering/goto.cir
@@ -24,14 +24,13 @@ module {
cir.return %3 : !s32i
^bb2:
cir.label "err"
- %4 = cir.const #cir.int<1> : !s32i
- %5 = cir.unary(minus, %4) : !s32i, !s32i
- cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
+ %4 = cir.const #cir.int<-1> : !s32i
+ cir.store %4, %1 : !s32i, !cir.ptr<!s32i>
cir.br ^bb1
}
// MLIR: llvm.func @gotoFromIf
-// MLIR: %[[#One:]] = llvm.mlir.constant(1 : i32) : i32
+// MLIR: %[[#Minus_one:]] = llvm.mlir.constant(-1 : i32) : i32
// MLIR: %[[#Zero:]] = llvm.mlir.constant(0 : i32) : i32
// MLIR: llvm.cond_br {{.*}}, ^bb[[#COND_YES:]], ^bb[[#COND_NO:]]
// MLIR: ^bb[[#COND_YES]]:
@@ -45,8 +44,7 @@ module {
// MLIR: %[[#Ret_val:]] = llvm.load %[[#Ret_val_addr]] {alignment = 4 : i64} : !llvm.ptr -> i32
// MLIR: llvm.return %[[#Ret_val]] : i32
// MLIR: ^bb[[#GOTO_BLK]]:
-// MLIR: %[[#Neg_one:]] = llvm.sub %[[#Zero]], %[[#One]] : i32
-// MLIR: llvm.store %[[#Neg_one]], %[[#Ret_val_addr]] {{.*}}: i32, !llvm.ptr
+// MLIR: llvm.store %[[#Minus_one]], %[[#Ret_val_addr]] {{.*}}: i32, !llvm.ptr
// MLIR: llvm.br ^bb[[#RETURN]]
// MLIR: }
}
diff --git a/clang/test/CIR/Transforms/canonicalize.cir b/clang/test/CIR/Transforms/canonicalize.cir
index 4f29fbc273801..cfac73ecdb738 100644
--- a/clang/test/CIR/Transforms/canonicalize.cir
+++ b/clang/test/CIR/Transforms/canonicalize.cir
@@ -7,6 +7,9 @@
!u32i = !cir.int<u, 32>
!u64i = !cir.int<u, 64>
+#true = #cir.bool<true> : !cir.bool
+#false = #cir.bool<false> : !cir.bool
+
module {
cir.func @redundant_br() {
cir.br ^bb1
@@ -60,12 +63,12 @@ module {
// CHECK-NEXT: cir.return
// CHECK-NEXT: }
- cir.func @unary_not(%arg0: !cir.bool) -> !cir.bool {
+ cir.func @unary_not_not(%arg0: !cir.bool) -> !cir.bool {
%0 = cir.unary(not, %arg0) : !cir.bool, !cir.bool
%1 = cir.unary(not, %0) : !cir.bool, !cir.bool
cir.return %1 : !cir.bool
}
- // CHECK: cir.func{{.*}} @unary_not(%arg0: !cir.bool) -> !cir.bool
+ // CHECK: cir.func{{.*}} @unary_not_not(%arg0: !cir.bool) -> !cir.bool
// CHECK-NEXT: cir.return %arg0 : !cir.bool
cir.func @unary_poison() -> !s32i {
@@ -78,6 +81,186 @@ module {
// CHECK-NEXT: cir.return %[[P]] : !s32i
// CHECK-NEXT: }
+ cir.func @unary_not_true() -> !cir.bool {
+ %0 = cir.const #true
+ %1 = cir.unary(not, %0) : !cir.bool, !cir.bool
+ cir.return %1 : !cir.bool
+ }
+ // CHECK: cir.func{{.*}} @unary_not_true() -> !cir.bool
+ // CHECK-NEXT: %[[FALSE:.*]] = cir.const #false
+ // CHECK-NEXT: cir.return %[[FALSE]] : !cir.bool
+
+ cir.func @unary_not_false() -> !cir.bool {
+ %0 = cir.const #false
+ %1 = cir.unary(not, %0) : !cir.bool, !cir.bool
+ cir.return %1 : !cir.bool
+ }
+ // CHECK: cir.func{{.*}} @unary_not_false() -> !cir.bool
+ // CHECK-NEXT: %[[FALSE:.*]] = cir.const #true
+ // CHECK-NEXT: cir.return %[[FALSE]] : !cir.bool
+
+ cir.func @unary_not_int() -> !s32i {
+ %0 = cir.const #cir.int<1> : !s32i
+ %1 = cir.unary(not, %0) : !s32i, !s32i
+ cir.return %1 : !s32i
+ }
+ // CHECK: cir.func{{.*}} @unary_not_int() -> !s32i
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.int<-2> : !s32i
+ // CHECK-NEXT: cir.return %[[CONST]] : !s32i
+
+ cir.func @unary_not_uint() -> !u32i {
+ %0 = cir.const #cir.int<1> : !u32i
+ %1 = cir.unary(not, %0) : !u32i, !u32i
+ cir.return %1 : !u32i
+ }
+ // CHECK: cir.func{{.*}} @unary_not_uint() -> !u32i
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.int<4294967294> : !u32i
+ // CHECK-NEXT: cir.return %[[CONST]] : !u32i
+
+ cir.func @unary_plus_true() -> !cir.bool {
+ %0 = cir.const #true
+ %1 = cir.unary(plus, %0) : !cir.bool, !cir.bool
+ cir.return %1 : !cir.bool
+ }
+ // CHECK: cir.func{{.*}} @unary_plus_true() -> !cir.bool
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #true
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.bool
+
+ cir.func @unary_plus_false() -> !cir.bool {
+ %0 = cir.const #false
+ %1 = cir.unary(plus, %0) : !cir.bool, !cir.bool
+ cir.return %1 : !cir.bool
+ }
+ // CHECK: cir.func{{.*}} @unary_plus_false() -> !cir.bool
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #false
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.bool
+
+ cir.func @unary_plus_int() -> !s32i {
+ %0 = cir.const #cir.int<1> : !s32i
+ %1 = cir.unary(plus, %0) : !s32i, !s32i
+ cir.return %1 : !s32i
+ }
+ // CHECK: cir.func{{.*}} @unary_plus_int() -> !s32i
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.int<1> : !s32i
+ // CHECK-NEXT: cir.return %[[CONST]] : !s32i
+
+ cir.func @unary_plus_uint() -> !u32i {
+ %0 = cir.const #cir.int<1> : !u32i
+ %1 = cir.unary(plus, %0) : !u32i, !u32i
+ cir.return %1 : !u32i
+ }
+ // CHECK: cir.func{{.*}} @unary_plus_uint() -> !u32i
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.int<1> : !u32i
+ // CHECK-NEXT: cir.return %[[CONST]] : !u32i
+
+ cir.func @unary_plus_float() -> !cir.float {
+ %0 = cir.const #cir.fp<1.100000e+00> : !cir.float
+ %1 = cir.unary(plus, %0) : !cir.float, !cir.float
+ cir.return %1 : !cir.float
+ }
+ // CHECK: cir.func{{.*}} @unary_plus_float() -> !cir.float
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.fp<1.100000e+00> : !cir.float
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.float
+
+ cir.func @unary_plus_double() -> !cir.double {
+ %0 = cir.const #cir.fp<1.100000e+00> : !cir.double
+ %1 = cir.unary(plus, %0) : !cir.double, !cir.double
+ cir.return %1 : !cir.double
+ }
+ // CHECK: cir.func{{.*}} @unary_plus_double() -> !cir.double
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.fp<1.100000e+00> : !cir.double
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.double
+
+ cir.func @unary_plus_nan() -> !cir.float {
+ %0 = cir.const #cir.fp<0x7F800000> : !cir.float
+ %1 = cir.unary(plus, %0) : !cir.float, !cir.float
+ cir.return %1 : !cir.float
+ }
+ // CHECK: cir.func{{.*}} @unary_plus_nan() -> !cir.float
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.fp<0x7F800000> : !cir.float
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.float
+
+ cir.func @unary_plus_neg_nan() -> !cir.float {
+ %0 = cir.const #cir.fp<0xFF800000> : !cir.float
+ %1 = cir.unary(plus, %0) : !cir.float, !cir.float
+ cir.return %1 : !cir.float
+ }
+ // CHECK: cir.func{{.*}} @unary_plus_neg_nan() -> !cir.float
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.fp<0xFF800000> : !cir.float
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.float
+
+ cir.func @unary_minus_true() -> !cir.bool {
+ %0 = cir.const #true
+ %1 = cir.unary(minus, %0) : !cir.bool, !cir.bool
+ cir.return %1 : !cir.bool
+ }
+ // CHECK: cir.func{{.*}} @unary_minus_true() -> !cir.bool
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #true
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.bool
+
+ cir.func @unary_minus_false() -> !cir.bool {
+ %0 = cir.const #false
+ %1 = cir.unary(minus, %0) : !cir.bool, !cir.bool
+ cir.return %1 : !cir.bool
+ }
+ // CHECK: cir.func{{.*}} @unary_minus_false() -> !cir.bool
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #false
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.bool
+
+ cir.func @unary_minus_int() -> !s32i {
+ %0 = cir.const #cir.int<1> : !s32i
+ %1 = cir.unary(minus, %0) : !s32i, !s32i
+ cir.return %1 : !s32i
+ }
+ // CHECK: cir.func{{.*}} @unary_minus_int() -> !s32i
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.int<-1> : !s32i
+ // CHECK-NEXT: cir.return %[[CONST]] : !s32i
+
+ cir.func @unary_minus_uint() -> !u32i {
+ %0 = cir.const #cir.int<1> : !u32i
+ %1 = cir.unary(minus, %0) : !u32i, !u32i
+ cir.return %1 : !u32i
+ }
+ // CHECK: cir.func{{.*}} @unary_minus_uint() -> !u32i
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.int<4294967295> : !u32i
+ // CHECK-NEXT: cir.return %[[CONST]] : !u32i
+
+ cir.func @unary_minus_float() -> !cir.float {
+ %0 = cir.const #cir.fp<1.100000e+00> : !cir.float
+ %1 = cir.unary(minus, %0) : !cir.float, !cir.float
+ cir.return %1 : !cir.float
+ }
+ // CHECK: cir.func{{.*}} @unary_minus_float() -> !cir.float
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.fp<-1.100000e+00> : !cir.float
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.float
+
+ cir.func @unary_minus_double() -> !cir.double {
+ %0 = cir.const #cir.fp<1.100000e+00> : !cir.double
+ %1 = cir.unary(minus, %0) : !cir.double, !cir.double
+ cir.return %1 : !cir.double
+ }
+ // CHECK: cir.func{{.*}} @unary_minus_double() -> !cir.double
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.fp<-1.100000e+00> : !cir.double
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.double
+
+ cir.func @unary_minus_nan() -> !cir.float {
+ %0 = cir.const #cir.fp<0x7F800000> : !cir.float
+ %1 = cir.unary(minus, %0) : !cir.float, !cir.float
+ cir.return %1 : !cir.float
+ }
+ // CHECK: cir.func{{.*}} @unary_minus_nan() -> !cir.float
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.fp<0xFF800000> : !cir.float
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.float
+
+ cir.func @unary_minus_neg_nan() -> !cir.float {
+ %0 = cir.const #cir.fp<0xFF800000> : !cir.float
+ %1 = cir.unary(minus, %0) : !cir.float, !cir.float
+ cir.return %1 : !cir.float
+ }
+ // CHECK: cir.func{{.*}} @unary_minus_neg_nan() -> !cir.float
+ // CHECK-NEXT: %[[CONST:.*]] = cir.const #cir.fp<0x7F800000> : !cir.float
+ // CHECK-NEXT: cir.return %[[CONST]] : !cir.float
+
cir.func @cast1(%arg0: !cir.bool) -> !cir.bool {
%0 = cir.cast bool_to_int %arg0 : !cir.bool -> !s32i
%1 = cir.cast int_to_bool %0 : !s32i -> !cir.bool
>From a28a2307ed81c8c1a5cbab5e0e95a541b7a398ac Mon Sep 17 00:00:00 2001
From: Andy Kaylor <akaylor at nvidia.com>
Date: Wed, 7 Jan 2026 16:05:20 -0800
Subject: [PATCH 2/3] Address review feedback
---
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index e80d858e39e43..f3d04b9a8e49c 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -2488,8 +2488,9 @@ OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
val.flipAllBits();
return cir::IntAttr::get(getType(), val);
}
- assert(mlir::isa<cir::BoolType>(srcConst.getType()));
- return cir::BoolAttr::get(getContext(), !srcConst.getBoolValue());
+ if (mlir::isa<cir::BoolType>(srcConst.getType()))
+ return cir::BoolAttr::get(getContext(), !srcConst.getBoolValue());
+ break;
case cir::UnaryOpKind::Plus:
return srcConst.getResult();
case cir::UnaryOpKind::Minus:
@@ -2503,8 +2504,9 @@ OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
val.negate();
return cir::IntAttr::get(getType(), val);
}
- assert(mlir::isa<cir::BoolType>(srcConst.getType()));
- return srcConst.getResult();
+ if (mlir::isa<cir::BoolType>(srcConst.getType()))
+ return srcConst.getResult();
+ break;
default:
return {};
}
>From 70ece07a1712f1a5ceef2ee0b81904b91fb5e861 Mon Sep 17 00:00:00 2001
From: Andy Kaylor <akaylor at nvidia.com>
Date: Thu, 8 Jan 2026 10:10:02 -0800
Subject: [PATCH 3/3] Rewrite fold using input from adaptor
---
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 92 +++++++++++++++++--------
1 file changed, 63 insertions(+), 29 deletions(-)
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index f3d04b9a8e49c..5031f89461b93 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -27,6 +27,7 @@
#include "clang/CIR/MissingFeatures.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
using namespace mlir;
@@ -2479,37 +2480,70 @@ OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
if (isBoolNot(previous))
return previous.getInput();
- // Fold constant unary operations.
+ // Avoid introducing unnecessary duplicate constants in cases where we are
+ // just folding the operation to its input value. If we return the
+ // input attribute from the adapter, a new constant is materialized, but
+ // if we return the input value directly, it avoids that.
if (auto srcConst = getInput().getDefiningOp<cir::ConstantOp>()) {
- switch (getKind()) {
- case cir::UnaryOpKind::Not:
- if (mlir::isa<cir::IntType>(srcConst.getType())) {
- APInt val = srcConst.getIntValue();
- val.flipAllBits();
- return cir::IntAttr::get(getType(), val);
- }
- if (mlir::isa<cir::BoolType>(srcConst.getType()))
- return cir::BoolAttr::get(getContext(), !srcConst.getBoolValue());
- break;
- case cir::UnaryOpKind::Plus:
+ if (getKind() == cir::UnaryOpKind::Plus ||
+ (mlir::isa<cir::BoolType>(srcConst.getType()) &&
+ getKind() == cir::UnaryOpKind::Minus))
return srcConst.getResult();
- case cir::UnaryOpKind::Minus:
- if (mlir::isa<cir::FPTypeInterface>(srcConst.getType())) {
- APFloat val = srcConst.getFloatValue();
- val.changeSign();
- return cir::FPAttr::get(getType(), val);
- }
- if (mlir::isa<cir::IntType>(srcConst.getType())) {
- APInt val = srcConst.getIntValue();
- val.negate();
- return cir::IntAttr::get(getType(), val);
- }
- if (mlir::isa<cir::BoolType>(srcConst.getType()))
- return srcConst.getResult();
- break;
- default:
- return {};
- }
+ }
+
+ // Fold unary operations with constant inputs. If the input is a ConstantOp,
+ // it "folds" to its value attribute. If it was some other operation that
+ // was folded, it will be an mlir::Attribute that hasn't yet been
+ // materialized. If it was a value that couldn't be folded, it will be null.
+ if (mlir::Attribute attr = adaptor.getInput()) {
+ // For now, we only attempt to fold simple scalar values.
+ OpFoldResult result =
+ llvm::TypeSwitch<mlir::Attribute, OpFoldResult>(attr)
+ .Case<cir::IntAttr>([&](cir::IntAttr attrT) {
+ switch (getKind()) {
+ case cir::UnaryOpKind::Not: {
+ APInt val = attrT.getValue();
+ val.flipAllBits();
+ return cir::IntAttr::get(getType(), val);
+ }
+ case cir::UnaryOpKind::Plus:
+ return attrT;
+ case cir::UnaryOpKind::Minus: {
+ APInt val = attrT.getValue();
+ val.negate();
+ return cir::IntAttr::get(getType(), val);
+ }
+ default:
+ return cir::IntAttr{};
+ }
+ })
+ .Case<cir::FPAttr>([&](cir::FPAttr attrT) {
+ switch (getKind()) {
+ case cir::UnaryOpKind::Plus:
+ return attrT;
+ case cir::UnaryOpKind::Minus: {
+ APFloat val = attrT.getValue();
+ val.changeSign();
+ return cir::FPAttr::get(getType(), val);
+ }
+ default:
+ return cir::FPAttr{};
+ }
+ })
+ .Case<cir::BoolAttr>([&](cir::BoolAttr attrT) {
+ switch (getKind()) {
+ case cir::UnaryOpKind::Not:
+ return cir::BoolAttr::get(getContext(), !attrT.getValue());
+ case cir::UnaryOpKind::Plus:
+ case cir::UnaryOpKind::Minus:
+ return attrT;
+ default:
+ return cir::BoolAttr{};
+ }
+ })
+ .Default([&](auto attrT) { return mlir::Attribute{}; });
+ if (result)
+ return result;
}
return {};
More information about the cfe-commits
mailing list