[clang] [CIR] Implement cir.cmp3way Operation (PR #186294)
via cfe-commits
cfe-commits at lists.llvm.org
Thu Mar 12 19:20:49 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Yeongu Choe (YeonguChoe)
<details>
<summary>Changes</summary>
A three-way comparison operation is an operation that determines ordering of two numbers in a single operation. I implemented it as specified in the CIR documentation. Also I used `clang-format` for code formatting.
Reference
- https://llvm.github.io/clangir/Dialect/ops.html#circmp3way-circmpthreewayop
- https://en.wikipedia.org/wiki/Three-way_comparison
- https://clang.llvm.org/docs/ClangFormatStyleOptions.html
---
Full diff: https://github.com/llvm/llvm-project/pull/186294.diff
6 Files Affected:
- (modified) clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h (+7)
- (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+34)
- (modified) clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp (+42-4)
- (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+82)
- (added) clang/test/CIR/CodeGenCXX/three-way-comparison.cpp (+22)
- (added) clang/test/CIR/Lowering/cmp3way.cir (+32)
``````````diff
diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index e60288c40132f..2a81da408121c 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -669,6 +669,13 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return cir::CmpOp::create(*this, loc, kind, lhs, rhs);
}
+ cir::CmpThreeWayOp createThreeWayComparison(mlir::Location loc,
+ mlir::Type resultTy,
+ mlir::Value lhs, mlir::Value rhs,
+ mlir::Attribute info) {
+ return cir::CmpThreeWayOp::create(*this, loc, resultTy, lhs, rhs, info);
+ }
+
cir::VecCmpOp createVecCompare(mlir::Location loc, cir::CmpOpKind kind,
mlir::Value lhs, mlir::Value rhs) {
VectorType vecCast = mlir::cast<VectorType>(lhs.getType());
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index a9b98b1f43b3f..24758a0c946f6 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2071,6 +2071,40 @@ def CIR_CmpOp : CIR_Op<"cmp", [Pure, SameTypeOperands]> {
let hasCXXABILowering = true;
}
+//===----------------------------------------------------------------------===//
+// CmpThreeWayOp
+//===----------------------------------------------------------------------===//
+
+def CIR_CmpThreeWayStrongInfoAttr
+ : CIR_Attr<"CmpThreeWayStrongInfo", "cmp3way_strong_info"> {
+ let parameters = (ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt);
+ let assemblyFormat =
+ [{`<` `strong` `,` `lt` `=` $lt `,` `eq` `=` $eq `,` `gt` `=` $gt `>`}];
+}
+def CIR_CmpThreeWayPartialInfoAttr
+ : CIR_Attr<"CmpThreeWayPartialInfo", "cmp3way_partial_info"> {
+ let parameters = (ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt,
+ "int64_t":$unordered);
+ let assemblyFormat =
+ [{`<` `partial` `,` `lt` `=` $lt `,` `eq` `=` $eq `,` `gt` `=` $gt `,` `unordered` `=` $unordered `>` }];
+}
+def CIR_CmpThreeWayInfoAttr : AnyAttrOf<[CIR_CmpThreeWayStrongInfoAttr,
+ CIR_CmpThreeWayPartialInfoAttr]>;
+
+def CIR_CmpThreeWayOp
+ : CIR_Op<"cmp3way", [Pure, SameTypeOperands, ConditionallySpeculatable]> {
+ let summary = "Performs three-way comparison.";
+ let description = [{
+ Three-way comparison takes two operands of the same type and determines ordering.
+ }];
+ let arguments = (ins CIR_AnyType:$lhs, CIR_AnyType:$rhs,
+ CIR_CmpThreeWayInfoAttr:$info);
+ let results = (outs CIR_AnySIntType:$result);
+ let assemblyFormat = [{
+ `(` $lhs `:` type($lhs) `,` $rhs `,` qualified($info) `)` `:` type($result) attr-dict
+ }];
+}
+
//===----------------------------------------------------------------------===//
// BinOpOverflowOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
index 9f390fec97613..19e844664f53c 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
@@ -301,9 +301,7 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
cgf.cgm.errorNYI(e->getSourceRange(),
"AggExprEmitter: VisitSubstNonTypeTemplateParmExpr");
}
- void VisitConstantExpr(ConstantExpr *e) {
- cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitConstantExpr");
- }
+ void VisitConstantExpr(ConstantExpr *e) { return Visit(e->getSubExpr()); }
void VisitMemberExpr(MemberExpr *e) { emitAggLoadOfLValue(e); }
void VisitUnaryDeref(UnaryOperator *e) { emitAggLoadOfLValue(e); }
void VisitStringLiteral(StringLiteral *e) { emitAggLoadOfLValue(e); }
@@ -326,7 +324,47 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
Visit(e->getRHS());
}
void VisitBinCmp(const BinaryOperator *e) {
- cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitBinCmp");
+ const ComparisonCategoryInfo &CompCategoryInfo =
+ cgf.getContext().CompCategories.getInfoForType(e->getType());
+
+ QualType ArgTy = e->getLHS()->getType();
+ if (ArgTy->isIntegralOrEnumerationType() || ArgTy->isRealFloatingType() ||
+ ArgTy->isNullPtrType() || ArgTy->isPointerType() ||
+ ArgTy->isMemberPointerType()) {
+ mlir::Value lhs = cgf.emitScalarExpr(e->getLHS());
+ mlir::Value rhs = cgf.emitScalarExpr(e->getRHS());
+
+ mlir::Attribute info;
+ if (CompCategoryInfo.isStrong()) {
+ info = cir::CmpThreeWayStrongInfoAttr::get(
+ cgf.getBuilder().getContext(),
+ CompCategoryInfo.getLess()->getIntValue().getSExtValue(),
+ CompCategoryInfo.getEqualOrEquiv()->getIntValue().getSExtValue(),
+ CompCategoryInfo.getGreater()->getIntValue().getSExtValue());
+ } else {
+ info = cir::CmpThreeWayPartialInfoAttr::get(
+ cgf.getBuilder().getContext(),
+ CompCategoryInfo.getLess()->getIntValue().getSExtValue(),
+ CompCategoryInfo.getEqualOrEquiv()->getIntValue().getSExtValue(),
+ CompCategoryInfo.getGreater()->getIntValue().getSExtValue(),
+ CompCategoryInfo.getUnordered()->getIntValue().getSExtValue());
+ }
+ mlir::Type resultTy = cgf.convertType(cgf.getContext().IntTy);
+ mlir::Value result = cgf.getBuilder().createThreeWayComparison(
+ cgf.getLoc(e->getSourceRange()), resultTy, lhs, rhs, info);
+
+ ensureDest(cgf.getLoc(e->getSourceRange()), e->getType());
+ LValue destLValue = cgf.makeAddrLValue(dest.getAddress(), e->getType());
+
+ const FieldDecl *field = *CompCategoryInfo.Record->field_begin();
+ LValue fieldLValue = cgf.emitLValueForFieldInitialization(
+ destLValue, field, field->getName());
+ cgf.emitStoreThroughLValue(RValue::get(result), fieldLValue, true);
+ } else {
+ cgf.cgm.errorNYI(e->getSourceRange(),
+ "AggExprEmitter: unsupported operand type");
+ return;
+ }
}
void VisitCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *e) {
cgf.cgm.errorNYI(e->getSourceRange(),
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 9c68248d5dede..53086ecb4b669 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -3155,6 +3155,88 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
}
+mlir::LogicalResult CIRToLLVMCmpThreeWayOpLowering::matchAndRewrite(
+ cir::CmpThreeWayOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ mlir::Location loc = op.getLoc();
+ auto info = op.getInfo();
+ mlir::Type resultTy = getTypeConverter()->convertType(op.getType());
+ mlir::Value lhs = adaptor.getLhs();
+ mlir::Value rhs = adaptor.getRhs();
+ mlir::Type operandTy = lhs.getType();
+
+ mlir::Value ltValue, eqValue, gtValue, unorderedValue;
+ if (auto strongInfo = mlir::dyn_cast<cir::CmpThreeWayStrongInfoAttr>(info)) {
+ ltValue = mlir::LLVM::ConstantOp::create(
+ rewriter, loc, resultTy,
+ rewriter.getI64IntegerAttr(strongInfo.getLt()));
+ eqValue = mlir::LLVM::ConstantOp::create(
+ rewriter, loc, resultTy,
+ rewriter.getI64IntegerAttr(strongInfo.getEq()));
+ gtValue = mlir::LLVM::ConstantOp::create(
+ rewriter, loc, resultTy,
+ rewriter.getI64IntegerAttr(strongInfo.getGt()));
+ } else if (auto partialInfo =
+ mlir::dyn_cast<cir::CmpThreeWayPartialInfoAttr>(info)) {
+ ltValue = mlir::LLVM::ConstantOp::create(
+ rewriter, loc, resultTy,
+ rewriter.getI64IntegerAttr(partialInfo.getLt()));
+ eqValue = mlir::LLVM::ConstantOp::create(
+ rewriter, loc, resultTy,
+ rewriter.getI64IntegerAttr(partialInfo.getEq()));
+ gtValue = mlir::LLVM::ConstantOp::create(
+ rewriter, loc, resultTy,
+ rewriter.getI64IntegerAttr(partialInfo.getGt()));
+ unorderedValue = mlir::LLVM::ConstantOp::create(
+ rewriter, loc, resultTy,
+ rewriter.getI64IntegerAttr(partialInfo.getUnordered()));
+ } else {
+ return op.emitError("unsupported comparison info attribute");
+ }
+
+ if (mlir::isa<mlir::IntegerType>(operandTy)) {
+ bool isSigned = true;
+ if (auto cirIntTy = mlir::dyn_cast<cir::IntType>(op.getLhs().getType())) {
+ isSigned = cirIntTy.isSigned();
+ }
+ auto ltPred = isSigned ? mlir::LLVM::ICmpPredicate::slt
+ : mlir::LLVM::ICmpPredicate::ult;
+
+ mlir::Value ltCmp =
+ mlir::LLVM::ICmpOp::create(rewriter, loc, ltPred, lhs, rhs);
+ mlir::Value eqCmp = mlir::LLVM::ICmpOp::create(
+ rewriter, loc, mlir::LLVM::ICmpPredicate::eq, lhs, rhs);
+
+ mlir::Value result = mlir::LLVM::SelectOp::create(
+ rewriter, loc, ltCmp, ltValue,
+ mlir::LLVM::SelectOp::create(rewriter, loc, eqCmp, eqValue, gtValue));
+ rewriter.replaceOp(op, result);
+ return mlir::success();
+ } else if (mlir::isa<mlir::FloatType>(operandTy)) {
+ if (!unorderedValue) {
+ return op.emitError("strong ordering not supported for float operands");
+ }
+
+ mlir::Value ltCmp = mlir::LLVM::FCmpOp::create(
+ rewriter, loc, mlir::LLVM::FCmpPredicate::olt, lhs, rhs);
+ mlir::Value eqCmp = mlir::LLVM::FCmpOp::create(
+ rewriter, loc, mlir::LLVM::FCmpPredicate::oeq, lhs, rhs);
+ mlir::Value orderedResult = mlir::LLVM::SelectOp::create(
+ rewriter, loc, ltCmp, ltValue,
+ mlir::LLVM::SelectOp::create(rewriter, loc, eqCmp, eqValue, gtValue));
+
+ mlir::Value unorderedCmp = mlir::LLVM::FCmpOp::create(
+ rewriter, loc, mlir::LLVM::FCmpPredicate::uno, lhs, rhs);
+
+ mlir::Value result = mlir::LLVM::SelectOp::create(
+ rewriter, loc, unorderedCmp, unorderedValue, orderedResult);
+ rewriter.replaceOp(op, result);
+ return mlir::success();
+ } else {
+ return op.emitError("unsupported operand type for three-way comparison");
+ }
+}
+
mlir::LogicalResult CIRToLLVMBinOpOverflowOpLowering::matchAndRewrite(
cir::BinOpOverflowOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/test/CIR/CodeGenCXX/three-way-comparison.cpp b/clang/test/CIR/CodeGenCXX/three-way-comparison.cpp
new file mode 100644
index 0000000000000..80c408597f0d2
--- /dev/null
+++ b/clang/test/CIR/CodeGenCXX/three-way-comparison.cpp
@@ -0,0 +1,22 @@
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s
+
+#include "Inputs/std-compare.h"
+
+int test_int_spaceship(int a, int b) {
+ auto result = a <=> b;
+ // CHECK: cir.cmp3way(%{{.*}} : !s32i, %{{.*}}, #cir.cmp3way_strong_info<strong, lt = -1, eq = 0, gt = 1>) : !s32i
+ return (result < 0) ? -1 : (result > 0) ? 1 : 0;
+}
+
+unsigned int test_uint_spaceship(unsigned int a, unsigned int b) {
+ auto result = a <=> b;
+ // CHECK: cir.cmp3way(%{{.*}} : !u32i, %{{.*}}, #cir.cmp3way_strong_info<strong, lt = -1, eq = 0, gt = 1>) : !s32i
+ return (result < 0) ? 0 : (result > 0) ? 2 : 1;
+}
+
+float test_float_spaceship(float a, float b) {
+ auto result = a <=> b;
+ // CHECK: cir.cmp3way(%{{.*}} : !cir.float, %{{.*}}, #cir.cmp3way_partial_info<partial, lt = -1, eq = 0, gt = 1, unordered = -127>)
+ return (result < 0) ? -1.0f : (result > 0) ? 1.0f : 0.0f;
+}
\ No newline at end of file
diff --git a/clang/test/CIR/Lowering/cmp3way.cir b/clang/test/CIR/Lowering/cmp3way.cir
new file mode 100644
index 0000000000000..fe3740b6fd30d
--- /dev/null
+++ b/clang/test/CIR/Lowering/cmp3way.cir
@@ -0,0 +1,32 @@
+// RUN: cir-opt %s -cir-to-llvm -o %t.mlir
+// RUN: FileCheck --input-file=%t.mlir %s
+
+!s32i = !cir.int<s, 32>
+!u32i = !cir.int<u, 32>
+module {
+ cir.func @test_signed() -> !s32i {
+ %0 = cir.const #cir.int<5> : !s32i
+ %1 = cir.const #cir.int<3> : !s32i
+ %2 = cir.cmp3way(%0 : !s32i, %1, #cir.cmp3way_strong_info<strong, lt = -1, eq = 0, gt = 1>) : !s32i
+ // CHECK: llvm.icmp "slt"
+ cir.return %2 : !s32i
+ }
+
+ cir.func @test_unsigned() -> !s32i {
+ %0 = cir.const #cir.int<5> : !u32i
+ %1 = cir.const #cir.int<3> : !u32i
+ %2 = cir.cmp3way(%0 : !u32i, %1, #cir.cmp3way_strong_info<strong, lt = -1, eq = 0, gt = 1>) : !s32i
+ // CHECK: llvm.icmp "ult"
+ cir.return %2 : !s32i
+ }
+
+ cir.func @test_float() -> !s32i {
+ %0 = cir.const #cir.fp<1.5> : !cir.float
+ %1 = cir.const #cir.fp<2.5> : !cir.float
+ %2 = cir.cmp3way(%0 : !cir.float, %1, #cir.cmp3way_partial_info<partial, lt = -1, eq = 0, gt = 1, unordered = 2>) : !s32i
+ // CHECK: llvm.fcmp "olt"
+ // CHECK: llvm.fcmp "oeq"
+ // CHECK: llvm.fcmp "uno"
+ cir.return %2 : !s32i
+ }
+}
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/186294
More information about the cfe-commits
mailing list