[Mlir-commits] [mlir] [MLIR][LLVM] Always print variadic callee type (PR #99293)

Tobias Gysi llvmlistbot at llvm.org
Wed Jul 17 06:06:10 PDT 2024


https://github.com/gysit updated https://github.com/llvm/llvm-project/pull/99293

>From 860bfac8b6e2b40bf7340ff9515037088564117d Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Wed, 17 Jul 2024 08:02:11 +0000
Subject: [PATCH] [MLIR][LLVM] Always print variadic callee type

This commit updates the LLVM dialect CallOp and InvokeOp to always print
the variadic callee type (previously callee type) if present. An
additional verifier checks that only variadic calls have a non-null
variadic callee type, and the builders are adapted accordingly to only
set the variadic callee type for variadic calls.

The motivation of this change is that CallOp and InvokeOp don't have
hidden state that is not pretty printed, but used during the export to
LLVM IR. Previously, it could happen that a call looked correct in MLIR,
but the return type changed after exporting to LLVM IR (since it has
been taken from the hidden callee type attribute). After landing this
change, this is not possible anymore since the variadic callee type is
always printed if present.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td |  15 +--
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp  | 103 +++++++++++---------
 mlir/test/Dialect/LLVMIR/invalid.mlir       |  31 +++++-
 3 files changed, 94 insertions(+), 55 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f0dec69a5032a..7adbe8a4e5b31 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -560,7 +560,7 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
                       DeclareOpInterfaceMethods<BranchWeightOpInterface>,
                       Terminator]> {
   let arguments = (ins
-                   OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+                   OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
                    OptionalAttr<FlatSymbolRefAttr>:$callee,
                    Variadic<LLVM_Type>:$callee_operands,
                    Variadic<LLVM_Type>:$normalDestOperands,
@@ -617,11 +617,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     start with a function name (`@`-prefixed) and indirect calls start with an
     SSA value (`%`-prefixed). The direct callee, if present, is stored as a
     function attribute `callee`. For indirect calls, the callee is of `!llvm.ptr` type
-    and is stored as the first value in `callee_operands`. If the callee is a variadic
-    function, then the `callee_type` attribute must carry the function type. The
-    trailing type list contains the optional indirect callee type and the MLIR
-    function type, which differs from the LLVM function type that uses a explicit
-    void type to model functions that do not return a value.
+    and is stored as the first value in `callee_operands`. If and only if the
+    callee is a variadic function, then the `var_callee_type` attribute must
+    carry the variadic LLVM function type. The trailing type list contains the
+    optional indirect callee type and the MLIR function type, which differs from
+    the LLVM function type that uses an explicit void type to model functions
+    that do not return a value.
 
     Examples:
 
@@ -644,7 +645,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     ```
   }];
 
-  dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$callee_type,
+  dag args = (ins OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
                   OptionalAttr<FlatSymbolRefAttr>:$callee,
                   Variadic<LLVM_Type>:$callee_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9372caf6e32a7..405e2aa76d610 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -948,6 +948,11 @@ static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
   return results;
 }
 
+/// Gets the variadic callee type for a LLVMFunctionType.
+static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
+  return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+}
+
 /// Constructs a LLVMFunctionType from MLIR `results` and `args`.
 static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
                                         ValueRange args) {
@@ -974,8 +979,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    FlatSymbolRefAttr callee, ValueRange args) {
   assert(callee && "expected non-null callee in direct call builder");
   build(builder, state, results,
-        TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
-        callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+        /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
+        /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -997,7 +1002,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
                    ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
+        getCallOpVarCalleeType(calleeType), callee, args,
+        /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr, /*CConv=*/nullptr,
         /*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -1006,7 +1012,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
 void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, ValueRange args) {
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), /*callee=*/nullptr, args,
+        getCallOpVarCalleeType(calleeType),
+        /*callee=*/nullptr, args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1017,7 +1024,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                    ValueRange args) {
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
+        getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1080,6 +1087,12 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (getNumResults() > 1)
     return emitOpError("must have 0 or 1 result");
 
+  // Verify the variadic callee type is a variadic function type.
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    if (!varCalleeType->isVarArg())
+      return emitOpError(
+          "expected var_callee_type to be a variadic function type");
+
   // Type for the callee, we'll get it differently depending if it is a direct
   // or indirect call.
   Type fnType;
@@ -1120,7 +1133,7 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (!funcType)
     return emitOpError("callee does not have a functional type: ") << fnType;
 
-  if (funcType.isVarArg() && !getCalleeType())
+  if (funcType.isVarArg() && !getVarCalleeType())
     return emitOpError() << "missing callee type attribute for vararg call";
 
   // Verify that the operand and result types match the callee.
@@ -1168,14 +1181,6 @@ void CallOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
-  LLVMFunctionType calleeType;
-  bool isVarArg = false;
-
-  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
-    calleeType = *optionalCalleeType;
-    isVarArg = calleeType.isVarArg();
-  }
-
   p << ' ';
 
   // Print calling convention.
@@ -1195,11 +1200,13 @@ void CallOp::print(OpAsmPrinter &p) {
   auto args = getOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
 
-  if (isVarArg)
-    p << " vararg(" << calleeType << ")";
+  // Print the variadic callee type if the call is variadic.
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    p << " vararg(" << *varCalleeType << ")";
 
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
-                          {getCConvAttrName(), "callee", "callee_type",
+                          {getCConvAttrName(), "callee",
+                           getVarCalleeTypeAttrName(),
                            getTailCallKindAttrName()});
 
   p << " : ";
@@ -1270,11 +1277,11 @@ static ParseResult parseOptionalCallFuncPtr(
 
 // <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
 //                             `(` ssa-use-list `)`
-//                             ( `vararg(` var-arg-func-type `)` )?
+//                             ( `vararg(` var-callee-type `)` )?
 //                             attribute-dict? `:` (type `,`)? function-type
 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
   SymbolRefAttr funcAttr;
-  TypeAttr calleeType;
+  TypeAttr varCalleeType;
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
 
   // Default to C Calling Convention if no keyword is provided.
@@ -1305,8 +1312,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
 
   bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
   if (isVarArg) {
+    StringAttr varCalleeTypeAttrName =
+        CallOp::getVarCalleeTypeAttrName(result.name);
     if (parser.parseLParen().failed() ||
-        parser.parseAttribute(calleeType, "callee_type", result.attributes)
+        parser
+            .parseAttribute(varCalleeType, varCalleeTypeAttrName,
+                            result.attributes)
             .failed() ||
         parser.parseRParen().failed())
       return failure();
@@ -1320,8 +1331,8 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 LLVMFunctionType CallOp::getCalleeFunctionType() {
-  if (getCalleeType())
-    return *getCalleeType();
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    return *varCalleeType;
   return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
 }
 
@@ -1334,8 +1345,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                      Block *unwind, ValueRange unwindOps) {
   auto calleeType = func.getFunctionType();
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
-        unwindOps, nullptr, nullptr, normal, unwind);
+        getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
+        normalOps, unwindOps, nullptr, nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1343,8 +1354,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
                      ValueRange normalOps, Block *unwind,
                      ValueRange unwindOps) {
   build(builder, state, tys,
-        TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
-        ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
+        /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
+        nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1352,8 +1363,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
                      ValueRange ops, Block *normal, ValueRange normalOps,
                      Block *unwind, ValueRange unwindOps) {
   build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
-        nullptr, normal, unwind);
+        getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
+        nullptr, nullptr, normal, unwind);
 }
 
 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1393,6 +1404,12 @@ LogicalResult InvokeOp::verify() {
   if (getNumResults() > 1)
     return emitOpError("must have 0 or 1 result");
 
+  // Verify the variadic callee type is a variadic function type.
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    if (!varCalleeType->isVarArg())
+      return emitOpError(
+          "expected var_callee_type to be a variadic function type");
+
   Block *unwindDest = getUnwindDest();
   if (unwindDest->empty())
     return emitError("must have at least one operation in unwind destination");
@@ -1409,14 +1426,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
-  LLVMFunctionType calleeType;
-  bool isVarArg = false;
-
-  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
-    calleeType = *optionalCalleeType;
-    isVarArg = calleeType.isVarArg();
-  }
-
   p << ' ';
 
   // Print calling convention.
@@ -1435,12 +1444,14 @@ void InvokeOp::print(OpAsmPrinter &p) {
   p << " unwind ";
   p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
 
-  if (isVarArg)
-    p << " vararg(" << calleeType << ")";
+  // Print the variadic callee type if the invoke is variadic.
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    p << " vararg(" << *varCalleeType << ")";
 
   p.printOptionalAttrDict((*this)->getAttrs(),
                           {InvokeOp::getOperandSegmentSizeAttr(), "callee",
-                           "callee_type", InvokeOp::getCConvAttrName()});
+                           InvokeOp::getVarCalleeTypeAttrName(),
+                           InvokeOp::getCConvAttrName()});
 
   p << " : ";
   if (!isDirect)
@@ -1453,12 +1464,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
 //                  `(` ssa-use-list `)`
 //                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
 //                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
-//                  ( `vararg(` var-arg-func-type `)` )?
+//                  ( `vararg(` var-callee-type `)` )?
 //                  attribute-dict? `:` (type `,`)? function-type
 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
   SymbolRefAttr funcAttr;
-  TypeAttr calleeType;
+  TypeAttr varCalleeType;
   Block *normalDest, *unwindDest;
   SmallVector<Value, 4> normalOperands, unwindOperands;
   Builder &builder = parser.getBuilder();
@@ -1488,8 +1499,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
 
   bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
   if (isVarArg) {
+    StringAttr varCalleeTypeAttrName =
+        InvokeOp::getVarCalleeTypeAttrName(result.name);
     if (parser.parseLParen().failed() ||
-        parser.parseAttribute(calleeType, "callee_type", result.attributes)
+        parser
+            .parseAttribute(varCalleeType, varCalleeTypeAttrName,
+                            result.attributes)
             .failed() ||
         parser.parseRParen().failed())
       return failure();
@@ -1515,8 +1530,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 LLVMFunctionType InvokeOp::getCalleeFunctionType() {
-  if (getCalleeType())
-    return *getCalleeType();
+  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
+    return *varCalleeType;
   return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 39f8e70b9fb7b..c5953087ba991 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1415,6 +1415,29 @@ func.func @invalid_zext_target_type_two(%arg: vector<1xi32>)  {
 
 // -----
 
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_callee_type(%arg: i32)  {
+  // expected-error at below {{expected var_callee_type to be a variadic function type}}
+  llvm.call @non_variadic(%arg) vararg(!llvm.func<void (i32)>) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_callee_type(%arg: i32)  {
+  // expected-error at below {{expected var_callee_type to be a variadic function type}}
+  llvm.invoke @non_variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32)>) : (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
 llvm.func @variadic(...)
 
 llvm.func @invalid_variadic_call(%arg: i32)  {
@@ -1445,14 +1468,14 @@ llvm.func @foo(%arg: !llvm.ptr) {
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{to use im2col mode, the tensor has to be at least 3-dimensional}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
   return
 }
 // -----
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{im2col offsets must be 2 less than number of coordinates}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
   return
 }
 
@@ -1460,7 +1483,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{expects coordinates between 1 to 5 dimension}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[]: !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[]: !llvm.ptr<3>, !llvm.ptr
   return
 }
 
@@ -1469,7 +1492,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{expects coordinates between 1 to 5 dimension}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr
   return
 }
 



More information about the Mlir-commits mailing list