[clang] [CIR] Implement folder for VecCmpOp (PR #143322)
Amr Hesham via cfe-commits
cfe-commits at lists.llvm.org
Fri Jun 13 13:25:31 PDT 2025
https://github.com/AmrDeveloper updated https://github.com/llvm/llvm-project/pull/143322
>From 0cb345f7a2241eb4f94036b2e9b8f55555807185 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Sun, 8 Jun 2025 21:08:14 +0200
Subject: [PATCH 1/4] [CIR] Implement folder for VecCmpOp
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 2 +
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 103 ++++++++
.../Dialect/Transforms/CIRCanonicalize.cpp | 4 +-
clang/test/CIR/Transforms/vector-cmp-fold.cir | 227 ++++++++++++++++++
4 files changed, 334 insertions(+), 2 deletions(-)
create mode 100644 clang/test/CIR/Transforms/vector-cmp-fold.cir
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 194153caa9271..119e5a5622fbc 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2155,6 +2155,8 @@ def VecCmpOp : CIR_Op<"vec.cmp", [Pure, SameTypeOperands]> {
`(` $kind `,` $lhs `,` $rhs `)` `:` qualified(type($lhs)) `,`
qualified(type($result)) attr-dict
}];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 8ed0ee92574dc..ca02e1ea6dc42 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1589,6 +1589,109 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
return elements[index];
}
+//===----------------------------------------------------------------------===//
+// VecCmpOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
+ mlir::Attribute lhs = adaptor.getLhs();
+ mlir::Attribute rhs = adaptor.getRhs();
+ if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
+ !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
+ return {};
+
+ auto lhsVecAttr = mlir::cast<cir::ConstVectorAttr>(lhs);
+ auto rhsVecAttr = mlir::cast<cir::ConstVectorAttr>(rhs);
+
+ auto inputElemTy =
+ mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
+ if (!mlir::isa<cir::IntType>(inputElemTy) &&
+ !mlir::isa<cir::CIRFPTypeInterface>(inputElemTy))
+ return {};
+
+ cir::CmpOpKind opKind = adaptor.getKind();
+ mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
+ mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
+ uint64_t vecSize = lhsVecElhs.size();
+
+ auto resultVecTy = mlir::cast<cir::VectorType>(getType());
+
+ SmallVector<mlir::Attribute, 16> elements(vecSize);
+ for (uint64_t i = 0; i < vecSize; i++) {
+ mlir::Attribute lhsAttr = lhsVecElhs[i];
+ mlir::Attribute rhsAttr = rhsVecElhs[i];
+
+ int cmpResult = 0;
+ switch (opKind) {
+ case cir::CmpOpKind::lt: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ case cir::CmpOpKind::le: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ case cir::CmpOpKind::gt: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ case cir::CmpOpKind::ge: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ case cir::CmpOpKind::eq: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ case cir::CmpOpKind::ne: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ }
+
+ elements[i] = cir::IntAttr::get(resultVecTy.getElementType(), cmpResult);
+ }
+
+ return cir::ConstVectorAttr::get(
+ getType(), mlir::ArrayAttr::get(getContext(), elements));
+}
+
//===----------------------------------------------------------------------===//
// VecShuffleOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 6f8a64ce0251e..99683e8d66290 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -142,8 +142,8 @@ void CIRCanonicalizePass::runOnOperation() {
// Many operations are here to perform a manual `fold` in
// applyOpPatternsGreedily.
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
- VecCreateOp, VecExtractOp, VecShuffleOp, VecShuffleDynamicOp,
- VecTernaryOp>(op))
+ VecCreateOp, VecCmpOp, VecExtractOp, VecShuffleOp,
+ VecShuffleDynamicOp, VecTernaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/Transforms/vector-cmp-fold.cir b/clang/test/CIR/Transforms/vector-cmp-fold.cir
new file mode 100644
index 0000000000000..b207fc08748e2
--- /dev/null
+++ b/clang/test/CIR/Transforms/vector-cmp-fold.cir
@@ -0,0 +1,227 @@
+// RUN: cir-opt %s -cir-canonicalize -o - -split-input-file | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(eq, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(ne, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(lt, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(le, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(gt, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(gt, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(eq, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(ne, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(lt, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(le, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(gt, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(ge, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
>From dfd6b65dbf55255ec9161abfe6a8931890b22475 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Mon, 9 Jun 2025 20:36:31 +0200
Subject: [PATCH 2/4] Address code review comments
---
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index ca02e1ea6dc42..0884c9a1963de 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1603,10 +1603,9 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
auto lhsVecAttr = mlir::cast<cir::ConstVectorAttr>(lhs);
auto rhsVecAttr = mlir::cast<cir::ConstVectorAttr>(rhs);
- auto inputElemTy =
+ mlir::Type inputElemTy =
mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
- if (!mlir::isa<cir::IntType>(inputElemTy) &&
- !mlir::isa<cir::CIRFPTypeInterface>(inputElemTy))
+ if (!isAnyIntegerOrFloatingPointType(inputElemTy))
return {};
cir::CmpOpKind opKind = adaptor.getKind();
>From ccaff3bcbc26405caa8fbb10db9ecf1f14289fd9 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Tue, 10 Jun 2025 20:54:10 +0200
Subject: [PATCH 3/4] Address code review comments
---
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 30 +++++++++++--------------
1 file changed, 13 insertions(+), 17 deletions(-)
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 0884c9a1963de..d9fe0f0d48ca6 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1594,15 +1594,13 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
- mlir::Attribute lhs = adaptor.getLhs();
- mlir::Attribute rhs = adaptor.getRhs();
- if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
- !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
+ auto lhsVecAttr =
+ mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
+ auto rhsVecAttr =
+ mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
+ if (!lhsVecAttr || !rhsVecAttr)
return {};
- auto lhsVecAttr = mlir::cast<cir::ConstVectorAttr>(lhs);
- auto rhsVecAttr = mlir::cast<cir::ConstVectorAttr>(rhs);
-
mlir::Type inputElemTy =
mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
if (!isAnyIntegerOrFloatingPointType(inputElemTy))
@@ -1613,17 +1611,15 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
uint64_t vecSize = lhsVecElhs.size();
- auto resultVecTy = mlir::cast<cir::VectorType>(getType());
-
SmallVector<mlir::Attribute, 16> elements(vecSize);
+ bool isIntAttr = vecSize ? mlir::isa<cir::IntAttr>(lhsVecElhs[0]) : false;
for (uint64_t i = 0; i < vecSize; i++) {
mlir::Attribute lhsAttr = lhsVecElhs[i];
mlir::Attribute rhsAttr = rhsVecElhs[i];
-
int cmpResult = 0;
switch (opKind) {
case cir::CmpOpKind::lt: {
- if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
@@ -1633,7 +1629,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
break;
}
case cir::CmpOpKind::le: {
- if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
@@ -1643,7 +1639,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
break;
}
case cir::CmpOpKind::gt: {
- if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
@@ -1653,7 +1649,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
break;
}
case cir::CmpOpKind::ge: {
- if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
@@ -1663,7 +1659,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
break;
}
case cir::CmpOpKind::eq: {
- if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
@@ -1673,7 +1669,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
break;
}
case cir::CmpOpKind::ne: {
- if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ if (isIntAttr) {
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
} else {
@@ -1684,7 +1680,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
}
}
- elements[i] = cir::IntAttr::get(resultVecTy.getElementType(), cmpResult);
+ elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
}
return cir::ConstVectorAttr::get(
>From 6560137c1900ea14602d1f16f02b736178aa0d52 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Tue, 10 Jun 2025 22:32:34 +0200
Subject: [PATCH 4/4] Address code review comment
---
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index d9fe0f0d48ca6..bf12d26c15f4e 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1612,7 +1612,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
uint64_t vecSize = lhsVecElhs.size();
SmallVector<mlir::Attribute, 16> elements(vecSize);
- bool isIntAttr = vecSize ? mlir::isa<cir::IntAttr>(lhsVecElhs[0]) : false;
+ bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]);
for (uint64_t i = 0; i < vecSize; i++) {
mlir::Attribute lhsAttr = lhsVecElhs[i];
mlir::Attribute rhsAttr = rhsVecElhs[i];
More information about the cfe-commits
mailing list