[clang] [CIR] Upstream TryCallOp (PR #165303)
via cfe-commits
cfe-commits at lists.llvm.org
Mon Oct 27 12:29:53 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clangir
Author: Amr Hesham (AmrDeveloper)
<details>
<summary>Changes</summary>
Upstream TryCall Op as a prerequisite for Try Catch work
Issue https://github.com/llvm/llvm-project/issues/154992
---
Full diff: https://github.com/llvm/llvm-project/pull/165303.diff
5 Files Affected:
- (modified) clang/include/clang/CIR/Dialect/IR/CIRDialect.td (+1)
- (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+93-1)
- (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+191-6)
- (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+19-5)
- (added) clang/test/CIR/IR/try-call.cir (+31)
``````````diff
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index e91537186df59..34df9af7fc06d 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -44,6 +44,7 @@ def CIR_Dialect : Dialect {
static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
+ static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }
void registerAttributes();
void registerTypes();
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 2b361ed0982c6..8f3e25b3c9737 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2580,7 +2580,7 @@ def CIR_FuncOp : CIR_Op<"func", [
}
//===----------------------------------------------------------------------===//
-// CallOp
+// CallOp and TryCallOp
//===----------------------------------------------------------------------===//
def CIR_SideEffect : CIR_I32EnumAttr<
@@ -2707,6 +2707,98 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
];
}
+def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
+ DeclareOpInterfaceMethods<BranchOpInterface>,
+ Terminator, AttrSizedOperandSegments
+]> {
+ let summary = "try_call operation";
+
+ let description = [{
+ Mostly similar to cir.call but requires two destination
+ branches, one for handling exceptions in case its thrown and
+ the other one to follow on regular control-flow.
+
+ Example:
+
+ ```mlir
+ // Direct call
+ %result = cir.try_call @division(%a, %b) ^continue, ^landing_pad
+ : (f32, f32) -> f32
+ ```
+ }];
+
+ let arguments = !con((ins
+ Variadic<CIR_AnyType>:$contOperands,
+ Variadic<CIR_AnyType>:$landingPadOperands
+ ), commonArgs);
+
+ let results = (outs Optional<CIR_AnyType>:$result);
+ let successors = (successor AnySuccessor:$cont, AnySuccessor:$landing_pad);
+
+ let skipDefaultBuilders = 1;
+
+ let builders = [
+ OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
+ "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+ CArg<"mlir::ValueRange", "{}">:$operands,
+ CArg<"mlir::ValueRange", "{}">:$contOperands,
+ CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
+ CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+ $_state.addOperands(operands);
+ if (callee)
+ $_state.addAttribute("callee", callee);
+ if (resType && !isa<VoidType>(resType))
+ $_state.addTypes(resType);
+
+ $_state.addAttribute("side_effect",
+ SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+ // Handle branches
+ $_state.addOperands(contOperands);
+ $_state.addOperands(landingPadOperands);
+ // The TryCall ODS layout is: cont, landing_pad, operands.
+ llvm::copy(::llvm::ArrayRef<int32_t>({
+ static_cast<int32_t>(contOperands.size()),
+ static_cast<int32_t>(landingPadOperands.size()),
+ static_cast<int32_t>(operands.size())
+ }),
+ odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+ $_state.addSuccessors(cont);
+ $_state.addSuccessors(landing_pad);
+ }]>,
+ OpBuilder<(ins "mlir::Value":$ind_target,
+ "FuncType":$fn_type,
+ "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+ CArg<"mlir::ValueRange", "{}">:$operands,
+ CArg<"mlir::ValueRange", "{}">:$contOperands,
+ CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
+ CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+ ::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
+ finalCallOperands.append(operands.begin(), operands.end());
+ $_state.addOperands(finalCallOperands);
+
+ if (!fn_type.hasVoidReturn())
+ $_state.addTypes(fn_type.getReturnType());
+
+ $_state.addAttribute("side_effect",
+ SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+ // Handle branches
+ $_state.addOperands(contOperands);
+ $_state.addOperands(landingPadOperands);
+ // The TryCall ODS layout is: cont, landing_pad, operands.
+ llvm::copy(::llvm::ArrayRef<int32_t>({
+ static_cast<int32_t>(contOperands.size()),
+ static_cast<int32_t>(landingPadOperands.size()),
+ static_cast<int32_t>(finalCallOperands.size())
+ }),
+ odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+ $_state.addSuccessors(cont);
+ $_state.addSuccessors(landing_pad);
+ }]>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 2d2ef422bfaef..11074af3ef127 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -701,13 +701,78 @@ unsigned cir::CallOp::getNumArgOperands() {
return this->getOperation()->getNumOperands();
}
+static mlir::ParseResult
+parseTryCallBranches(mlir::OpAsmParser &parser, mlir::OperationState &result,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+ &continueOperands,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+ &landingPadOperands,
+ llvm::SmallVectorImpl<mlir::Type> &continueTypes,
+ llvm::SmallVectorImpl<mlir::Type> &landingPadTypes,
+ llvm::SMLoc &continueOperandsLoc,
+ llvm::SMLoc &landingPadOperandsLoc) {
+ mlir::Block *continueSuccessor = nullptr;
+ mlir::Block *landingPadSuccessor = nullptr;
+
+ if (parser.parseSuccessor(continueSuccessor))
+ return mlir::failure();
+
+ if (mlir::succeeded(parser.parseOptionalLParen())) {
+ continueOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(continueOperands))
+ return mlir::failure();
+ if (parser.parseColon())
+ return mlir::failure();
+
+ if (parser.parseTypeList(continueTypes))
+ return mlir::failure();
+ if (parser.parseRParen())
+ return mlir::failure();
+ }
+
+ if (parser.parseComma())
+ return mlir::failure();
+
+ if (parser.parseSuccessor(landingPadSuccessor))
+ return mlir::failure();
+
+ if (mlir::succeeded(parser.parseOptionalLParen())) {
+ landingPadOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(landingPadOperands))
+ return mlir::failure();
+ if (parser.parseColon())
+ return mlir::failure();
+
+ if (parser.parseTypeList(landingPadTypes))
+ return mlir::failure();
+ if (parser.parseRParen())
+ return mlir::failure();
+ }
+
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return mlir::failure();
+
+ result.addSuccessors(continueSuccessor);
+ result.addSuccessors(landingPadSuccessor);
+ return mlir::success();
+}
+
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
- mlir::OperationState &result) {
+ mlir::OperationState &result,
+ bool hasDestinationBlocks = false) {
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
llvm::SMLoc opsLoc;
mlir::FlatSymbolRefAttr calleeAttr;
llvm::ArrayRef<mlir::Type> allResultTypes;
+ // TryCall control flow related
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> continueOperands;
+ llvm::SMLoc continueOperandsLoc;
+ llvm::SmallVector<mlir::Type, 1> continueTypes;
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> landingPadOperands;
+ llvm::SMLoc landingPadOperandsLoc;
+ llvm::SmallVector<mlir::Type, 1> landingPadTypes;
+
// If we cannot parse a string callee, it means this is an indirect call.
if (!parser
.parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
@@ -729,6 +794,14 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
if (parser.parseRParen())
return mlir::failure();
+ if (hasDestinationBlocks &&
+ parseTryCallBranches(parser, result, continueOperands, landingPadOperands,
+ continueTypes, landingPadTypes, continueOperandsLoc,
+ landingPadOperandsLoc)
+ .failed()) {
+ return ::mlir::failure();
+ }
+
if (parser.parseOptionalKeyword("nothrow").succeeded())
result.addAttribute(CIRDialect::getNoThrowAttrName(),
mlir::UnitAttr::get(parser.getContext()));
@@ -761,6 +834,24 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
return mlir::failure();
+ if (hasDestinationBlocks) {
+ // The TryCall ODS layout is: cont, landing_pad, operands.
+ llvm::copy(::llvm::ArrayRef<int32_t>(
+ {static_cast<int32_t>(continueOperands.size()),
+ static_cast<int32_t>(landingPadOperands.size()),
+ static_cast<int32_t>(ops.size())}),
+ result.getOrAddProperties<cir::TryCallOp::Properties>()
+ .operandSegmentSizes.begin());
+
+ if (parser.resolveOperands(continueOperands, continueTypes,
+ continueOperandsLoc, result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(landingPadOperands, landingPadTypes,
+ landingPadOperandsLoc, result.operands))
+ return ::mlir::failure();
+ }
+
return mlir::success();
}
@@ -768,7 +859,9 @@ static void printCallCommon(mlir::Operation *op,
mlir::FlatSymbolRefAttr calleeSym,
mlir::Value indirectCallee,
mlir::OpAsmPrinter &printer, bool isNothrow,
- cir::SideEffect sideEffect) {
+ cir::SideEffect sideEffect,
+ mlir::Block *cont = nullptr,
+ mlir::Block *landingPad = nullptr) {
printer << ' ';
auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
@@ -782,8 +875,35 @@ static void printCallCommon(mlir::Operation *op,
assert(indirectCallee);
printer << indirectCallee;
}
+
printer << "(" << ops << ")";
+ if (cont) {
+ assert(landingPad && "expected two successors");
+ auto tryCall = dyn_cast<cir::TryCallOp>(op);
+ assert(tryCall && "regular calls do not branch");
+ printer << ' ' << tryCall.getCont();
+ if (!tryCall.getContOperands().empty()) {
+ printer << "(";
+ printer << tryCall.getContOperands();
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << tryCall.getContOperands().getTypes();
+ printer << ")";
+ }
+ printer << ",";
+ printer << ' ';
+ printer << tryCall.getLandingPad();
+ if (!tryCall.getLandingPadOperands().empty()) {
+ printer << "(";
+ printer << tryCall.getLandingPadOperands();
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << tryCall.getLandingPadOperands().getTypes();
+ printer << ")";
+ }
+ }
+
if (isNothrow)
printer << " nothrow";
@@ -793,10 +913,11 @@ static void printCallCommon(mlir::Operation *op,
printer << ")";
}
- printer.printOptionalAttrDict(op->getAttrs(),
- {CIRDialect::getCalleeAttrName(),
- CIRDialect::getNoThrowAttrName(),
- CIRDialect::getSideEffectAttrName()});
+ llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = {
+ CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
+ CIRDialect::getSideEffectAttrName(),
+ CIRDialect::getOperandSegmentSizesAttrName()};
+ printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
printer << " : ";
printer.printFunctionalType(op->getOperands().getTypes(),
@@ -878,6 +999,70 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyCallCommInSymbolUses(*this, symbolTable);
}
+//===----------------------------------------------------------------------===//
+// TryCallOp
+//===----------------------------------------------------------------------===//
+
+mlir::OperandRange cir::TryCallOp::getArgOperands() {
+ if (isIndirect())
+ return getArgs().drop_front(1);
+ return getArgs();
+}
+
+mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
+ mlir::MutableOperandRange args = getArgsMutable();
+ if (isIndirect())
+ return args.slice(1, args.size() - 1);
+ return args;
+}
+
+mlir::Value cir::TryCallOp::getIndirectCall() {
+ assert(isIndirect());
+ return getOperand(0);
+}
+
+/// Return the operand at index 'i'.
+Value cir::TryCallOp::getArgOperand(unsigned i) {
+ if (isIndirect())
+ ++i;
+ return getOperand(i);
+}
+
+/// Return the number of operands.
+unsigned cir::TryCallOp::getNumArgOperands() {
+ if (isIndirect())
+ return this->getOperation()->getNumOperands() - 1;
+ return this->getOperation()->getNumOperands();
+}
+
+LogicalResult
+cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ return verifyCallCommInSymbolUses(*this, symbolTable);
+}
+
+mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
+}
+
+void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
+ mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
+ cir::SideEffect sideEffect = getSideEffect();
+ printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
+ sideEffect, getCont(), getLandingPad());
+}
+
+mlir::SuccessorOperands cir::TryCallOp::getSuccessorOperands(unsigned index) {
+ assert(index < getNumSuccessors() && "invalid successor index");
+ if (index == 0)
+ return SuccessorOperands(getContOperandsMutable());
+ if (index == 1)
+ return SuccessorOperands(getLandingPadOperandsMutable());
+
+ // index == 2
+ return SuccessorOperands(getArgOperandsMutable());
+}
+
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5a6193fa8d840..12f3db01c77d8 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1385,7 +1385,9 @@ static mlir::LogicalResult
rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
- mlir::FlatSymbolRefAttr calleeAttr) {
+ mlir::FlatSymbolRefAttr calleeAttr,
+ mlir::Block *continueBlock = nullptr,
+ mlir::Block *landingPadBlock = nullptr) {
llvm::SmallVector<mlir::Type, 8> llvmResults;
mlir::ValueTypeRange<mlir::ResultRange> cirResults = op->getResultTypes();
auto call = cast<cir::CIRCallOpInterface>(op);
@@ -1414,7 +1416,7 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
llvmFnTy = converter->convertType<mlir::LLVM::LLVMFunctionType>(
fn.getFunctionType());
assert(llvmFnTy && "Failed to convert function type");
- } else if (auto alias = mlir::cast<mlir::LLVM::AliasOp>(callee)) {
+ } else if (auto alias = mlir::dyn_cast<mlir::LLVM::AliasOp>(callee)) {
// If the callee was an alias. In that case,
// we need to prepend the address of the alias to the operands. The
// way aliases work in the LLVM dialect is a little counter-intuitive.
@@ -1452,17 +1454,21 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
converter->convertType(calleeFuncTy));
}
- assert(!cir::MissingFeatures::opCallLandingPad());
- assert(!cir::MissingFeatures::opCallContinueBlock());
assert(!cir::MissingFeatures::opCallCallConv());
+ if (landingPadBlock) {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
+ op, llvmFnTy, calleeAttr, callOperands, continueBlock,
+ mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
+ return mlir::success();
+ }
+
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
op, llvmFnTy, calleeAttr, callOperands);
if (memoryEffects)
newOp.setMemoryEffectsAttr(memoryEffects);
newOp.setNoUnwind(noUnwind);
newOp.setWillReturn(willReturn);
-
return mlir::success();
}
@@ -1473,6 +1479,14 @@ mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite(
getTypeConverter(), op.getCalleeAttr());
}
+mlir::LogicalResult CIRToLLVMTryCallOpLowering::matchAndRewrite(
+ cir::TryCallOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter,
+ getTypeConverter(), op.getCalleeAttr(),
+ op.getCont(), op.getLandingPad());
+}
+
mlir::LogicalResult CIRToLLVMReturnAddrOpLowering::matchAndRewrite(
cir::ReturnAddrOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/test/CIR/IR/try-call.cir b/clang/test/CIR/IR/try-call.cir
new file mode 100644
index 0000000000000..6c23d3add15c8
--- /dev/null
+++ b/clang/test/CIR/IR/try-call.cir
@@ -0,0 +1,31 @@
+// RUN: cir-opt %s --verify-roundtrip | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i
+
+cir.func @flatten_structure_with_try_call_op() {
+ %a = cir.const #cir.int<1> : !s32i
+ %b = cir.const #cir.int<2> : !s32i
+ %3 = cir.try_call @division(%a, %b) ^continue, ^landing_pad : (!s32i, !s32i) -> !s32i
+ ^continue:
+ cir.br ^landing_pad
+ ^landing_pad:
+ cir.return
+}
+
+// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i
+
+// CHECK: cir.func @flatten_structure_with_try_call_op() {
+// CHECK-NEXT: %[[CONST_0:.*]] = cir.const #cir.int<1> : !s32i
+// CHECK-NEXT: %[[CONST_1:.*]] = cir.const #cir.int<2> : !s32i
+// CHECK-NEXT: %[[CALL:.*]] = cir.try_call @division(%0, %1) ^[[CONTINUE:.*]], ^[[LANDING_PAD:.*]] : (!s32i, !s32i) -> !s32i
+// CHECK-NEXT: ^[[CONTINUE]]:
+// CHECK-NEXT: cir.br ^[[LANDING_PAD]]
+// CHECK-NEXT: ^[[LANDING_PAD]]:
+// CHECK-NEXT: cir.return
+// CHECK-NEXT: }
+
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/165303
More information about the cfe-commits
mailing list