[clang] [CIR] Implement folder for VecTernaryOp (PR #142946)
Amr Hesham via cfe-commits
cfe-commits at lists.llvm.org
Thu Jun 5 14:28:45 PDT 2025
https://github.com/AmrDeveloper updated https://github.com/llvm/llvm-project/pull/142946
>From ac8277b48d0affa78f5e5e943e0179c27dd033ec Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Thu, 5 Jun 2025 13:08:57 +0200
Subject: [PATCH 1/2] [CIR] Implement folder for VecTernaryOp
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 2 ++
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 35 +++++++++++++++++++
.../Dialect/Transforms/CIRCanonicalize.cpp | 6 ++--
.../CIR/Transforms/vector-ternary-fold.cir | 20 +++++++++++
4 files changed, 60 insertions(+), 3 deletions(-)
create mode 100644 clang/test/CIR/Transforms/vector-ternary-fold.cir
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 00878f7dd8ed7..eb439f7aa1527 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2228,7 +2228,9 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
`(` $cond `,` $lhs`,` $rhs `)` `:` qualified(type($cond)) `,`
qualified(type($lhs)) attr-dict
}];
+
let hasVerifier = 1;
+ let hasFolder = 1;
}
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index fa7fb592a3cd6..f585254d3340b 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1638,6 +1638,41 @@ LogicalResult cir::VecTernaryOp::verify() {
return success();
}
+OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
+ mlir::Attribute cond = adaptor.getCond();
+ mlir::Attribute lhs = adaptor.getLhs();
+ mlir::Attribute rhs = adaptor.getRhs();
+
+ if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) &&
+ mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) &&
+ mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) {
+ auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
+ auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
+ auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
+
+ mlir::ArrayAttr condElts = condVec.getElts();
+
+ SmallVector<mlir::Attribute, 16> elements;
+ elements.reserve(condElts.size());
+
+ for (const auto &[idx, condAttr] :
+ llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
+ if (condAttr.getSInt()) {
+ elements.push_back(lhsVec.getElts()[idx]);
+ continue;
+ }
+
+ elements.push_back(rhsVec.getElts()[idx]);
+ }
+
+ cir::VectorType vecTy = getLhs().getType();
+ return cir::ConstVectorAttr::get(
+ vecTy, mlir::ArrayAttr::get(getContext(), elements));
+ }
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 7d03e374c27e8..aa3e97033cdda 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -138,10 +138,10 @@ void CIRCanonicalizePass::runOnOperation() {
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());
- // CastOp, UnaryOp, VecExtractOp and VecShuffleDynamicOp are here to perform
- // a manual `fold` in applyOpPatternsGreedily.
+ // CastOp, UnaryOp, VecExtractOp, VecShuffleDynamicOp and VecTernaryOp are
+ // here to perform a manual `fold` in applyOpPatternsGreedily.
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
- VecExtractOp, VecShuffleDynamicOp>(op))
+ VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/Transforms/vector-ternary-fold.cir b/clang/test/CIR/Transforms/vector-ternary-fold.cir
new file mode 100644
index 0000000000000..f2e18576da74b
--- /dev/null
+++ b/clang/test/CIR/Transforms/vector-ternary-fold.cir
@@ -0,0 +1,20 @@
+// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> {
+ %cond = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ %lhs = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
+ %rhs = cir.const #cir.const_vector<[#cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %res = cir.vec.ternary(%cond, %lhs, %rhs) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %res : !cir.vector<4 x !s32i>
+ }
+
+ // [1, 0, 1, 0] ? [1, 2, 3, 4] : [5, 6, 7, 8] Will be fold to [1, 6, 3, 8]
+ // CHECK: cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<6> : !s32i, #cir.int<3> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+
>From 9263706f56b05cfc9c106355e65de9b2886d2e98 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Thu, 5 Jun 2025 23:28:09 +0200
Subject: [PATCH 2/2] Address code review comments
---
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 39 +++++++++----------
.../Dialect/Transforms/CIRCanonicalize.cpp | 11 ++++--
2 files changed, 25 insertions(+), 25 deletions(-)
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index f585254d3340b..b353a32ea07a8 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1643,34 +1643,31 @@ OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
mlir::Attribute lhs = adaptor.getLhs();
mlir::Attribute rhs = adaptor.getRhs();
- if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) &&
- mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) &&
- mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) {
- auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
- auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
- auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
-
- mlir::ArrayAttr condElts = condVec.getElts();
+ if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
+ !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
+ !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
+ return {};
+ auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
+ auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
+ auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
- SmallVector<mlir::Attribute, 16> elements;
- elements.reserve(condElts.size());
+ mlir::ArrayAttr condElts = condVec.getElts();
- for (const auto &[idx, condAttr] :
- llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
- if (condAttr.getSInt()) {
- elements.push_back(lhsVec.getElts()[idx]);
- continue;
- }
+ SmallVector<mlir::Attribute, 16> elements;
+ elements.reserve(condElts.size());
+ for (const auto &[idx, condAttr] :
+ llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
+ if (condAttr.getSInt()) {
+ elements.push_back(lhsVec.getElts()[idx]);
+ } else {
elements.push_back(rhsVec.getElts()[idx]);
}
-
- cir::VectorType vecTy = getLhs().getType();
- return cir::ConstVectorAttr::get(
- vecTy, mlir::ArrayAttr::get(getContext(), elements));
}
- return {};
+ cir::VectorType vecTy = getLhs().getType();
+ return cir::ConstVectorAttr::get(
+ vecTy, mlir::ArrayAttr::get(getContext(), elements));
}
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index aa3e97033cdda..ccd808116dbbf 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -138,10 +138,13 @@ void CIRCanonicalizePass::runOnOperation() {
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());
- // CastOp, UnaryOp, VecExtractOp, VecShuffleDynamicOp and VecTernaryOp are
- // here to perform a manual `fold` in applyOpPatternsGreedily.
- if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
- VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op))
+
+ if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, SelectOp>(op))
+ ops.push_back(op);
+
+ // Operations to perform manual `fold` in applyOpPatternsGreedily.
+ if (isa<CastOp, UnaryOp, VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(
+ op))
ops.push_back(op);
});
More information about the cfe-commits
mailing list