[clang] [CIR] Upstream TernaryOp for VectorType (PR #142393)
Amr Hesham via cfe-commits
cfe-commits at lists.llvm.org
Mon Jun 2 13:48:07 PDT 2025
https://github.com/AmrDeveloper updated https://github.com/llvm/llvm-project/pull/142393
>From 030af8ea55d123d4b32d6a935c6288ea76973897 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Mon, 2 Jun 2025 15:11:17 +0200
Subject: [PATCH 1/2] [CIR] Upstream TernaryOp for VectorType
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 36 +++++++++++++++++++
clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp | 30 ++++++++++++++++
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 18 ++++++++++
.../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 17 ++++++++-
.../CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 10 ++++++
clang/test/CIR/CodeGen/vector-ext.cpp | 15 ++++++++
clang/test/CIR/CodeGen/vector.cpp | 17 ++++++++-
7 files changed, 141 insertions(+), 2 deletions(-)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 237daed32532a..d6a9bda0c04ea 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2194,4 +2194,40 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// VecTernaryOp
+//===----------------------------------------------------------------------===//
+
+def VecTernaryOp : CIR_Op<"vec.ternary",
+ [Pure, AllTypesMatch<["result", "vec1", "vec2"]>]> {
+ let summary = "The `cond ? a : b` ternary operator for vector types";
+ let description = [{
+ The `cir.vec.ternary` operation represents the C/C++ ternary operator,
+ `?:`, for vector types, which does a `select` on individual elements of the
+ vectors. Unlike a regular `?:` operator, there is no short circuiting. All
+ three arguments are always evaluated. Because there is no short
+ circuiting, there are no regions in this operation, unlike cir.ternary.
+
+ The first argument is a vector of integral type. The second and third
+ arguments are vectors of the same type and have the same number of elements
+ as the first argument.
+
+ The result is a vector of the same type as the second and third arguments.
+ Each element of the result is `(bool)a[n] ? b[n] : c[n]`.
+ }];
+
+ let arguments = (ins
+ IntegerVector:$cond,
+ CIR_VectorType:$vec1,
+ CIR_VectorType:$vec2
+ );
+
+ let results = (outs CIR_VectorType:$result);
+ let assemblyFormat = [{
+ `(` $cond `,` $vec1 `,` $vec2 `)` `:` qualified(type($cond)) `,`
+ qualified(type($vec1)) attr-dict
+ }];
+ let hasVerifier = 1;
+}
+
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index b33bb71c99c90..94e331d4ef652 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -193,6 +193,36 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
e->getSourceRange().getBegin());
}
+ mlir::Value
+ VisitAbstractConditionalOperator(const AbstractConditionalOperator *e) {
+ mlir::Location loc = cgf.getLoc(e->getSourceRange());
+ Expr *condExpr = e->getCond();
+ Expr *lhsExpr = e->getTrueExpr();
+ Expr *rhsExpr = e->getFalseExpr();
+
+ // OpenCL: If the condition is a vector, we can treat this condition like
+ // the select function.
+ if ((cgf.getLangOpts().OpenCL && condExpr->getType()->isVectorType()) ||
+ condExpr->getType()->isExtVectorType()) {
+ cgf.getCIRGenModule().errorNYI(loc,
+ "TernaryOp OpenCL VectorType condition");
+ return {};
+ }
+
+ if (condExpr->getType()->isVectorType() ||
+ condExpr->getType()->isSveVLSBuiltinType()) {
+ assert(condExpr->getType()->isVectorType() && "?: op for SVE vector NYI");
+ mlir::Value condValue = Visit(condExpr);
+ mlir::Value lhsValue = Visit(lhsExpr);
+ mlir::Value rhsValue = Visit(rhsExpr);
+ return builder.create<cir::VecTernaryOp>(loc, condValue, lhsValue,
+ rhsValue);
+ }
+
+ cgf.getCIRGenModule().errorNYI(loc, "TernaryOp for non vector types");
+ return {};
+ }
+
mlir::Value VisitMemberExpr(MemberExpr *e);
mlir::Value VisitInitListExpr(InitListExpr *e);
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 36f050de9f8bb..1236c455304a9 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1589,6 +1589,24 @@ LogicalResult cir::VecShuffleDynamicOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// VecTernaryOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult cir::VecTernaryOp::verify() {
+ // Verify that the condition operand has the same number of elements as the
+ // other operands. (The automatic verification already checked that all
+ // operands are vector types and that the second and third operands are the
+ // same type.)
+ if (mlir::cast<cir::VectorType>(getCond().getType()).getSize() !=
+ getVec1().getType().getSize()) {
+ return emitOpError() << ": the number of elements in "
+ << getCond().getType() << " and "
+ << getVec1().getType() << " don't match";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index b07e61638c3b4..e5a26260dc8cc 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1708,7 +1708,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMVecExtractOpLowering,
CIRToLLVMVecInsertOpLowering,
CIRToLLVMVecCmpOpLowering,
- CIRToLLVMVecShuffleDynamicOpLowering
+ CIRToLLVMVecShuffleDynamicOpLowering,
+ CIRToLLVMVecTernaryOpLowering
// clang-format on
>(converter, patterns.getContext());
@@ -1916,6 +1917,20 @@ mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
return mlir::success();
}
+mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
+ cir::VecTernaryOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ // Convert `cond` into a vector of i1, then use that in a `select` op.
+ mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>(
+ op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(),
+ rewriter.create<mlir::LLVM::ZeroOp>(
+ op.getCond().getLoc(),
+ typeConverter->convertType(op.getCond().getType())));
+ rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
+ op, bitVec, adaptor.getVec1(), adaptor.getVec2());
+ return mlir::success();
+}
+
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index 6b8862db2c8be..ed369ff15a710 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -363,6 +363,16 @@ class CIRToLLVMVecShuffleDynamicOpLowering
mlir::ConversionPatternRewriter &) const override;
};
+class CIRToLLVMVecTernaryOpLowering
+ : public mlir::OpConversionPattern<cir::VecTernaryOp> {
+public:
+ using mlir::OpConversionPattern<cir::VecTernaryOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::VecTernaryOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &) const override;
+};
+
} // namespace direct
} // namespace cir
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index 8a0479fc1d088..53258845c2169 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -1091,3 +1091,18 @@ void foo17() {
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
+
+void foo20() {
+ vi4 a;
+ vi4 b;
+ vi4 c;
+ vi4 r = c ? a : b;
+}
+
+// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+
+// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
+
+// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 4c50f68a56162..49f142d110a81 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -1069,4 +1069,19 @@ void foo17() {
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
-// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
\ No newline at end of file
+// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
+
+void foo20() {
+ vi4 a;
+ vi4 b;
+ vi4 c;
+ vi4 r = c ? a : b;
+}
+
+// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+
+// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
+
+// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
>From 4aa56e8e4b4dfcf5835c0d35d99462a352efbeae Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Mon, 2 Jun 2025 22:47:35 +0200
Subject: [PATCH 2/2] Update with main & Address code review comments
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 2 +-
clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp | 20 ++++++++++++--------
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 3 +--
clang/test/CIR/CodeGen/vector-ext.cpp | 15 ---------------
4 files changed, 14 insertions(+), 26 deletions(-)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index d6a9bda0c04ea..746583985d4c5 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2217,7 +2217,7 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
}];
let arguments = (ins
- IntegerVector:$cond,
+ CIR_VectorOfIntType:$cond,
CIR_VectorType:$vec1,
CIR_VectorType:$vec2
);
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 94e331d4ef652..fd2718c57dbca 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -200,18 +200,22 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
Expr *lhsExpr = e->getTrueExpr();
Expr *rhsExpr = e->getFalseExpr();
+ QualType condType = condExpr->getType();
+
// OpenCL: If the condition is a vector, we can treat this condition like
// the select function.
- if ((cgf.getLangOpts().OpenCL && condExpr->getType()->isVectorType()) ||
- condExpr->getType()->isExtVectorType()) {
- cgf.getCIRGenModule().errorNYI(loc,
- "TernaryOp OpenCL VectorType condition");
+ if ((cgf.getLangOpts().OpenCL && condType->isVectorType()) ||
+ condType->isExtVectorType()) {
+ cgf.cgm.errorNYI(loc, "TernaryOp OpenCL VectorType condition");
return {};
}
- if (condExpr->getType()->isVectorType() ||
- condExpr->getType()->isSveVLSBuiltinType()) {
- assert(condExpr->getType()->isVectorType() && "?: op for SVE vector NYI");
+ if (condType->isVectorType() || condType->isSveVLSBuiltinType()) {
+ if (!condExpr->getType()->isVectorType()) {
+ cgf.cgm.errorNYI(loc, "TernaryOp for SVE vector");
+ return {};
+ }
+
mlir::Value condValue = Visit(condExpr);
mlir::Value lhsValue = Visit(lhsExpr);
mlir::Value rhsValue = Visit(rhsExpr);
@@ -219,7 +223,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
rhsValue);
}
- cgf.getCIRGenModule().errorNYI(loc, "TernaryOp for non vector types");
+ cgf.cgm.errorNYI(loc, "TernaryOp for non vector types");
return {};
}
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 1236c455304a9..7dd4a8ec0c7ef 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1598,8 +1598,7 @@ LogicalResult cir::VecTernaryOp::verify() {
// other operands. (The automatic verification already checked that all
// operands are vector types and that the second and third operands are the
// same type.)
- if (mlir::cast<cir::VectorType>(getCond().getType()).getSize() !=
- getVec1().getType().getSize()) {
+ if (getCond().getType().getSize() != getVec1().getType().getSize()) {
return emitOpError() << ": the number of elements in "
<< getCond().getType() << " and "
<< getVec1().getType() << " don't match";
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index 53258845c2169..8a0479fc1d088 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -1091,18 +1091,3 @@ void foo17() {
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
-
-void foo20() {
- vi4 a;
- vi4 b;
- vi4 c;
- vi4 r = c ? a : b;
-}
-
-// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
-
-// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
-// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
-
-// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
-// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
More information about the cfe-commits
mailing list