[clang] [CIR] Upstream TryCallOp (PR #165303)

via cfe-commits cfe-commits at lists.llvm.org
Mon Oct 27 12:29:53 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clangir

Author: Amr Hesham (AmrDeveloper)

<details>
<summary>Changes</summary>

Upstream TryCall Op as a prerequisite for Try Catch work

Issue https://github.com/llvm/llvm-project/issues/154992

---
Full diff: https://github.com/llvm/llvm-project/pull/165303.diff


5 Files Affected:

- (modified) clang/include/clang/CIR/Dialect/IR/CIRDialect.td (+1) 
- (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+93-1) 
- (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+191-6) 
- (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+19-5) 
- (added) clang/test/CIR/IR/try-call.cir (+31) 


``````````diff
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index e91537186df59..34df9af7fc06d 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -44,6 +44,7 @@ def CIR_Dialect : Dialect {
     static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
     static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
     static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
+    static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }
 
     void registerAttributes();
     void registerTypes();
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 2b361ed0982c6..8f3e25b3c9737 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2580,7 +2580,7 @@ def CIR_FuncOp : CIR_Op<"func", [
 }
 
 //===----------------------------------------------------------------------===//
-// CallOp
+// CallOp and TryCallOp
 //===----------------------------------------------------------------------===//
 
 def CIR_SideEffect : CIR_I32EnumAttr<
@@ -2707,6 +2707,98 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
   ];
 }
 
+def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
+  DeclareOpInterfaceMethods<BranchOpInterface>,
+  Terminator, AttrSizedOperandSegments
+]> {
+  let summary = "try_call operation";
+
+  let description = [{
+    Mostly similar to cir.call but requires two destination
+    branches, one for handling exceptions in case its thrown and
+    the other one to follow on regular control-flow.
+
+    Example:
+
+    ```mlir
+    // Direct call
+    %result = cir.try_call @division(%a, %b) ^continue, ^landing_pad 
+      : (f32, f32) -> f32
+    ```
+  }];
+
+  let arguments = !con((ins
+    Variadic<CIR_AnyType>:$contOperands,
+    Variadic<CIR_AnyType>:$landingPadOperands
+  ), commonArgs);
+
+  let results = (outs Optional<CIR_AnyType>:$result);
+  let successors = (successor AnySuccessor:$cont, AnySuccessor:$landing_pad);
+
+  let skipDefaultBuilders = 1;
+
+  let builders = [
+    OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
+               "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+               CArg<"mlir::ValueRange", "{}">:$operands,
+               CArg<"mlir::ValueRange", "{}">:$contOperands,
+               CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
+               CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+      $_state.addOperands(operands);
+      if (callee)
+        $_state.addAttribute("callee", callee);
+      if (resType && !isa<VoidType>(resType))
+        $_state.addTypes(resType);
+
+      $_state.addAttribute("side_effect",
+        SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+      // Handle branches
+      $_state.addOperands(contOperands);
+      $_state.addOperands(landingPadOperands);
+      // The TryCall ODS layout is: cont, landing_pad, operands.
+      llvm::copy(::llvm::ArrayRef<int32_t>({
+        static_cast<int32_t>(contOperands.size()),
+        static_cast<int32_t>(landingPadOperands.size()),
+        static_cast<int32_t>(operands.size())
+        }),
+        odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+      $_state.addSuccessors(cont);
+      $_state.addSuccessors(landing_pad);
+    }]>,
+    OpBuilder<(ins "mlir::Value":$ind_target,
+               "FuncType":$fn_type,
+               "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+               CArg<"mlir::ValueRange", "{}">:$operands,
+               CArg<"mlir::ValueRange", "{}">:$contOperands,
+               CArg<"mlir::ValueRange", "{}">:$landingPadOperands, 
+               CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+      ::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
+      finalCallOperands.append(operands.begin(), operands.end());
+      $_state.addOperands(finalCallOperands);
+
+      if (!fn_type.hasVoidReturn())
+        $_state.addTypes(fn_type.getReturnType());
+
+      $_state.addAttribute("side_effect",
+        SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+      // Handle branches
+      $_state.addOperands(contOperands);
+      $_state.addOperands(landingPadOperands);
+      // The TryCall ODS layout is: cont, landing_pad, operands.
+      llvm::copy(::llvm::ArrayRef<int32_t>({
+        static_cast<int32_t>(contOperands.size()),
+        static_cast<int32_t>(landingPadOperands.size()),
+        static_cast<int32_t>(finalCallOperands.size())
+        }),
+        odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+      $_state.addSuccessors(cont);
+      $_state.addSuccessors(landing_pad);
+    }]>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 2d2ef422bfaef..11074af3ef127 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -701,13 +701,78 @@ unsigned cir::CallOp::getNumArgOperands() {
   return this->getOperation()->getNumOperands();
 }
 
+static mlir::ParseResult
+parseTryCallBranches(mlir::OpAsmParser &parser, mlir::OperationState &result,
+                     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+                         &continueOperands,
+                     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+                         &landingPadOperands,
+                     llvm::SmallVectorImpl<mlir::Type> &continueTypes,
+                     llvm::SmallVectorImpl<mlir::Type> &landingPadTypes,
+                     llvm::SMLoc &continueOperandsLoc,
+                     llvm::SMLoc &landingPadOperandsLoc) {
+  mlir::Block *continueSuccessor = nullptr;
+  mlir::Block *landingPadSuccessor = nullptr;
+
+  if (parser.parseSuccessor(continueSuccessor))
+    return mlir::failure();
+
+  if (mlir::succeeded(parser.parseOptionalLParen())) {
+    continueOperandsLoc = parser.getCurrentLocation();
+    if (parser.parseOperandList(continueOperands))
+      return mlir::failure();
+    if (parser.parseColon())
+      return mlir::failure();
+
+    if (parser.parseTypeList(continueTypes))
+      return mlir::failure();
+    if (parser.parseRParen())
+      return mlir::failure();
+  }
+
+  if (parser.parseComma())
+    return mlir::failure();
+
+  if (parser.parseSuccessor(landingPadSuccessor))
+    return mlir::failure();
+
+  if (mlir::succeeded(parser.parseOptionalLParen())) {
+    landingPadOperandsLoc = parser.getCurrentLocation();
+    if (parser.parseOperandList(landingPadOperands))
+      return mlir::failure();
+    if (parser.parseColon())
+      return mlir::failure();
+
+    if (parser.parseTypeList(landingPadTypes))
+      return mlir::failure();
+    if (parser.parseRParen())
+      return mlir::failure();
+  }
+
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return mlir::failure();
+
+  result.addSuccessors(continueSuccessor);
+  result.addSuccessors(landingPadSuccessor);
+  return mlir::success();
+}
+
 static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
-                                         mlir::OperationState &result) {
+                                         mlir::OperationState &result,
+                                         bool hasDestinationBlocks = false) {
   llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
   llvm::SMLoc opsLoc;
   mlir::FlatSymbolRefAttr calleeAttr;
   llvm::ArrayRef<mlir::Type> allResultTypes;
 
+  // TryCall control flow related
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> continueOperands;
+  llvm::SMLoc continueOperandsLoc;
+  llvm::SmallVector<mlir::Type, 1> continueTypes;
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> landingPadOperands;
+  llvm::SMLoc landingPadOperandsLoc;
+  llvm::SmallVector<mlir::Type, 1> landingPadTypes;
+
   // If we cannot parse a string callee, it means this is an indirect call.
   if (!parser
            .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
@@ -729,6 +794,14 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
   if (parser.parseRParen())
     return mlir::failure();
 
+  if (hasDestinationBlocks &&
+      parseTryCallBranches(parser, result, continueOperands, landingPadOperands,
+                           continueTypes, landingPadTypes, continueOperandsLoc,
+                           landingPadOperandsLoc)
+          .failed()) {
+    return ::mlir::failure();
+  }
+
   if (parser.parseOptionalKeyword("nothrow").succeeded())
     result.addAttribute(CIRDialect::getNoThrowAttrName(),
                         mlir::UnitAttr::get(parser.getContext()));
@@ -761,6 +834,24 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
   if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
     return mlir::failure();
 
+  if (hasDestinationBlocks) {
+    // The TryCall ODS layout is: cont, landing_pad, operands.
+    llvm::copy(::llvm::ArrayRef<int32_t>(
+                   {static_cast<int32_t>(continueOperands.size()),
+                    static_cast<int32_t>(landingPadOperands.size()),
+                    static_cast<int32_t>(ops.size())}),
+               result.getOrAddProperties<cir::TryCallOp::Properties>()
+                   .operandSegmentSizes.begin());
+
+    if (parser.resolveOperands(continueOperands, continueTypes,
+                               continueOperandsLoc, result.operands))
+      return ::mlir::failure();
+
+    if (parser.resolveOperands(landingPadOperands, landingPadTypes,
+                               landingPadOperandsLoc, result.operands))
+      return ::mlir::failure();
+  }
+
   return mlir::success();
 }
 
@@ -768,7 +859,9 @@ static void printCallCommon(mlir::Operation *op,
                             mlir::FlatSymbolRefAttr calleeSym,
                             mlir::Value indirectCallee,
                             mlir::OpAsmPrinter &printer, bool isNothrow,
-                            cir::SideEffect sideEffect) {
+                            cir::SideEffect sideEffect,
+                            mlir::Block *cont = nullptr,
+                            mlir::Block *landingPad = nullptr) {
   printer << ' ';
 
   auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
@@ -782,8 +875,35 @@ static void printCallCommon(mlir::Operation *op,
     assert(indirectCallee);
     printer << indirectCallee;
   }
+
   printer << "(" << ops << ")";
 
+  if (cont) {
+    assert(landingPad && "expected two successors");
+    auto tryCall = dyn_cast<cir::TryCallOp>(op);
+    assert(tryCall && "regular calls do not branch");
+    printer << ' ' << tryCall.getCont();
+    if (!tryCall.getContOperands().empty()) {
+      printer << "(";
+      printer << tryCall.getContOperands();
+      printer << ' ' << ":";
+      printer << ' ';
+      printer << tryCall.getContOperands().getTypes();
+      printer << ")";
+    }
+    printer << ",";
+    printer << ' ';
+    printer << tryCall.getLandingPad();
+    if (!tryCall.getLandingPadOperands().empty()) {
+      printer << "(";
+      printer << tryCall.getLandingPadOperands();
+      printer << ' ' << ":";
+      printer << ' ';
+      printer << tryCall.getLandingPadOperands().getTypes();
+      printer << ")";
+    }
+  }
+
   if (isNothrow)
     printer << " nothrow";
 
@@ -793,10 +913,11 @@ static void printCallCommon(mlir::Operation *op,
     printer << ")";
   }
 
-  printer.printOptionalAttrDict(op->getAttrs(),
-                                {CIRDialect::getCalleeAttrName(),
-                                 CIRDialect::getNoThrowAttrName(),
-                                 CIRDialect::getSideEffectAttrName()});
+  llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = {
+      CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
+      CIRDialect::getSideEffectAttrName(),
+      CIRDialect::getOperandSegmentSizesAttrName()};
+  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 
   printer << " : ";
   printer.printFunctionalType(op->getOperands().getTypes(),
@@ -878,6 +999,70 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return verifyCallCommInSymbolUses(*this, symbolTable);
 }
 
+//===----------------------------------------------------------------------===//
+// TryCallOp
+//===----------------------------------------------------------------------===//
+
+mlir::OperandRange cir::TryCallOp::getArgOperands() {
+  if (isIndirect())
+    return getArgs().drop_front(1);
+  return getArgs();
+}
+
+mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
+  mlir::MutableOperandRange args = getArgsMutable();
+  if (isIndirect())
+    return args.slice(1, args.size() - 1);
+  return args;
+}
+
+mlir::Value cir::TryCallOp::getIndirectCall() {
+  assert(isIndirect());
+  return getOperand(0);
+}
+
+/// Return the operand at index 'i'.
+Value cir::TryCallOp::getArgOperand(unsigned i) {
+  if (isIndirect())
+    ++i;
+  return getOperand(i);
+}
+
+/// Return the number of operands.
+unsigned cir::TryCallOp::getNumArgOperands() {
+  if (isIndirect())
+    return this->getOperation()->getNumOperands() - 1;
+  return this->getOperation()->getNumOperands();
+}
+
+LogicalResult
+cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyCallCommInSymbolUses(*this, symbolTable);
+}
+
+mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
+                                        mlir::OperationState &result) {
+  return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
+}
+
+void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
+  mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
+  cir::SideEffect sideEffect = getSideEffect();
+  printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
+                  sideEffect, getCont(), getLandingPad());
+}
+
+mlir::SuccessorOperands cir::TryCallOp::getSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  if (index == 0)
+    return SuccessorOperands(getContOperandsMutable());
+  if (index == 1)
+    return SuccessorOperands(getLandingPadOperandsMutable());
+
+  // index == 2
+  return SuccessorOperands(getArgOperandsMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // ReturnOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5a6193fa8d840..12f3db01c77d8 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1385,7 +1385,9 @@ static mlir::LogicalResult
 rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
                     mlir::ConversionPatternRewriter &rewriter,
                     const mlir::TypeConverter *converter,
-                    mlir::FlatSymbolRefAttr calleeAttr) {
+                    mlir::FlatSymbolRefAttr calleeAttr,
+                    mlir::Block *continueBlock = nullptr,
+                    mlir::Block *landingPadBlock = nullptr) {
   llvm::SmallVector<mlir::Type, 8> llvmResults;
   mlir::ValueTypeRange<mlir::ResultRange> cirResults = op->getResultTypes();
   auto call = cast<cir::CIRCallOpInterface>(op);
@@ -1414,7 +1416,7 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
       llvmFnTy = converter->convertType<mlir::LLVM::LLVMFunctionType>(
           fn.getFunctionType());
       assert(llvmFnTy && "Failed to convert function type");
-    } else if (auto alias = mlir::cast<mlir::LLVM::AliasOp>(callee)) {
+    } else if (auto alias = mlir::dyn_cast<mlir::LLVM::AliasOp>(callee)) {
       // If the callee was an alias. In that case,
       // we need to prepend the address of the alias to the operands. The
       // way aliases work in the LLVM dialect is a little counter-intuitive.
@@ -1452,17 +1454,21 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
         converter->convertType(calleeFuncTy));
   }
 
-  assert(!cir::MissingFeatures::opCallLandingPad());
-  assert(!cir::MissingFeatures::opCallContinueBlock());
   assert(!cir::MissingFeatures::opCallCallConv());
 
+  if (landingPadBlock) {
+    rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
+        op, llvmFnTy, calleeAttr, callOperands, continueBlock,
+        mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
+    return mlir::success();
+  }
+
   auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
       op, llvmFnTy, calleeAttr, callOperands);
   if (memoryEffects)
     newOp.setMemoryEffectsAttr(memoryEffects);
   newOp.setNoUnwind(noUnwind);
   newOp.setWillReturn(willReturn);
-
   return mlir::success();
 }
 
@@ -1473,6 +1479,14 @@ mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite(
                              getTypeConverter(), op.getCalleeAttr());
 }
 
+mlir::LogicalResult CIRToLLVMTryCallOpLowering::matchAndRewrite(
+    cir::TryCallOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter,
+                             getTypeConverter(), op.getCalleeAttr(),
+                             op.getCont(), op.getLandingPad());
+}
+
 mlir::LogicalResult CIRToLLVMReturnAddrOpLowering::matchAndRewrite(
     cir::ReturnAddrOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/test/CIR/IR/try-call.cir b/clang/test/CIR/IR/try-call.cir
new file mode 100644
index 0000000000000..6c23d3add15c8
--- /dev/null
+++ b/clang/test/CIR/IR/try-call.cir
@@ -0,0 +1,31 @@
+// RUN: cir-opt %s --verify-roundtrip | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i
+
+cir.func @flatten_structure_with_try_call_op() {
+   %a = cir.const #cir.int<1> : !s32i
+   %b = cir.const #cir.int<2> : !s32i
+   %3 = cir.try_call @division(%a, %b) ^continue, ^landing_pad : (!s32i, !s32i) -> !s32i
+ ^continue:
+   cir.br ^landing_pad
+ ^landing_pad:
+   cir.return
+}
+
+// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i
+
+// CHECK: cir.func @flatten_structure_with_try_call_op() {
+// CHECK-NEXT:   %[[CONST_0:.*]] = cir.const #cir.int<1> : !s32i
+// CHECK-NEXT:   %[[CONST_1:.*]] = cir.const #cir.int<2> : !s32i
+// CHECK-NEXT:   %[[CALL:.*]] = cir.try_call @division(%0, %1) ^[[CONTINUE:.*]], ^[[LANDING_PAD:.*]] : (!s32i, !s32i) -> !s32i
+// CHECK-NEXT: ^[[CONTINUE]]:
+// CHECK-NEXT:   cir.br ^[[LANDING_PAD]]
+// CHECK-NEXT: ^[[LANDING_PAD]]:
+// CHECK-NEXT:   cir.return
+// CHECK-NEXT: }
+
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/165303


More information about the cfe-commits mailing list