[Mlir-commits] [mlir] [mlir][linalg] Restrict fill initial value type to output element type (PR #169567)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Nov 25 13:16:03 PST 2025
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/169567
>From 65fcd86fdc3c5a73e1e32426c2e0853b01275a80 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 25 Nov 2025 15:10:50 -0500
Subject: [PATCH 1/3] [mlir][linalg] Restrict fill initial value type to output
element type
Disallow implicit casting, which is surprising, and, IME, usually
indicative of copy-paste errors.
Because the initial value must be a scalar, I don't expect this to
affect any data movement.
---
mlir/docs/Dialects/Linalg/OpDSL.md | 17 +++---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 18 ++----
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 56 ++++++++++++++-----
.../linalg/opdsl/ops/core_named_ops.py | 8 +--
.../Affine/value-bounds-reification.mlir | 4 +-
.../Linalg/fusion-elementwise-ops.mlir | 26 +--------
.../generalize-named-polymorphic-ops.mlir | 8 +--
mlir/test/Dialect/Linalg/invalid.mlir | 18 ++++++
.../Linalg/CPU/test-matmul-masked-vec.mlir | 4 +-
.../Dialect/Transform/match_matmul.mlir | 4 +-
.../integration/dialects/linalg/opsrun.py | 26 ++++-----
11 files changed, 103 insertions(+), 86 deletions(-)
diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index b892bbe427a18..37604fc17dd9b 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -311,16 +311,17 @@ An example for a rank polymorphic operation is `fill`:
```python
@linalg_structured_op
-def fill(value=ScalarDef(T1),
- O=TensorDef(U, output=True)):
- O[None] = TypeFn.cast_signed(U, value)
+def fill(value=ScalarDef(T),
+ O=TensorDef(T, output=True)):
+ O[None] = value
```
-The operation sets the elements of the output tensor `O` to `value`. All
-operands are either scalars or rank zero tensors that are accessed using the
-index `None`. The operation thus performs a scalar computation that trivially
-extends to a multi-dimensional pointwise computation. As a result, we may use
-`fill` with arbitrary ranked output tensors:
+The operation sets the elements of the output tensor `O` to `value`. The value
+type must match the element type of the output tensor. All operands are either
+scalars or rank zero tensors that are accessed using the index `None`. The
+operation thus performs a scalar computation that trivially extends to a
+multi-dimensional pointwise computation. As a result, we may use `fill` with
+arbitrary ranked output tensors:
```python
tensor_2d = tensor.EmptyOp([4, 8], f32)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 9aae1b850c3a0..521afc991063f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -6054,9 +6054,9 @@ metadata: !LinalgOpMetadata
doc: |-
Fills the output tensor with the given value.
- Works for arbitrary ranked output tensors since the operation performs scalar
- accesses only and is thus rank polymorphic. Numeric casting is performed on
- the value operand, promoting it to the same data type as the output.
+ Works for arbitrary ranked output tensors since the operation performs
+ scalar accesses only and is thus rank polymorphic. The value operand
+ type must match the element type of the output.
implements:
- LinalgFillOpInterface
defines:
@@ -6066,11 +6066,11 @@ structured_op: !LinalgStructuredOpConfig
- !LinalgOperandDefConfig
name: value
kind: scalar
- type_var: T1
+ type_var: T
- !LinalgOperandDefConfig
name: O
kind: output_tensor
- type_var: U
+ type_var: T
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
@@ -6081,13 +6081,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: value
+ scalar_arg: value
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_rng_2d
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index dcc1ef9e997ea..3b0b7d8c492c8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1057,35 +1057,65 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
+namespace mlir::linalg::detail {
enum class MatchFillResult {
Success = 0,
NotLinalgOp,
WrongNumOperands,
- NotScalarInput
+ NotScalarInput,
+ TypeMismatch
};
-static MatchFillResult isFillInterfaceImpl(Operation *op) {
+struct FillInterfaceResult {
+ MatchFillResult result = MatchFillResult::Success;
+ Type scalarType;
+ Type outputElementType;
+};
+static FillInterfaceResult isFillInterfaceImpl(Operation *op) {
+ FillInterfaceResult fillResult = {};
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
- if (!linalgOp)
- return MatchFillResult::NotLinalgOp;
- if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
- return MatchFillResult::WrongNumOperands;
+ if (!linalgOp) {
+ fillResult.result = MatchFillResult::NotLinalgOp;
+ return fillResult;
+ }
+ if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) {
+ fillResult.result = MatchFillResult::WrongNumOperands;
+ return fillResult;
+ }
OpOperand *value = linalgOp.getDpsInputOperand(0);
- if (!linalgOp.isScalar(value))
- return MatchFillResult::NotScalarInput;
+ if (!linalgOp.isScalar(value)) {
+ fillResult.result = MatchFillResult::NotScalarInput;
+ return fillResult;
+ }
+
+ // Check that the scalar input type matches the output element type.
+ OpOperand *output = linalgOp.getDpsInitOperand(0);
+ Type scalarType = value->get().getType();
+ Type outputElementType = getElementTypeOrSelf(output->get().getType());
+ if (scalarType != outputElementType) {
+ fillResult.result = MatchFillResult::TypeMismatch;
+ fillResult.scalarType = scalarType;
+ fillResult.outputElementType = outputElementType;
+ return fillResult;
+ }
- return MatchFillResult::Success;
+ return fillResult;
}
+} // namespace mlir::linalg::detail
LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
- auto res = isFillInterfaceImpl(op);
- if (res == MatchFillResult::NotLinalgOp)
+ auto [result, scalarType, outputElementType] = isFillInterfaceImpl(op);
+ if (result == MatchFillResult::NotLinalgOp)
return op->emitError("expected a LinalgOp");
- if (res == MatchFillResult::WrongNumOperands)
+ if (result == MatchFillResult::WrongNumOperands)
return op->emitError("expected op with 1 input and 1 output");
- if (res == MatchFillResult::NotScalarInput)
+ if (result == MatchFillResult::NotScalarInput)
return op->emitError("expected op with scalar input");
+ if (result == MatchFillResult::TypeMismatch)
+ return op->emitOpError("expected fill value type (")
+ << scalarType << ") to match output element type ("
+ << outputElementType << ")";
return success();
}
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index fd4a5a848f1e3..9c24f94fcf612 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -1729,16 +1729,16 @@ def pooling_ndhwc_min(
@linalg_structured_op
-def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
+def fill(value=ScalarDef(T), O=TensorDef(T, output=True)):
"""Fills the output tensor with the given value.
Works for arbitrary ranked output tensors since the operation performs scalar
- accesses only and is thus rank polymorphic. Numeric casting is performed on
- the value operand, promoting it to the same data type as the output.
+ accesses only and is thus rank polymorphic. The value type must match the
+ element type of the output tensor or memref.
"""
implements(FillOpInterface)
defines(Canonicalizer)
- O[None] = TypeFn.cast_signed(U, value)
+ O[None] = value
@linalg_structured_op
diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
index 817614be50533..2e801028057a1 100644
--- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
@@ -36,13 +36,13 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index
// CHECK: "test.some_use"(%[[c5]])
// CHECK: %[[c5:.*]] = arith.constant 5 : index
// CHECK: "test.some_use"(%[[c5]])
-func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
+func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: i32) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
scf.for %iv = %c0 to %ub step %c4 {
%sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
%slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
- %filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
+ %filled = linalg.fill ins(%f : i32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
%bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
"test.some_use"(%bound) : (index) -> ()
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index bc55c12c02f29..6f1a422324e08 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -921,30 +921,6 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
// -----
-// CHECK-LABEL: func @fold_fill_generic_different_dtype
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
-// CHECK-NOT: linalg.fill
-// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
-// CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
-#map0 = affine_map<(d0) -> (d0)>
-func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 7.0 : f32
- %0 = tensor.dim %arg0, %c0 : tensor<?xf16>
- %1 = tensor.empty(%0) : tensor<?xf16>
- %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
- %3 = tensor.empty(%0) : tensor<?xf16>
- %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
- ^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
- %5 = arith.addf %arg1, %arg2 : f16
- linalg.yield %5 : f16
- } -> tensor<?xf16>
- return %4 : tensor<?xf16>
-}
-
-// -----
-
// CHECK-LABEL: func @fold_fill_generic_mixedaccess
// CHECK-NOT: linalg.fill
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
@@ -1079,4 +1055,4 @@ module {
// CHECK-NOT: linalg.generic
// CHECK: tensor.expand_shape
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
\ No newline at end of file
+// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 290c6c7c36f76..4526dc90fad2e 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -380,8 +380,8 @@ func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: t
// -----
-func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
- %0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
+func.func @generalize_fill_0d(%value: f32, %O: tensor<f32>) -> tensor<f32> {
+ %0 = linalg.fill ins(%value: f32) outs(%O : tensor<f32>) -> tensor<f32>
return %0: tensor<f32>
}
@@ -394,8 +394,8 @@ func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
// -----
-func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) {
- linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>)
+func.func @generalize_fill_2d(%value: f32, %O: memref<16x32xf32>) {
+ linalg.fill ins(%value: f32) outs(%O : memref<16x32xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index fabc8e610612d..1f554e6c45da7 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -352,6 +352,24 @@ func.func @illegal_fill_tensor_with_memref_return
// -----
+func.func @illegal_fill_element_type_truncation(%arg0 : tensor<2xf32>, %arg1 : f64) -> tensor<2xf32>
+{
+ // expected-error @+1 {{'linalg.fill' op expected fill value type ('f64') to match output element type ('f32')}}
+ %0 = linalg.fill ins(%arg1 : f64) outs(%arg0 : tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+func.func @illegal_fill_element_type_extension(%arg0 : tensor<2xi32>, %arg1 : i16) -> tensor<2xi32>
+{
+ // expected-error @+1 {{'linalg.fill' op expected fill value type ('i16') to match output element type ('i32')}}
+ %0 = linalg.fill ins(%arg1 : i16) outs(%arg0 : tensor<2xi32>) -> tensor<2xi32>
+ return %0 : tensor<2xi32>
+}
+
+// -----
+
func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32>
{
// expected-error @+1 {{expected op with scalar input}}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index 8fa32d7aeb586..bbda8d4e99d04 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -27,8 +27,8 @@ func.func @main() {
%A_dyn = tensor.cast %A : tensor<8x2xf32> to tensor<?x?xf32>
%B_dyn = tensor.cast %B : tensor<2x4xf32> to tensor<?x?xf32>
- %c0_i32 = arith.constant 0 : i32
- %C_init = linalg.fill ins(%c0_i32 : i32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %c0_f32 = arith.constant 0.0 : f32
+ %C_init = linalg.fill ins(%c0_f32 : f32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
%res = linalg.matmul ins(%A_dyn, %B_dyn: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C_init: tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index a374d9a611258..e3fee917cdeaa 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -63,11 +63,11 @@ func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> te
}
func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> {
- %cst = arith.constant 0.0 : f64
+ %cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<10x15xf32>
// expected-remark @below {{fill}}
- %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
%real_lhs = linalg.mul
ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32>
diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 8f202318146ee..8eff573f98ad3 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -25,13 +25,13 @@ def log(*args):
%O1 = memref.alloc() : memref<16xi32>
%O2 = memref.alloc() : memref<4x16xi32>
- %val0 = arith.constant 1.0 : f32
- %val1 = arith.constant 2.0 : f32
- %val2 = arith.constant 3.0 : f32
+ %val0 = arith.constant 1 : i32
+ %val1 = arith.constant 2 : i32
+ %val2 = arith.constant 3 : i32
- call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
- call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
- call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
+ call @fill_0d_on_buffers(%val0, %O0) : (i32, memref<i32>) -> ()
+ call @fill_1d_on_buffers(%val1, %O1) : (i32, memref<16xi32>) -> ()
+ call @fill_2d_on_buffers(%val2, %O2) : (i32, memref<4x16xi32>) -> ()
%c0 = arith.constant 0 : index
%res0 = memref.load %O0[] : memref<i32>
@@ -149,19 +149,18 @@ def transform(module, boilerplate):
def test_fill_builtin():
with Context() as ctx, Location.unknown():
module = Module.create()
- f32 = F32Type.get()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
def fill_0d_on_buffers(value, out):
linalg.fill(value, outs=[out])
- @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
def fill_1d_on_buffers(value, out):
linalg.fill(value, outs=[out])
- @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
def fill_2d_on_buffers(value, out):
linalg.fill(value, outs=[out])
@@ -184,19 +183,18 @@ def fill_2d_on_buffers(value, out):
def test_fill_generic():
with Context() as ctx, Location.unknown():
module = Module.create()
- f32 = F32Type.get()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
def fill_0d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)
- @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
def fill_1d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)
- @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+ @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
def fill_2d_on_buffers(value, out):
linalg.fill(value, outs=[out], emit_generic=True)
>From e0767d9576c3212f729b0e9799d40873236cc6ae Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 25 Nov 2025 15:46:42 -0500
Subject: [PATCH 2/3] Use anonymous namespace
---
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3b0b7d8c492c8..f768bd14295f9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1057,7 +1057,7 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
-namespace mlir::linalg::detail {
+namespace {
enum class MatchFillResult {
Success = 0,
NotLinalgOp,
@@ -1071,6 +1071,8 @@ struct FillInterfaceResult {
Type scalarType;
Type outputElementType;
};
+} // namespace
+
static FillInterfaceResult isFillInterfaceImpl(Operation *op) {
FillInterfaceResult fillResult = {};
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
@@ -1102,7 +1104,6 @@ static FillInterfaceResult isFillInterfaceImpl(Operation *op) {
return fillResult;
}
-} // namespace mlir::linalg::detail
LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
auto [result, scalarType, outputElementType] = isFillInterfaceImpl(op);
>From 09e730550e168c8a3f2a0c93e0475bfc7fe46ea9 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 25 Nov 2025 16:15:52 -0500
Subject: [PATCH 3/3] Simplify error message
---
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 52 +++++++------------
1 file changed, 20 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f768bd14295f9..b4b1347493529 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1065,58 +1065,46 @@ enum class MatchFillResult {
NotScalarInput,
TypeMismatch
};
-
-struct FillInterfaceResult {
- MatchFillResult result = MatchFillResult::Success;
- Type scalarType;
- Type outputElementType;
-};
} // namespace
-static FillInterfaceResult isFillInterfaceImpl(Operation *op) {
- FillInterfaceResult fillResult = {};
+static MatchFillResult isFillInterfaceImpl(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
- if (!linalgOp) {
- fillResult.result = MatchFillResult::NotLinalgOp;
- return fillResult;
- }
- if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) {
- fillResult.result = MatchFillResult::WrongNumOperands;
- return fillResult;
- }
+ if (!linalgOp)
+ return MatchFillResult::NotLinalgOp;
+ if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
+ return MatchFillResult::WrongNumOperands;
OpOperand *value = linalgOp.getDpsInputOperand(0);
- if (!linalgOp.isScalar(value)) {
- fillResult.result = MatchFillResult::NotScalarInput;
- return fillResult;
- }
+ if (!linalgOp.isScalar(value))
+ return MatchFillResult::NotScalarInput;
// Check that the scalar input type matches the output element type.
OpOperand *output = linalgOp.getDpsInitOperand(0);
Type scalarType = value->get().getType();
Type outputElementType = getElementTypeOrSelf(output->get().getType());
- if (scalarType != outputElementType) {
- fillResult.result = MatchFillResult::TypeMismatch;
- fillResult.scalarType = scalarType;
- fillResult.outputElementType = outputElementType;
- return fillResult;
- }
+ if (scalarType != outputElementType)
+ return MatchFillResult::TypeMismatch;
- return fillResult;
+ return MatchFillResult::Success;
}
LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
- auto [result, scalarType, outputElementType] = isFillInterfaceImpl(op);
- if (result == MatchFillResult::NotLinalgOp)
+ MatchFillResult res = isFillInterfaceImpl(op);
+ if (res == MatchFillResult::NotLinalgOp)
return op->emitError("expected a LinalgOp");
- if (result == MatchFillResult::WrongNumOperands)
+ if (res == MatchFillResult::WrongNumOperands)
return op->emitError("expected op with 1 input and 1 output");
- if (result == MatchFillResult::NotScalarInput)
+ if (res == MatchFillResult::NotScalarInput)
return op->emitError("expected op with scalar input");
- if (result == MatchFillResult::TypeMismatch)
+ if (res == MatchFillResult::TypeMismatch) {
+ auto linalgOp = cast<linalg::LinalgOp>(op);
+ Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
+ Type outputElementType =
+ getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
return op->emitOpError("expected fill value type (")
<< scalarType << ") to match output element type ("
<< outputElementType << ")";
+ }
return success();
}
More information about the Mlir-commits
mailing list