[Mlir-commits] [mlir] [MLIR][LLVM] Add vararg support in LLVM::LLVMFuncOp (PR #67274)

Ivan R. Ivanov llvmlistbot at llvm.org
Mon Sep 25 20:08:02 PDT 2023


https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/67274

>From ee20d089575575c420d6f475cdd1040b1169774b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sat, 23 Sep 2023 17:46:17 +0900
Subject: [PATCH 1/4] Build LLVM::CallOp with callee type info

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td |  9 +++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp  | 51 ++++++++++++++++++---
 2 files changed, 53 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 2b4c8b609cfdd4f..63deb00802446b9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -615,7 +615,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     ```
   }];
 
-  dag args = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
+  dag args = (ins TypeAttrOf<LLVM_FunctionType>:$callee_type,
+                  OptionalAttr<FlatSymbolRefAttr>:$callee,
                   Variadic<LLVM_Type>:$callee_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
@@ -630,6 +631,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
                    CArg<"ValueRange", "{}">:$args)>,
     OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
+                   CArg<"ValueRange", "{}">:$args)>,
+    OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringAttr":$callee,
+                   CArg<"ValueRange", "{}">:$args)>,
+    OpBuilder<(ins "LLVMFunctionType":$calleeType, "FlatSymbolRefAttr":$callee,
+                   CArg<"ValueRange", "{}">:$args)>,
+    OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
                    CArg<"ValueRange", "{}">:$args)>
   ];
   let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index fd0d2b3fb3c1a08..2def142b1d2d457 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -950,6 +950,16 @@ static void printStoreType(OpAsmPrinter &printer, Operation *op,
 // CallOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+static TypeRange getCallOpResults(LLVMFunctionType calleeType) {
+  SmallVector<Type> results;
+  Type resultType = calleeType.getReturnType();
+  if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
+    results.push_back(resultType);
+  return results;
+}
+} // namespace
+
 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    StringRef callee, ValueRange args) {
   build(builder, state, results, builder.getStringAttr(callee), args);
@@ -962,7 +972,38 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
 
 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    FlatSymbolRefAttr callee, ValueRange args) {
-  build(builder, state, results, callee, args, /*fastmathFlags=*/nullptr,
+  Type resultType;
+  if (results.empty())
+    resultType = LLVMVoidType::get(builder.getContext());
+  else
+    resultType = results.front();
+  std::vector<Type> argTypes(args.getTypes().begin(), args.getTypes().end());
+  auto calleeType =
+      LLVMFunctionType::get(resultType, argTypes, /*isVariadic*/ false);
+  build(builder, state, results, TypeAttr::get(calleeType), callee, args,
+        /*fastmathFlags=*/nullptr,
+        /*branch_weights=*/nullptr,
+        /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
+        /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+                   LLVMFunctionType calleeType, StringRef callee,
+                   ValueRange args) {
+  build(builder, state, calleeType, builder.getStringAttr(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+                   LLVMFunctionType calleeType, StringAttr callee,
+                   ValueRange args) {
+  build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+                   LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
+                   ValueRange args) {
+  build(builder, state, getCallOpResults(calleeType), TypeAttr::get(calleeType),
+        callee, args, /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -970,11 +1011,9 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
 
 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                    ValueRange args) {
-  SmallVector<Type> results;
-  Type resultType = func.getFunctionType().getReturnType();
-  if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
-    results.push_back(resultType);
-  build(builder, state, results, SymbolRefAttr::get(func), args,
+  auto calleeType = func.getFunctionType();
+  build(builder, state, getCallOpResults(calleeType), TypeAttr::get(calleeType),
+        SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,

>From 6f95cba5452169d46679a9adfe6296bc1899605a Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sat, 23 Sep 2023 17:54:54 +0900
Subject: [PATCH 2/4] Fix Conversion of CallOp func type

---
 .../LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp      | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 9b438090c84cacc..c27bc1b6f29076b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -211,9 +211,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
       call = builder.CreateCall(
           moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
     } else {
-      call = builder.CreateCall(getCalleeFunctionType(callOp.getResultTypes(),
-                                                      callOp.getArgOperands()),
-                                operandsRef.front(), operandsRef.drop_front());
+      call = builder.CreateCall(
+          llvm::cast<llvm::FunctionType>(
+              moduleTranslation.convertType(callOp.getCalleeType())),
+          operandsRef.front(), operandsRef.drop_front());
     }
     moduleTranslation.setAccessGroupsMetadata(callOp, call);
     moduleTranslation.setAliasScopeMetadata(callOp, call);

>From 5c54d75f18134f4867217751c1be605bef7a7f3b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Mon, 25 Sep 2023 10:37:30 +0900
Subject: [PATCH 3/4] Print vararg info

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  2 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 36 ++++++++++++++++---
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |  2 +-
 3 files changed, 33 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 63deb00802446b9..2f65077dddf5aaf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -615,7 +615,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     ```
   }];
 
-  dag args = (ins TypeAttrOf<LLVM_FunctionType>:$callee_type,
+  dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
                   OptionalAttr<FlatSymbolRefAttr>:$callee,
                   Variadic<LLVM_Type>:$callee_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2def142b1d2d457..87138d24f7fc09a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1172,6 +1172,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 void CallOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
+  auto calleeType = *getCalleeType();
+  bool isVarArg = calleeType.isVarArg();
 
   // Print the direct callee if present as a function attribute, or an indirect
   // callee (first operand) otherwise.
@@ -1181,9 +1183,13 @@ void CallOp::print(OpAsmPrinter &p) {
   else
     p << getOperand(0);
 
+  if (isVarArg)
+    p << " vararg " << getCalleeType() << ' ';
+
   auto args = getOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
-  p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"});
+  p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
+                          {"callee", "callee_type"});
 
   p << " : ";
   if (!isDirect)
@@ -1197,7 +1203,7 @@ void CallOp::print(OpAsmPrinter &p) {
 /// succeeds. Returns failure otherwise.
 static ParseResult parseCallTypeAndResolveOperands(
     OpAsmParser &parser, OperationState &result, bool isDirect,
-    ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
+    ArrayRef<OpAsmParser::UnresolvedOperand> operands, bool isVarArg) {
   SMLoc trailingTypesLoc = parser.getCurrentLocation();
   SmallVector<Type> types;
   if (parser.parseColonTypeList(types))
@@ -1233,6 +1239,18 @@ static ParseResult parseCallTypeAndResolveOperands(
   if (funcType.getNumResults() != 0)
     result.addTypes(funcType.getResults());
 
+  if (!isVarArg) {
+    Type returnType;
+    if (funcType.getNumResults() == 0)
+      returnType = LLVM::LLVMVoidType::get(result.getContext());
+    else
+      returnType = funcType.getResult(0);
+    result.addAttribute(
+        "callee_type",
+        TypeAttr::get(LLVM::LLVMFunctionType::get(
+            returnType, funcType.getInputs(), /*isVarArg*/ false)));
+  }
+
   return success();
 }
 
@@ -1251,10 +1269,12 @@ static ParseResult parseOptionalCallFuncPtr(
   return success();
 }
 
-// <operation> ::= `llvm.call` (function-id | ssa-use)`(` ssa-use-list `)`
+// <operation> ::= `llvm.call` (function-id | ssa-use) var-arg-func-type?
+//                             `(` ssa-use-list `)`
 //                             attribute-dict? `:` (type `,`)? function-type
 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   SymbolRefAttr funcAttr;
+  TypeAttr calleeType;
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
 
   // Parse a function pointer for indirect calls.
@@ -1267,13 +1287,18 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
     if (parser.parseAttribute(funcAttr, "callee", result.attributes))
       return failure();
 
+  bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
+  if (isVarArg)
+    parser.parseOptionalAttribute(calleeType, "callee_type", result.attributes);
+
   // Parse the function arguments.
   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
       parser.parseOptionalAttrDict(result.attributes))
     return failure();
 
   // Parse the trailing type list and resolve the operands.
-  return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
+  return parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                         isVarArg);
 }
 
 ///===---------------------------------------------------------------------===//
@@ -1388,7 +1413,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   // Parse the trailing type list and resolve the function operands.
-  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
+  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
+                                      /*isVarArg*/ false))
     return failure();
 
   result.addSuccessors({normalDest, unwindDest});
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index c27bc1b6f29076b..cdd22f116a4ef9d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -213,7 +213,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     } else {
       call = builder.CreateCall(
           llvm::cast<llvm::FunctionType>(
-              moduleTranslation.convertType(callOp.getCalleeType())),
+              moduleTranslation.convertType(*callOp.getCalleeType())),
           operandsRef.front(), operandsRef.drop_front());
     }
     moduleTranslation.setAccessGroupsMetadata(callOp, call);

>From 66d8e4fdcf9e55934e7ab1858bac8eac44728346 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Tue, 26 Sep 2023 12:03:03 +0900
Subject: [PATCH 4/4] Fixes

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 49 ++++++++++++-------
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        | 13 +++--
 2 files changed, 41 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 87138d24f7fc09a..39bc824d25b25bc 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -950,15 +950,14 @@ static void printStoreType(OpAsmPrinter &printer, Operation *op,
 // CallOp
 //===----------------------------------------------------------------------===//
 
-namespace {
-static TypeRange getCallOpResults(LLVMFunctionType calleeType) {
-  SmallVector<Type> results;
+/// Get the MLIR Op-like result types of a LLVM fuction type
+static SmallVector<Type, 1> getCallOpResults(LLVMFunctionType calleeType) {
+  SmallVector<Type, 1> results;
   Type resultType = calleeType.getReturnType();
-  if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
+  if (!isa<LLVM::LLVMVoidType>(resultType))
     results.push_back(resultType);
   return results;
 }
-} // namespace
 
 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    StringRef callee, ValueRange args) {
@@ -977,9 +976,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
     resultType = LLVMVoidType::get(builder.getContext());
   else
     resultType = results.front();
-  std::vector<Type> argTypes(args.getTypes().begin(), args.getTypes().end());
-  auto calleeType =
-      LLVMFunctionType::get(resultType, argTypes, /*isVariadic*/ false);
+  auto calleeType = LLVMFunctionType::get(
+      resultType, llvm::to_vector(args.getTypes()), /*isVariadic*/ false);
   build(builder, state, results, TypeAttr::get(calleeType), callee, args,
         /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr,
@@ -1172,8 +1170,17 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 void CallOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
-  auto calleeType = *getCalleeType();
-  bool isVarArg = calleeType.isVarArg();
+
+  LLVMFunctionType calleeType;
+  bool isVarArg;
+
+  std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType();
+  if (optionalCalleeType.has_value()) {
+    calleeType = *optionalCalleeType;
+    isVarArg = calleeType.isVarArg();
+  } else {
+    isVarArg = false;
+  }
 
   // Print the direct callee if present as a function attribute, or an indirect
   // callee (first operand) otherwise.
@@ -1183,11 +1190,12 @@ void CallOp::print(OpAsmPrinter &p) {
   else
     p << getOperand(0);
 
-  if (isVarArg)
-    p << " vararg " << getCalleeType() << ' ';
-
   auto args = getOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
+
+  if (isVarArg)
+    p << " vararg(" << calleeType << ") ";
+
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
                           {"callee", "callee_type"});
 
@@ -1287,15 +1295,22 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
     if (parser.parseAttribute(funcAttr, "callee", result.attributes))
       return failure();
 
-  bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
-  if (isVarArg)
-    parser.parseOptionalAttribute(calleeType, "callee_type", result.attributes);
-
   // Parse the function arguments.
   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
       parser.parseOptionalAttrDict(result.attributes))
     return failure();
 
+  bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
+  if (isVarArg) {
+    if (parser.parseLParen().failed() ||
+        !parser
+             .parseOptionalAttribute(calleeType, "callee_type",
+                                     result.attributes)
+             .has_value() ||
+        parser.parseRParen().failed())
+      return failure();
+  }
+
   // Parse the trailing type list and resolve the operands.
   return parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
                                          isVarArg);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index cdd22f116a4ef9d..61f9088d7e30f73 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -211,10 +211,15 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
       call = builder.CreateCall(
           moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
     } else {
-      call = builder.CreateCall(
-          llvm::cast<llvm::FunctionType>(
-              moduleTranslation.convertType(*callOp.getCalleeType())),
-          operandsRef.front(), operandsRef.drop_front());
+      llvm::FunctionType *calleeType;
+      if (callOp.getCalleeType().has_value())
+        calleeType = llvm::cast<llvm::FunctionType>(
+            moduleTranslation.convertType(*callOp.getCalleeType()));
+      else
+        calleeType = getCalleeFunctionType(callOp.getResultTypes(),
+                                           callOp.getArgOperands());
+      call = builder.CreateCall(calleeType, operandsRef.front(),
+                                operandsRef.drop_front());
     }
     moduleTranslation.setAccessGroupsMetadata(callOp, call);
     moduleTranslation.setAliasScopeMetadata(callOp, call);



More information about the Mlir-commits mailing list