[Mlir-commits] [mlir] [mlir][func] Move return-type verification from ReturnOp to FuncOp (PR #184153)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 2 07:24:39 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
Move the operand count and type checks for func.return from ReturnOp::verify() into a new FuncOp::verify(). The verifier iterates all blocks in the callable region, skipping terminators that are not func.return (e.g. llvm.return or test.return that may appear during dialect conversion).
Fix several invalid-IR tests that had func.func return types inconsistent with the actual func.return operands. Previously these mismatches were silent because block verification stopped at an earlier expected error before reaching the func.return; now that FuncOp::verify() runs before body verification, the return types must be consistent.
---
Patch is 33.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/184153.diff
18 Files Affected:
- (modified) mlir/include/mlir/Dialect/Func/IR/FuncOps.td (+1-1)
- (modified) mlir/lib/Dialect/Func/IR/FuncOps.cpp (+31-17)
- (modified) mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir (-7)
- (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir (+1-1)
- (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir (+1-1)
- (modified) mlir/test/Dialect/EmitC/invalid_types.mlir (+1-1)
- (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+5-5)
- (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+15-15)
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+3-3)
- (modified) mlir/test/Dialect/NVGPU/invalid.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+1-1)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+3-3)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+1-1)
- (modified) mlir/test/Transforms/print-op-graph.mlir (+1-1)
- (modified) mlir/test/Transforms/test-dialect-conversion-pdll.mlir (+3-3)
- (modified) mlir/test/Transforms/test-legalizer.mlir (+5-3)
- (modified) mlir/test/Transforms/test-merge-blocks.mlir (+1-1)
- (modified) mlir/test/Transforms/test-pattern-selective-replacement.mlir (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index 06ce4f16c867d..ce2b7228cb954 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -357,6 +357,7 @@ def FuncOp : Func_Op<"func", [
bool isDeclaration() { return isExternal(); }
}];
let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -389,7 +390,6 @@ def ReturnOp : Func_Op<"return", [Pure, HasParent<"FuncOp">,
}]>];
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
- let hasVerifier = 1;
}
#endif // MLIR_DIALECT_FUNC_IR_FUNCOPS_TD
diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index 0e0863be9b476..8ee2c9956d2d6 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -284,23 +284,37 @@ FuncOp FuncOp::clone() {
// ReturnOp
//===----------------------------------------------------------------------===//
-LogicalResult ReturnOp::verify() {
- auto function = cast<FuncOp>((*this)->getParentOp());
-
- // The operand number and types must match the function signature.
- const auto &results = function.getFunctionType().getResults();
- if (getNumOperands() != results.size())
- return emitOpError("has ")
- << getNumOperands() << " operands, but enclosing function (@"
- << function.getName() << ") returns " << results.size();
-
- for (unsigned i = 0, e = results.size(); i != e; ++i)
- if (getOperand(i).getType() != results[i])
- return emitError() << "type of return operand " << i << " ("
- << getOperand(i).getType()
- << ") doesn't match function result type ("
- << results[i] << ")"
- << " in function @" << function.getName();
+LogicalResult FuncOp::verify() {
+ // External declarations have no body to check.
+ if (isDeclaration())
+ return success();
+ // Hoist the result types once; they are the same for every return site.
+ auto resultTypes = getFunctionType().getResults();
+ for (Block &block : getBody()) {
+ if (block.empty())
+ continue;
+ // Check func.return or other return-like terminators ops (e.g.
+ // llvm.return, test.return).
+ auto returnOp = dyn_cast<RegionBranchTerminatorOpInterface>(&block.back());
+ if (!returnOp)
+ continue;
+
+ if (returnOp->getNumOperands() != resultTypes.size())
+ return returnOp->emitOpError("has ")
+ << returnOp->getNumOperands()
+ << " operands, but enclosing function (@" << getName()
+ << ") returns " << resultTypes.size();
+
+ for (auto [i, opType] :
+ llvm::enumerate(llvm::zip(returnOp->getOperandTypes(), resultTypes))) {
+ auto [opTy, resTy] = opType;
+ if (opTy != resTy)
+ return returnOp->emitError()
+ << "type of return operand " << i << " (" << opTy
+ << ") doesn't match function result type (" << resTy
+ << ") in function @" << getName();
+ }
+ }
return success();
}
diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
index bdb0092d155be..22ebbf8618bde 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
@@ -87,13 +87,6 @@ func.func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> {
return %res : memref<20xi8>
}
-// BAREPTR-LABEL: func @check_return(
-// BAREPTR-SAME: %{{.*}}: memref<?xi8>) -> memref<?xi8>
-func.func @check_return(%in : memref<?xi8>) -> memref<?xi8> {
- // BAREPTR: llvm.return {{.*}} : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- return %in : memref<?xi8>
-}
-
// BAREPTR-LABEL: func @unconvertible_multiresult
// BAREPTR-SAME: %{{.*}}: memref<?xf32>, %{{.*}}: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>)
func.func @unconvertible_multiresult(%arg0: memref<?xf32> , %arg1: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
index 8bcbdad1437ab..1bc7acfe46a4a 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
@@ -67,7 +67,7 @@ module attributes {transform.with_named_sequence} {
{index_bitwidth = 32, use_opaque_pointers = true}
} {
legal_dialects = ["llvm", "memref", "nvvm"],
- legal_ops = ["func.func", "gpu.module", "gpu.yield"],
+ legal_ops = ["gpu.module", "gpu.yield"],
illegal_dialects = ["gpu"],
illegal_ops = ["llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
"llvm.ffloor", "llvm.log", "llvm.log10", "llvm.log2", "llvm.pow",
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index ac091664fe7da..4837800488e86 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1031,7 +1031,7 @@ module attributes {transform.with_named_sequence} {
use_bare_ptr_call_conv = false}
} {
legal_dialects = ["llvm", "memref", "nvvm", "test"],
- legal_ops = ["func.func", "gpu.module", "gpu.yield"],
+ legal_ops = ["gpu.module", "gpu.yield"],
illegal_dialects = ["gpu"],
illegal_ops = ["llvm.copysign", "llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
"llvm.ffloor", "llvm.frem", "llvm.log", "llvm.log10", "llvm.log2", "llvm.pow",
diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir
index c39a881ff26ad..f3f998ef56e0f 100644
--- a/mlir/test/Dialect/EmitC/invalid_types.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_types.mlir
@@ -84,7 +84,7 @@ func.func @illegal_array_with_lvalue_element_type(
// -----
-func.func @illegal_integer_type(%arg0: i11, %arg1: i11) -> i11 {
+func.func @illegal_integer_type(%arg0: i11, %arg1: i11) {
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer, index or opaque type supported by EmitC, but got 'i11'}}
%mul = "emitc.mul" (%arg0, %arg1) : (i11, i11) -> i11
return
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index b736cde7689ed..5068ddc42e1e5 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -687,7 +687,7 @@ func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3
func.func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
+ %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> {
// expected-error at +1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
@@ -699,7 +699,7 @@ func.func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
func.func @nvvm_invalid_mma_1(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
+ %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)> {
// expected-error at +1 {{Could not match allowed types for the result; expected one of !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> but got !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
@@ -711,7 +711,7 @@ func.func @nvvm_invalid_mma_1(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
func.func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
- %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
+ %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> {
// expected-error at +1 {{op requires attribute 'layoutA'}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{shape = #nvvm.shape<m = 8, n = 8, k = 4>}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
@@ -722,7 +722,7 @@ func.func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
func.func @nvvm_invalid_mma_3(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
// expected-error at +1 {{unimplemented variant for MMA shape <8, 8, 16>}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
@@ -732,7 +732,7 @@ func.func @nvvm_invalid_mma_3(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
func.func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
%b0 : i32,
- %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
// expected-error at +1 {{op requires b1Op attribute}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 579f0ac3ccad1..6b7417b4b82bc 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -143,7 +143,7 @@ func.func @llvm_nvvm_bar_warp_sync(%mask : i32) {
// CHECK-LABEL: @nvvm_mma_m8n8k4_row_col_f32_f32
func.func @nvvm_mma_m8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> {
// CHECK: nvvm.mma.sync
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
@@ -154,7 +154,7 @@ func.func @nvvm_mma_m8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf
// CHECK-LABEL: @nvvm_mma_m8n8k4_f16_f16
func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : vector<2xf16>, %c1 : vector<2xf16>, %c2 : vector<2xf16>, %c3 : vector<2xf16>) {
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>, %c2 : vector<2xf16>, %c3 : vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> {
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}]
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
@@ -164,7 +164,7 @@ func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
// CHECK-LABEL: @nvvm_mma_m16n8k8_bf16_bf16
func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32,
- %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
@@ -176,7 +176,7 @@ func.func @nvvm_mma_m16n8k8_bf16_bf16(%a0 : i32, %a1 : i32, %b0 : i32,
// CHECK-LABEL: @nvvm_mma_m16n8k16_bf16_bf16
func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32,
- %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
@@ -187,7 +187,7 @@ func.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i3
// CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
- %c0 : i32, %c1 : i32) {
+ %c0 : i32, %c1 : i32) -> !llvm.struct<(i32, i32)> {
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 8, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
%0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
@@ -200,7 +200,7 @@ func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
// CHECK-LABEL: @nvvm_mma_m16n8k8_f16_f16
func.func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>,
- %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
// CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
@@ -212,7 +212,7 @@ func.func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
func.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
@@ -224,7 +224,7 @@ func.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
@@ -235,7 +235,7 @@ func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
// CHECK-LABEL: @nvvm_mma_m16n8k4_tf32_f32
func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
%b0 : i32,
- %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 4>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
@@ -246,7 +246,7 @@ func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_s8
func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
- %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
@@ -259,7 +259,7 @@ func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_u8
func.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
%b0 : i32,
- %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
+ %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/184153
More information about the Mlir-commits
mailing list