[Mlir-commits] [mlir] [MLIR][LLVM] Always print variadic callee type (PR #99291)

Tobias Gysi llvmlistbot at llvm.org
Wed Jul 17 01:40:34 PDT 2024


https://github.com/gysit updated https://github.com/llvm/llvm-project/pull/99291

>From 7e25f4d4b1b6540f79df83d29420d08c168170d9 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Wed, 17 Jul 2024 08:02:11 +0000
Subject: [PATCH] [MLIR][LLVM] Always print variadic callee type

This commit updates the LLVM dialect CallOp and InvokeOp to always print
the calleeType if present. An additional verifier checks that only
variadic calls have a non-null calleeType, and the builders are adapted
accordingly to only set the calleeType for variadic calls.

The motivation for this change is to avoid that CallOp and InvokeOp have
hidden state that is not pretty printed but that is used for example
during the export to LLVM IR. This triggered downstream bugs where a
call looked correct in MLIR, but had a completely different result type
after exporting to LLVM IR. This change ensures the calleeType is only
present when necessary, reducing the amount of redundant state, and
always printed if present, avoiding any kind of hidden state.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 11 +--
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp  | 77 +++++++++++----------
 mlir/test/Dialect/LLVMIR/invalid.mlir       | 31 +++++++--
 3 files changed, 74 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f0dec69a5032a..18857e1cdefe0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -617,11 +617,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
     start with a function name (`@`-prefixed) and indirect calls start with an
     SSA value (`%`-prefixed). The direct callee, if present, is stored as a
     function attribute `callee`. For indirect calls, the callee is of `!llvm.ptr` type
-    and is stored as the first value in `callee_operands`. If the callee is a variadic
-    function, then the `callee_type` attribute must carry the function type. The
-    trailing type list contains the optional indirect callee type and the MLIR
-    function type, which differs from the LLVM function type that uses a explicit
-    void type to model functions that do not return a value.
+    and is stored as the first value in `callee_operands`. If and only if the
+    callee is a variadic function, then the `callee_type` attribute must carry
+    the variadic LLVM function type. The trailing type list contains the
+    optional indirect callee type and the MLIR function type, which differs from
+    the LLVM function type that uses an explicit void type to model functions
+    that do not return a value.
 
     Examples:
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9372caf6e32a7..f45968b315bf0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -974,8 +974,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    FlatSymbolRefAttr callee, ValueRange args) {
   assert(callee && "expected non-null callee in direct call builder");
   build(builder, state, results,
-        TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
-        callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+        /*callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
+        /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -996,8 +996,10 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
 void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
                    ValueRange args) {
-  build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
+  auto varArgCalleeType =
+      calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+  build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
+        callee, args, /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr, /*CConv=*/nullptr,
         /*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -1005,8 +1007,10 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
 
 void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, ValueRange args) {
-  build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), /*callee=*/nullptr, args,
+  auto varArgCalleeType =
+      calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+  build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
+        /*callee=*/nullptr, args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1016,8 +1020,10 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                    ValueRange args) {
   auto calleeType = func.getFunctionType();
-  build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
+  auto varArgCalleeType =
+      calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+  build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
+        SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1080,6 +1086,11 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (getNumResults() > 1)
     return emitOpError("must have 0 or 1 result");
 
+  // If the callee type attribute is present, it must be variadic.
+  if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
+    if (!calleeType->isVarArg())
+      return emitOpError("expected variadic callee type attribute");
+
   // Type for the callee, we'll get it differently depending if it is a direct
   // or indirect call.
   Type fnType;
@@ -1168,14 +1179,6 @@ void CallOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
-  LLVMFunctionType calleeType;
-  bool isVarArg = false;
-
-  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
-    calleeType = *optionalCalleeType;
-    isVarArg = calleeType.isVarArg();
-  }
-
   p << ' ';
 
   // Print calling convention.
@@ -1195,8 +1198,9 @@ void CallOp::print(OpAsmPrinter &p) {
   auto args = getOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
 
-  if (isVarArg)
-    p << " vararg(" << calleeType << ")";
+  // Print the callee type if the call is variadic.
+  if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
+    p << " vararg(" << *calleeType << ")";
 
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
                           {getCConvAttrName(), "callee", "callee_type",
@@ -1333,9 +1337,11 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                      ValueRange ops, Block *normal, ValueRange normalOps,
                      Block *unwind, ValueRange unwindOps) {
   auto calleeType = func.getFunctionType();
-  build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
-        unwindOps, nullptr, nullptr, normal, unwind);
+  auto varArgCalleeType =
+      calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+  build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
+        SymbolRefAttr::get(func), ops, normalOps, unwindOps, nullptr, nullptr,
+        normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1343,17 +1349,18 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
                      ValueRange normalOps, Block *unwind,
                      ValueRange unwindOps) {
   build(builder, state, tys,
-        TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
-        ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
+        /*callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
+        nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state,
                      LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
                      ValueRange ops, Block *normal, ValueRange normalOps,
                      Block *unwind, ValueRange unwindOps) {
-  build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
-        nullptr, normal, unwind);
+  auto varArgCalleeType =
+      calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+  build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
+        callee, ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
 }
 
 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1393,6 +1400,11 @@ LogicalResult InvokeOp::verify() {
   if (getNumResults() > 1)
     return emitOpError("must have 0 or 1 result");
 
+  // If the callee type attribute is present, it must be variadic.
+  if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
+    if (!calleeType->isVarArg())
+      return emitOpError("expected variadic callee type attribute");
+
   Block *unwindDest = getUnwindDest();
   if (unwindDest->empty())
     return emitError("must have at least one operation in unwind destination");
@@ -1409,14 +1421,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
-  LLVMFunctionType calleeType;
-  bool isVarArg = false;
-
-  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
-    calleeType = *optionalCalleeType;
-    isVarArg = calleeType.isVarArg();
-  }
-
   p << ' ';
 
   // Print calling convention.
@@ -1435,8 +1439,9 @@ void InvokeOp::print(OpAsmPrinter &p) {
   p << " unwind ";
   p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
 
-  if (isVarArg)
-    p << " vararg(" << calleeType << ")";
+  // Print the callee type if the invoke is variadic.
+  if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
+    p << " vararg(" << *calleeType << ")";
 
   p.printOptionalAttrDict((*this)->getAttrs(),
                           {InvokeOp::getOperandSegmentSizeAttr(), "callee",
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 39f8e70b9fb7b..e932187a8e614 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1415,6 +1415,29 @@ func.func @invalid_zext_target_type_two(%arg: vector<1xi32>)  {
 
 // -----
 
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_callee_type(%arg: i32)  {
+  // expected-error at below {{expected variadic callee type attribute}}
+  llvm.call @non_variadic(%arg) vararg(!llvm.func<void (i32)>) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_callee_type(%arg: i32)  {
+  // expected-error at below {{expected variadic callee type attribute}}
+  llvm.invoke @non_variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32)>) : (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
 llvm.func @variadic(...)
 
 llvm.func @invalid_variadic_call(%arg: i32)  {
@@ -1445,14 +1468,14 @@ llvm.func @foo(%arg: !llvm.ptr) {
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{to use im2col mode, the tensor has to be at least 3-dimensional}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
   return
 }
 // -----
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{im2col offsets must be 2 less than number of coordinates}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
   return
 }
 
@@ -1460,7 +1483,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{expects coordinates between 1 to 5 dimension}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[]: !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[]: !llvm.ptr<3>, !llvm.ptr
   return
 }
 
@@ -1469,7 +1492,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{expects coordinates between 1 to 5 dimension}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr
   return
 }
 



More information about the Mlir-commits mailing list