[clang] [CIR] Upstream TernaryOp (PR #137184)
Morris Hafner via cfe-commits
cfe-commits at lists.llvm.org
Thu Apr 24 08:08:33 PDT 2025
https://github.com/mmha updated https://github.com/llvm/llvm-project/pull/137184
>From 1eed90e3859c2ad8d703708f89976cad8f0faeec Mon Sep 17 00:00:00 2001
From: Morris Hafner <mhafner at nvidia.com>
Date: Thu, 24 Apr 2025 16:12:37 +0200
Subject: [PATCH 1/2] [CIR] Upstream TernaryOp
This patch adds TernaryOp to CIR plus a pass that flattens the operator in FlattenCFG.
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 57 +++++++++++++++-
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 42 ++++++++++++
.../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 60 ++++++++++++++--
clang/test/CIR/IR/ternary.cir | 30 ++++++++
clang/test/CIR/Lowering/ternary.cir | 30 ++++++++
clang/test/CIR/Transforms/ternary.cir | 68 +++++++++++++++++++
6 files changed, 280 insertions(+), 7 deletions(-)
create mode 100644 clang/test/CIR/IR/ternary.cir
create mode 100644 clang/test/CIR/Lowering/ternary.cir
create mode 100644 clang/test/CIR/Transforms/ternary.cir
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 81b447f31feca..76ad5c3666c1b 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -609,8 +609,8 @@ def ConditionOp : CIR_Op<"condition", [
//===----------------------------------------------------------------------===//
def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
- ParentOneOf<["IfOp", "ScopeOp", "WhileOp",
- "ForOp", "DoWhileOp"]>]> {
+ ParentOneOf<["IfOp", "TernaryOp", "ScopeOp",
+ "WhileOp", "ForOp", "DoWhileOp"]>]> {
let summary = "Represents the default branching behaviour of a region";
let description = [{
The `cir.yield` operation terminates regions on different CIR operations,
@@ -1246,6 +1246,59 @@ def SelectOp : CIR_Op<"select", [Pure,
}];
}
+//===----------------------------------------------------------------------===//
+// TernaryOp
+//===----------------------------------------------------------------------===//
+
+def TernaryOp : CIR_Op<"ternary",
+ [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> {
+ let summary = "The `cond ? a : b` C/C++ ternary operation";
+ let description = [{
+ The `cir.ternary` operation represents C/C++ ternary, much like a `select`
+ operation. The first argument is a `cir.bool` condition to evaluate, followed
+ by two regions to execute (true or false). This is different from `cir.if`
+ since each region is one block sized and the `cir.yield` closing the block
+ scope should have one argument.
+
+ Example:
+
+ ```mlir
+ // x = cond ? a : b;
+
+ %x = cir.ternary (%cond, true_region {
+ ...
+ cir.yield %a : i32
+ }, false_region {
+ ...
+ cir.yield %b : i32
+ }) -> i32
+ ```
+ }];
+ let arguments = (ins CIR_BoolType:$cond);
+ let regions = (region AnyRegion:$trueRegion,
+ AnyRegion:$falseRegion);
+ let results = (outs Optional<CIR_AnyType>:$result);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "mlir::Value":$cond,
+ "llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$trueBuilder,
+ "llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$falseBuilder)
+ >
+ ];
+
+ // All constraints already verified elsewhere.
+ let hasVerifier = 0;
+
+ let assemblyFormat = [{
+ `(` $cond `,`
+ `true` $trueRegion `,`
+ `false` $falseRegion
+ `)` `:` functional-type(operands, results) attr-dict
+ }];
+}
+
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 89daf20c5f478..e80d243cb396f 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1058,6 +1058,48 @@ LogicalResult cir::BinOp::verify() {
return mlir::success();
}
+//===----------------------------------------------------------------------===//
+// TernaryOp
+//===----------------------------------------------------------------------===//
+
+/// Given the region at `index`, or the parent operation if `index` is None,
+/// return the successor regions. These are the regions that may be selected
+/// during the flow of control. `operands` is a set of optional attributes that
+/// correspond to a constant value for each operand, or null if that operand is
+/// not a constant.
+void cir::TernaryOp::getSuccessorRegions(
+ mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ // The `true` and the `false` region branch back to the parent operation.
+ if (!point.isParent()) {
+ regions.push_back(RegionSuccessor(this->getODSResults(0)));
+ return;
+ }
+
+ // If the condition isn't constant, both regions may be executed.
+ regions.push_back(RegionSuccessor(&getTrueRegion()));
+ regions.push_back(RegionSuccessor(&getFalseRegion()));
+}
+
+void cir::TernaryOp::build(
+ OpBuilder &builder, OperationState &result, Value cond,
+ function_ref<void(OpBuilder &, Location)> trueBuilder,
+ function_ref<void(OpBuilder &, Location)> falseBuilder) {
+ result.addOperands(cond);
+ OpBuilder::InsertionGuard guard(builder);
+ Region *trueRegion = result.addRegion();
+ Block *block = builder.createBlock(trueRegion);
+ trueBuilder(builder, result.location);
+ Region *falseRegion = result.addRegion();
+ builder.createBlock(falseRegion);
+ falseBuilder(builder, result.location);
+
+ auto yield = dyn_cast<YieldOp>(block->getTerminator());
+ assert((yield && yield.getNumOperands() <= 1) &&
+ "expected zero or one result type");
+ if (yield.getNumOperands() == 1)
+ result.addTypes(TypeRange{yield.getOperandTypes().front()});
+}
+
//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 72ccfa8d4e14e..295fa748b1624 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening
}
};
+class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
+public:
+ using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::TernaryOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ Block *condBlock = rewriter.getInsertionBlock();
+ Block::iterator opPosition = rewriter.getInsertionPoint();
+ Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
+ llvm::SmallVector<mlir::Location, 2> locs;
+ // Ternary result is optional, make sure to populate the location only
+ // when relevant.
+ if (op->getResultTypes().size())
+ locs.push_back(loc);
+ auto *continueBlock =
+ rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
+ rewriter.create<cir::BrOp>(loc, remainingOpsBlock);
+
+ Region &trueRegion = op.getTrueRegion();
+ Block *trueBlock = &trueRegion.front();
+ mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
+ rewriter.setInsertionPointToEnd(&trueRegion.back());
+ auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
+
+ rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
+ continueBlock);
+ rewriter.inlineRegionBefore(trueRegion, continueBlock);
+
+ Block *falseBlock = continueBlock;
+ Region &falseRegion = op.getFalseRegion();
+
+ falseBlock = &falseRegion.front();
+ mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
+ rewriter.setInsertionPointToEnd(&falseRegion.back());
+ cir::YieldOp falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
+ rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(),
+ continueBlock);
+ rewriter.inlineRegionBefore(falseRegion, continueBlock);
+
+ rewriter.setInsertionPointToEnd(condBlock);
+ rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);
+
+ rewriter.replaceOp(op, continueBlock->getArguments());
+
+ // Ok, we're done!
+ return mlir::success();
+ }
+};
+
void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
- patterns
- .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening>(
- patterns.getContext());
+ patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
+ CIRScopeOpFlattening, CIRTernaryOpFlattening>(
+ patterns.getContext());
}
void CIRFlattenCFGPass::runOnOperation() {
@@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() {
getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
assert(!cir::MissingFeatures::ifOp());
assert(!cir::MissingFeatures::switchOp());
- assert(!cir::MissingFeatures::ternaryOp());
assert(!cir::MissingFeatures::tryOp());
- if (isa<IfOp, ScopeOp, LoopOpInterface>(op))
+ if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/IR/ternary.cir b/clang/test/CIR/IR/ternary.cir
new file mode 100644
index 0000000000000..3827dc77726df
--- /dev/null
+++ b/clang/test/CIR/IR/ternary.cir
@@ -0,0 +1,30 @@
+// RUN: cir-opt %s | cir-opt | FileCheck %s
+!u32i = !cir.int<u, 32>
+
+module {
+ cir.func @blue(%arg0: !cir.bool) -> !u32i {
+ %0 = cir.ternary(%arg0, true {
+ %a = cir.const #cir.int<0> : !u32i
+ cir.yield %a : !u32i
+ }, false {
+ %b = cir.const #cir.int<1> : !u32i
+ cir.yield %b : !u32i
+ }) : (!cir.bool) -> !u32i
+ cir.return %0 : !u32i
+ }
+}
+
+// CHECK: module {
+
+// CHECK: cir.func @blue(%arg0: !cir.bool) -> !u32i {
+// CHECK: %0 = cir.ternary(%arg0, true {
+// CHECK: %1 = cir.const #cir.int<0> : !u32i
+// CHECK: cir.yield %1 : !u32i
+// CHECK: }, false {
+// CHECK: %1 = cir.const #cir.int<1> : !u32i
+// CHECK: cir.yield %1 : !u32i
+// CHECK: }) : (!cir.bool) -> !u32i
+// CHECK: cir.return %0 : !u32i
+// CHECK: }
+
+// CHECK: }
diff --git a/clang/test/CIR/Lowering/ternary.cir b/clang/test/CIR/Lowering/ternary.cir
new file mode 100644
index 0000000000000..247c6ae3a1e17
--- /dev/null
+++ b/clang/test/CIR/Lowering/ternary.cir
@@ -0,0 +1,30 @@
+// RUN: cir-translate -cir-to-llvmir --disable-cc-lowering -o %t.ll %s
+// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s
+
+!u32i = !cir.int<u, 32>
+
+module {
+ cir.func @blue(%arg0: !cir.bool) -> !u32i {
+ %0 = cir.ternary(%arg0, true {
+ %a = cir.const #cir.int<0> : !u32i
+ cir.yield %a : !u32i
+ }, false {
+ %b = cir.const #cir.int<1> : !u32i
+ cir.yield %b : !u32i
+ }) : (!cir.bool) -> !u32i
+ cir.return %0 : !u32i
+ }
+}
+
+// LLVM-LABEL: define i32 {{.*}}@blue(
+// LLVM-SAME: i1 [[PRED:%[[:alnum:]]+]])
+// LLVM: br i1 [[PRED]], label %[[B1:[[:alnum:]]+]], label %[[B2:[[:alnum:]]+]]
+// LLVM: [[B1]]:
+// LLVM: br label %[[M:[[:alnum:]]+]]
+// LLVM: [[B2]]:
+// LLVM: br label %[[M]]
+// LLVM: [[M]]:
+// LLVM: [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ]
+// LLVM: br label %[[B3:[[:alnum:]]+]]
+// LLVM: [[B3]]:
+// LLVM: ret i32 [[R]]
diff --git a/clang/test/CIR/Transforms/ternary.cir b/clang/test/CIR/Transforms/ternary.cir
new file mode 100644
index 0000000000000..67ef7f95a6b52
--- /dev/null
+++ b/clang/test/CIR/Transforms/ternary.cir
@@ -0,0 +1,68 @@
+// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @foo(%arg0: !s32i) -> !s32i {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
+ %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ %3 = cir.const #cir.int<0> : !s32i
+ %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
+ %5 = cir.ternary(%4, true {
+ %7 = cir.const #cir.int<3> : !s32i
+ cir.yield %7 : !s32i
+ }, false {
+ %7 = cir.const #cir.int<5> : !s32i
+ cir.yield %7 : !s32i
+ }) : (!cir.bool) -> !s32i
+ cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
+ %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+ cir.return %6 : !s32i
+ }
+
+// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i {
+// CHECK: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
+// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
+// CHECK: cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+// CHECK: %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+// CHECK: %3 = cir.const #cir.int<0> : !s32i
+// CHECK: %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
+// CHECK: cir.brcond %4 ^bb1, ^bb2
+// CHECK: ^bb1: // pred: ^bb0
+// CHECK: %5 = cir.const #cir.int<3> : !s32i
+// CHECK: cir.br ^bb3(%5 : !s32i)
+// CHECK: ^bb2: // pred: ^bb0
+// CHECK: %6 = cir.const #cir.int<5> : !s32i
+// CHECK: cir.br ^bb3(%6 : !s32i)
+// CHECK: ^bb3(%7: !s32i): // 2 preds: ^bb1, ^bb2
+// CHECK: cir.br ^bb4
+// CHECK: ^bb4: // pred: ^bb3
+// CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
+// CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+// CHECK: cir.return %8 : !s32i
+// CHECK: }
+
+ cir.func @foo2(%arg0: !cir.bool) {
+ cir.ternary(%arg0, true {
+ cir.yield
+ }, false {
+ cir.yield
+ }) : (!cir.bool) -> ()
+ cir.return
+ }
+
+// CHECK: cir.func @foo2(%arg0: !cir.bool) {
+// CHECK: cir.brcond %arg0 ^bb1, ^bb2
+// CHECK: ^bb1: // pred: ^bb0
+// CHECK: cir.br ^bb3
+// CHECK: ^bb2: // pred: ^bb0
+// CHECK: cir.br ^bb3
+// CHECK: ^bb3: // 2 preds: ^bb1, ^bb2
+// CHECK: cir.br ^bb4
+// CHECK: ^bb4: // pred: ^bb3
+// CHECK: cir.return
+// CHECK: }
+
+}
>From 3e9d9d35b52c0b69ac9950f53cad044a958a81d4 Mon Sep 17 00:00:00 2001
From: Morris Hafner <mhafner at nvidia.com>
Date: Thu, 24 Apr 2025 17:08:17 +0200
Subject: [PATCH 2/2] Reorder YieldOp parents lexicographically
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 76ad5c3666c1b..760149636b23b 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -609,8 +609,8 @@ def ConditionOp : CIR_Op<"condition", [
//===----------------------------------------------------------------------===//
def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
- ParentOneOf<["IfOp", "TernaryOp", "ScopeOp",
- "WhileOp", "ForOp", "DoWhileOp"]>]> {
+ ParentOneOf<["DoWhileOp", "ForOp", "WhileOp",
+ "IfOp", "ScopeOp", "TernaryOp"]>]> {
let summary = "Represents the default branching behaviour of a region";
let description = [{
The `cir.yield` operation terminates regions on different CIR operations,
More information about the cfe-commits
mailing list