[Mlir-commits] [mlir] [MLIR][LLVM] Always print variadic callee type (PR #99291)

Tobias Gysi llvmlistbot at llvm.org
Wed Jul 17 01:18:52 PDT 2024


https://github.com/gysit created https://github.com/llvm/llvm-project/pull/99291

This commit updates the LLVM dialect CallOp and InvokeOp to always print the calleeType if present. An additional verifier checks that only variadic calls have a non-null calleeType, and the builders are adapted accordingly to only set the calleeType for variadic calls.

The motivation for this change is to avoid that CallOp and InvokeOp have hidden state that is not pretty printed but that is used for example during the export to LLVM IR. This triggered downstream bugs where a call looked correct in MLIR, but had a completely different result type after exporting to LLVM IR. This change ensures the calleeType is only present when necessary, reducing the amount of redundant state, and always printed if present, avoiding any kind of hidden state.

>From c81c18ff1e60432ca5eca4c89a14f8b12887d3c4 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Wed, 17 Jul 2024 08:02:11 +0000
Subject: [PATCH] [MLIR][LLVM] Always print variadic callee type

This commit updates the LLVM dialect CallOp and InvokeOp to always print
the calleeType if present. An additional verifier checks that only
variadic calls have a non-null calleeType, and the builders are adapted
accordingly to only set the calleeType for variadic calls.

The motivation for this change is to avoid that CallOp and InvokeOp have
hidden state that is not pretty printed but that is used for example
during the export to LLVM IR. This triggered downstream bugs where a
call looked correct in MLIR, but had a completely different result type
after exporting to LLVM IR. This change ensures the calleeType is only
present when necessary, reducing the amount of redundant state, and
always printed if present, avoiding any kind of hidden state.
---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 77 ++++++++++++----------
 mlir/test/Dialect/LLVMIR/invalid.mlir      | 31 +++++++--
 2 files changed, 68 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9372caf6e32a7..1bfb944ee477b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -974,8 +974,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    FlatSymbolRefAttr callee, ValueRange args) {
   assert(callee && "expected non-null callee in direct call builder");
   build(builder, state, results,
-        TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
-        callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
+        /*callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
+        /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
         /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -996,8 +996,10 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
 void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
                    ValueRange args) {
-  build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
+  TypeAttr varArgCalleeType =
+      calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+  build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
+        callee, args, /*fastmathFlags=*/nullptr,
         /*branch_weights=*/nullptr, /*CConv=*/nullptr,
         /*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
         /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
@@ -1005,8 +1007,10 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
 
 void CallOp::build(OpBuilder &builder, OperationState &state,
                    LLVMFunctionType calleeType, ValueRange args) {
-  build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), /*callee=*/nullptr, args,
+  TypeAttr varArgCalleeType =
+      calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+  build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
+        /*callee=*/nullptr, args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1016,8 +1020,10 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                    ValueRange args) {
   auto calleeType = func.getFunctionType();
-  build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
+  TypeAttr varArgCalleeType =
+      calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+  build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
+        SymbolRefAttr::get(func), args,
         /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
         /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
         /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
@@ -1080,6 +1086,11 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (getNumResults() > 1)
     return emitOpError("must have 0 or 1 result");
 
+  // If the callee type attribute is present, it must be variadic.
+  if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
+    if (!calleeType->isVarArg())
+      return emitOpError("expected variadic callee type attribute");
+
   // Type for the callee, we'll get it differently depending if it is a direct
   // or indirect call.
   Type fnType;
@@ -1168,14 +1179,6 @@ void CallOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
-  LLVMFunctionType calleeType;
-  bool isVarArg = false;
-
-  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
-    calleeType = *optionalCalleeType;
-    isVarArg = calleeType.isVarArg();
-  }
-
   p << ' ';
 
   // Print calling convention.
@@ -1195,8 +1198,9 @@ void CallOp::print(OpAsmPrinter &p) {
   auto args = getOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
 
-  if (isVarArg)
-    p << " vararg(" << calleeType << ")";
+  // Print the callee type if the call is variadic.
+  if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
+    p << " vararg(" << *calleeType << ")";
 
   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
                           {getCConvAttrName(), "callee", "callee_type",
@@ -1333,9 +1337,11 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
                      ValueRange ops, Block *normal, ValueRange normalOps,
                      Block *unwind, ValueRange unwindOps) {
   auto calleeType = func.getFunctionType();
-  build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
-        unwindOps, nullptr, nullptr, normal, unwind);
+  TypeAttr varArgCalleeType =
+      calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+  build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
+        SymbolRefAttr::get(func), ops, normalOps, unwindOps, nullptr, nullptr,
+        normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1343,17 +1349,18 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
                      ValueRange normalOps, Block *unwind,
                      ValueRange unwindOps) {
   build(builder, state, tys,
-        TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
-        ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
+        /*callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr,
+        nullptr, normal, unwind);
 }
 
 void InvokeOp::build(OpBuilder &builder, OperationState &state,
                      LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
                      ValueRange ops, Block *normal, ValueRange normalOps,
                      Block *unwind, ValueRange unwindOps) {
-  build(builder, state, getCallOpResultTypes(calleeType),
-        TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
-        nullptr, normal, unwind);
+  TypeAttr varArgCalleeType =
+      calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
+  build(builder, state, getCallOpResultTypes(calleeType), varArgCalleeType,
+        callee, ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
 }
 
 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1393,6 +1400,11 @@ LogicalResult InvokeOp::verify() {
   if (getNumResults() > 1)
     return emitOpError("must have 0 or 1 result");
 
+  // If the callee type attribute is present, it must be variadic.
+  if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
+    if (!calleeType->isVarArg())
+      return emitOpError("expected variadic callee type attribute");
+
   Block *unwindDest = getUnwindDest();
   if (unwindDest->empty())
     return emitError("must have at least one operation in unwind destination");
@@ -1409,14 +1421,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
   auto callee = getCallee();
   bool isDirect = callee.has_value();
 
-  LLVMFunctionType calleeType;
-  bool isVarArg = false;
-
-  if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType()) {
-    calleeType = *optionalCalleeType;
-    isVarArg = calleeType.isVarArg();
-  }
-
   p << ' ';
 
   // Print calling convention.
@@ -1435,8 +1439,9 @@ void InvokeOp::print(OpAsmPrinter &p) {
   p << " unwind ";
   p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
 
-  if (isVarArg)
-    p << " vararg(" << calleeType << ")";
+  // Print the callee type if the invoke is variadic.
+  if (std::optional<LLVMFunctionType> calleeType = getCalleeType())
+    p << " vararg(" << *calleeType << ")";
 
   p.printOptionalAttrDict((*this)->getAttrs(),
                           {InvokeOp::getOperandSegmentSizeAttr(), "callee",
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 39f8e70b9fb7b..e932187a8e614 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1415,6 +1415,29 @@ func.func @invalid_zext_target_type_two(%arg: vector<1xi32>)  {
 
 // -----
 
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_callee_type(%arg: i32)  {
+  // expected-error at below {{expected variadic callee type attribute}}
+  llvm.call @non_variadic(%arg) vararg(!llvm.func<void (i32)>) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @non_variadic(%arg: i32)
+
+llvm.func @invalid_callee_type(%arg: i32)  {
+  // expected-error at below {{expected variadic callee type attribute}}
+  llvm.invoke @non_variadic(%arg) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (i32)>) : (i32) -> ()
+^bb1:
+  llvm.return
+^bb2:
+  llvm.return
+}
+
+// -----
+
 llvm.func @variadic(...)
 
 llvm.func @invalid_variadic_call(%arg: i32)  {
@@ -1445,14 +1468,14 @@ llvm.func @foo(%arg: !llvm.ptr) {
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{to use im2col mode, the tensor has to be at least 3-dimensional}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
   return
 }
 // -----
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{im2col offsets must be 2 less than number of coordinates}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
   return
 }
 
@@ -1460,7 +1483,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{expects coordinates between 1 to 5 dimension}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[]: !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[]: !llvm.ptr<3>, !llvm.ptr
   return
 }
 
@@ -1469,7 +1492,7 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
 
 func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
   // expected-error at +1 {{expects coordinates between 1 to 5 dimension}}
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr  
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd0,%crd1,%crd2,%crd3]: !llvm.ptr<3>, !llvm.ptr
   return
 }
 



More information about the Mlir-commits mailing list