[clang] Fix cir vec cmp fold (PR #202502)
Aayush Shrivastava via cfe-commits
cfe-commits at lists.llvm.org
Wed Jun 10 03:13:05 PDT 2026
https://github.com/iamaayushrivastava updated https://github.com/llvm/llvm-project/pull/202502
>From 35e9c8c39ef48f68d4d72340a3a7a10e0618a905 Mon Sep 17 00:00:00 2001
From: iamaayushrivastava <iamaayushrivastava at gmail.com>
Date: Tue, 9 Jun 2026 09:37:34 +0530
Subject: [PATCH 1/2] [CIR] Fix constant folding of vector comparisons for
unsigned types and result values
---
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 48 +++++++++----
clang/test/CIR/Transforms/vector-cmp-fold.cir | 67 +++++++++++++++----
2 files changed, 89 insertions(+), 26 deletions(-)
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index cf07fc4f0833a..987b605193296 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -3436,15 +3436,21 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
SmallVector<mlir::Attribute, 16> elements(vecSize);
bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
+ bool isUnsignedInt =
+ isIntAttr && mlir::cast<cir::IntType>(inputElemTy).isUnsigned();
for (uint64_t i = 0; i < vecSize; i++) {
mlir::Attribute lhsAttr = lhsVecElhs[i];
mlir::Attribute rhsAttr = rhsVecElhs[i];
- int cmpResult = 0;
+ bool cmpResult = false;
switch (opKind) {
case cir::CmpOpKind::lt: {
if (isIntAttr) {
- cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
- mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ if (isUnsignedInt)
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getUInt() <
+ mlir::cast<cir::IntAttr>(rhsAttr).getUInt();
+ else
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
@@ -3453,8 +3459,12 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
}
case cir::CmpOpKind::le: {
if (isIntAttr) {
- cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
- mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ if (isUnsignedInt)
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getUInt() <=
+ mlir::cast<cir::IntAttr>(rhsAttr).getUInt();
+ else
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
@@ -3463,8 +3473,12 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
}
case cir::CmpOpKind::gt: {
if (isIntAttr) {
- cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
- mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ if (isUnsignedInt)
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getUInt() >
+ mlir::cast<cir::IntAttr>(rhsAttr).getUInt();
+ else
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
@@ -3473,8 +3487,12 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
}
case cir::CmpOpKind::ge: {
if (isIntAttr) {
- cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
- mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ if (isUnsignedInt)
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getUInt() >=
+ mlir::cast<cir::IntAttr>(rhsAttr).getUInt();
+ else
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
@@ -3483,8 +3501,8 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
}
case cir::CmpOpKind::eq: {
if (isIntAttr) {
- cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
- mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getValue() ==
+ mlir::cast<cir::IntAttr>(rhsAttr).getValue();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
@@ -3493,8 +3511,8 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
}
case cir::CmpOpKind::ne: {
if (isIntAttr) {
- cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
- mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getValue() !=
+ mlir::cast<cir::IntAttr>(rhsAttr).getValue();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
@@ -3517,7 +3535,9 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
}
}
- elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
+ // Vector comparison results are 0 (false) or -1 / all-ones (true).
+ elements[i] =
+ cir::IntAttr::get(getType().getElementType(), cmpResult ? -1LL : 0LL);
}
return cir::ConstVectorAttr::get(
diff --git a/clang/test/CIR/Transforms/vector-cmp-fold.cir b/clang/test/CIR/Transforms/vector-cmp-fold.cir
index f3486bd26fe1b..4d7c81f81cccb 100644
--- a/clang/test/CIR/Transforms/vector-cmp-fold.cir
+++ b/clang/test/CIR/Transforms/vector-cmp-fold.cir
@@ -29,8 +29,8 @@ module {
}
// CHECK: cir.func{{.*}} @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
- // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
- // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<-1> : !s32i, #cir.int<-1> : !s32i,
+ // CHECK-SAME: #cir.int<-1> : !s32i, #cir.int<-1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}
@@ -47,8 +47,8 @@ module {
}
// CHECK: cir.func{{.*}} @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
- // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
- // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<-1> : !s32i, #cir.int<-1> : !s32i,
+ // CHECK-SAME: #cir.int<-1> : !s32i, #cir.int<-1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}
@@ -65,8 +65,8 @@ module {
}
// CHECK: cir.func{{.*}} @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
- // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
- // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<-1> : !s32i, #cir.int<-1> : !s32i,
+ // CHECK-SAME: #cir.int<-1> : !s32i, #cir.int<-1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}
@@ -141,8 +141,8 @@ module {
}
// CHECK: cir.func{{.*}} @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
- // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
- // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<-1> : !s32i, #cir.int<-1> : !s32i,
+ // CHECK-SAME: #cir.int<-1> : !s32i, #cir.int<-1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}
@@ -161,8 +161,8 @@ module {
}
// CHECK: cir.func{{.*}} @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
- // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
- // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<-1> : !s32i, #cir.int<-1> : !s32i,
+ // CHECK-SAME: #cir.int<-1> : !s32i, #cir.int<-1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}
@@ -181,8 +181,8 @@ module {
}
// CHECK: cir.func{{.*}} @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
- // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
- // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<-1> : !s32i, #cir.int<-1> : !s32i,
+ // CHECK-SAME: #cir.int<-1> : !s32i, #cir.int<-1> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}
@@ -225,3 +225,46 @@ module {
// CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
}
+
+// -----
+
+// Test unsigned integer comparisons: result must use unsigned ordering and
+// produce 0 or -1 (all-ones) per the SIMD convention.
+
+!u8i = !cir.int<u, 8>
+!s8i = !cir.int<s, 8>
+
+module {
+ cir.func @fold_cmp_vector_unsigned_gt() -> !cir.vector<4 x !s8i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<255> : !u8i, #cir.int<0> : !u8i, #cir.int<1> : !u8i, #cir.int<254> : !u8i]> : !cir.vector<4 x !u8i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<254> : !u8i, #cir.int<255> : !u8i, #cir.int<255> : !u8i, #cir.int<255> : !u8i]> : !cir.vector<4 x !u8i>
+ %new_vec = cir.vec.cmp(gt, %vec_1, %vec_2) : !cir.vector<4 x !u8i>, !cir.vector<4 x !s8i>
+ cir.return %new_vec : !cir.vector<4 x !s8i>
+ }
+
+ // 255>254 (T), 0>255 (F), 1>255 (F), 254>255 (F) -> [-1, 0, 0, 0]
+ // CHECK: cir.func{{.*}} @fold_cmp_vector_unsigned_gt() -> !cir.vector<4 x !s8i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<-1> : !s8i, #cir.int<0> : !s8i,
+ // CHECK-SAME: #cir.int<0> : !s8i, #cir.int<0> : !s8i]> : !cir.vector<4 x !s8i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s8i>
+}
+
+// -----
+
+!u8i = !cir.int<u, 8>
+!s8i = !cir.int<s, 8>
+
+module {
+ cir.func @fold_cmp_vector_unsigned_lt() -> !cir.vector<4 x !s8i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<0> : !u8i, #cir.int<255> : !u8i, #cir.int<128> : !u8i, #cir.int<127> : !u8i]> : !cir.vector<4 x !u8i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<255> : !u8i, #cir.int<0> : !u8i, #cir.int<129> : !u8i, #cir.int<128> : !u8i]> : !cir.vector<4 x !u8i>
+ %new_vec = cir.vec.cmp(lt, %vec_1, %vec_2) : !cir.vector<4 x !u8i>, !cir.vector<4 x !s8i>
+ cir.return %new_vec : !cir.vector<4 x !s8i>
+ }
+
+ // 0<255 (T), 255<0 (F), 128<129 (T), 127<128 (T) -> [-1, 0, -1, -1]
+ // CHECK: cir.func{{.*}} @fold_cmp_vector_unsigned_lt() -> !cir.vector<4 x !s8i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<-1> : !s8i, #cir.int<0> : !s8i,
+ // CHECK-SAME: #cir.int<-1> : !s8i, #cir.int<-1> : !s8i]> : !cir.vector<4 x !s8i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s8i>
+}
>From 27196909cf36eaf1fa9bc6436da71a866b6e1712 Mon Sep 17 00:00:00 2001
From: iamaayushrivastava <iamaayushrivastava at gmail.com>
Date: Wed, 10 Jun 2026 15:42:36 +0530
Subject: [PATCH 2/2] Address review comments on vector comparisons for
unsigned types
---
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 12 ++++++----
clang/test/CIR/Transforms/vector-cmp-fold.cir | 24 +++++++++++++++++++
2 files changed, 31 insertions(+), 5 deletions(-)
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 987b605193296..d6fee0da92ad1 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -3501,8 +3501,8 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
}
case cir::CmpOpKind::eq: {
if (isIntAttr) {
- cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getValue() ==
- mlir::cast<cir::IntAttr>(rhsAttr).getValue();
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
@@ -3511,8 +3511,8 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
}
case cir::CmpOpKind::ne: {
if (isIntAttr) {
- cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getValue() !=
- mlir::cast<cir::IntAttr>(rhsAttr).getValue();
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
mlir::cast<cir::FPAttr>(rhsAttr).getValue();
@@ -3535,7 +3535,9 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
}
}
- // Vector comparison results are 0 (false) or -1 / all-ones (true).
+ // A true result is all bits set (-1 in two's complement), and a false
+ // result is all bits clear. For a 1-bit element type these are the same
+ // bit pattern as 1 and 0, respectively.
elements[i] =
cir::IntAttr::get(getType().getElementType(), cmpResult ? -1LL : 0LL);
}
diff --git a/clang/test/CIR/Transforms/vector-cmp-fold.cir b/clang/test/CIR/Transforms/vector-cmp-fold.cir
index 4d7c81f81cccb..9198db396c1e8 100644
--- a/clang/test/CIR/Transforms/vector-cmp-fold.cir
+++ b/clang/test/CIR/Transforms/vector-cmp-fold.cir
@@ -268,3 +268,27 @@ module {
// CHECK-SAME: #cir.int<-1> : !s8i, #cir.int<-1> : !s8i]> : !cir.vector<4 x !s8i>
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s8i>
}
+
+// -----
+
+// Test folding when the result element type is a 1-bit signed integer. The
+// all-ones (true) bit pattern is printed as -1 and all-zeros (false) as 0,
+// since -1 and 1 share the same bit pattern at width 1.
+
+!s32i = !cir.int<s, 32>
+!s1i = !cir.int<s, 1>
+
+module {
+ cir.func @fold_cmp_vector_op_test_i1_result() -> !cir.vector<4 x !s1i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(lt, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s1i>
+ cir.return %new_vec : !cir.vector<4 x !s1i>
+ }
+
+ // 1<2, 3<4, 5<6, 7<8 are all true -> [-1, -1, -1, -1]
+ // CHECK: cir.func{{.*}} @fold_cmp_vector_op_test_i1_result() -> !cir.vector<4 x !cir.int<s, 1>> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<-1> : !cir.int<s, 1>, #cir.int<-1> : !cir.int<s, 1>,
+ // CHECK-SAME: #cir.int<-1> : !cir.int<s, 1>, #cir.int<-1> : !cir.int<s, 1>]> : !cir.vector<4 x !cir.int<s, 1>>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !cir.int<s, 1>>
+}
More information about the cfe-commits
mailing list