[Mlir-commits] [mlir] [mlir][linalg] Restrict fill initial value type to output element type (PR #169567)

Jakub Kuderski llvmlistbot at llvm.org
Tue Nov 25 13:16:03 PST 2025


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/169567

>From 65fcd86fdc3c5a73e1e32426c2e0853b01275a80 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 25 Nov 2025 15:10:50 -0500
Subject: [PATCH 1/3] [mlir][linalg] Restrict fill initial value type to output
 element type

Disallow implicit casting, which is surprising, and, IME, usually
indicative of copy-paste errors.

Because the initial value must be a scalar, I don't expect this to
affect any data movement.
---
 mlir/docs/Dialects/Linalg/OpDSL.md            | 17 +++---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 18 ++----
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 56 ++++++++++++++-----
 .../linalg/opdsl/ops/core_named_ops.py        |  8 +--
 .../Affine/value-bounds-reification.mlir      |  4 +-
 .../Linalg/fusion-elementwise-ops.mlir        | 26 +--------
 .../generalize-named-polymorphic-ops.mlir     |  8 +--
 mlir/test/Dialect/Linalg/invalid.mlir         | 18 ++++++
 .../Linalg/CPU/test-matmul-masked-vec.mlir    |  4 +-
 .../Dialect/Transform/match_matmul.mlir       |  4 +-
 .../integration/dialects/linalg/opsrun.py     | 26 ++++-----
 11 files changed, 103 insertions(+), 86 deletions(-)

diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index b892bbe427a18..37604fc17dd9b 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -311,16 +311,17 @@ An example for a rank polymorphic operation is `fill`:
 
 ```python
 @linalg_structured_op
-def fill(value=ScalarDef(T1),
-         O=TensorDef(U, output=True)):
-  O[None] = TypeFn.cast_signed(U, value)
+def fill(value=ScalarDef(T),
+         O=TensorDef(T, output=True)):
+  O[None] = value
 ```
 
-The operation sets the elements of the output tensor `O` to `value`. All
-operands are either scalars or rank zero tensors that are accessed using the
-index `None`. The operation thus performs a scalar computation that trivially
-extends to a multi-dimensional pointwise computation. As a result, we may use
-`fill` with arbitrary ranked output tensors:
+The operation sets the elements of the output tensor `O` to `value`. The value
+type must match the element type of the output tensor. All operands are either
+scalars or rank zero tensors that are accessed using the index `None`. The
+operation thus performs a scalar computation that trivially extends to a
+multi-dimensional pointwise computation. As a result, we may use `fill` with
+arbitrary ranked output tensors:
 
 ```python
 tensor_2d = tensor.EmptyOp([4, 8], f32)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 9aae1b850c3a0..521afc991063f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -6054,9 +6054,9 @@ metadata: !LinalgOpMetadata
   doc: |-
     Fills the output tensor with the given value.
 
-    Works for arbitrary ranked output tensors since the operation performs scalar
-    accesses only and is thus rank polymorphic. Numeric casting is performed on
-    the value operand, promoting it to the same data type as the output.
+    Works for arbitrary ranked output tensors since the operation performs
+    scalar accesses only and is thus rank polymorphic. The value operand
+    type must match the element type of the output.
   implements:
   - LinalgFillOpInterface
   defines:
@@ -6066,11 +6066,11 @@ structured_op: !LinalgStructuredOpConfig
   - !LinalgOperandDefConfig
     name: value
     kind: scalar
-    type_var: T1
+    type_var: T
   - !LinalgOperandDefConfig
     name: O
     kind: output_tensor
-    type_var: U
+    type_var: T
     shape_map: affine_map<() -> ()>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
@@ -6081,13 +6081,7 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      scalar_fn:
-        kind: type
-        fn_name: cast_signed
-        type_var: U
-        operands:
-        - !ScalarExpression
-          scalar_arg: value
+      scalar_arg: value
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: fill_rng_2d
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index dcc1ef9e997ea..3b0b7d8c492c8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1057,35 +1057,65 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
 // FillOpInterface implementation
 //===----------------------------------------------------------------------===//
 
+namespace mlir::linalg::detail {
 enum class MatchFillResult {
   Success = 0,
   NotLinalgOp,
   WrongNumOperands,
-  NotScalarInput
+  NotScalarInput,
+  TypeMismatch
 };
 
-static MatchFillResult isFillInterfaceImpl(Operation *op) {
+struct FillInterfaceResult {
+  MatchFillResult result = MatchFillResult::Success;
+  Type scalarType;
+  Type outputElementType;
+};
+static FillInterfaceResult isFillInterfaceImpl(Operation *op) {
+  FillInterfaceResult fillResult = {};
   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
-  if (!linalgOp)
-    return MatchFillResult::NotLinalgOp;
-  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
-    return MatchFillResult::WrongNumOperands;
+  if (!linalgOp) {
+    fillResult.result = MatchFillResult::NotLinalgOp;
+    return fillResult;
+  }
+  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) {
+    fillResult.result = MatchFillResult::WrongNumOperands;
+    return fillResult;
+  }
 
   OpOperand *value = linalgOp.getDpsInputOperand(0);
-  if (!linalgOp.isScalar(value))
-    return MatchFillResult::NotScalarInput;
+  if (!linalgOp.isScalar(value)) {
+    fillResult.result = MatchFillResult::NotScalarInput;
+    return fillResult;
+  }
+
+  // Check that the scalar input type matches the output element type.
+  OpOperand *output = linalgOp.getDpsInitOperand(0);
+  Type scalarType = value->get().getType();
+  Type outputElementType = getElementTypeOrSelf(output->get().getType());
+  if (scalarType != outputElementType) {
+    fillResult.result = MatchFillResult::TypeMismatch;
+    fillResult.scalarType = scalarType;
+    fillResult.outputElementType = outputElementType;
+    return fillResult;
+  }
 
-  return MatchFillResult::Success;
+  return fillResult;
 }
+} // namespace mlir::linalg::detail
 
 LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
-  auto res = isFillInterfaceImpl(op);
-  if (res == MatchFillResult::NotLinalgOp)
+  auto [result, scalarType, outputElementType] = isFillInterfaceImpl(op);
+  if (result == MatchFillResult::NotLinalgOp)
     return op->emitError("expected a LinalgOp");
-  if (res == MatchFillResult::WrongNumOperands)
+  if (result == MatchFillResult::WrongNumOperands)
     return op->emitError("expected op with 1 input and 1 output");
-  if (res == MatchFillResult::NotScalarInput)
+  if (result == MatchFillResult::NotScalarInput)
     return op->emitError("expected op with scalar input");
+  if (result == MatchFillResult::TypeMismatch)
+    return op->emitOpError("expected fill value type (")
+           << scalarType << ") to match output element type ("
+           << outputElementType << ")";
 
   return success();
 }
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index fd4a5a848f1e3..9c24f94fcf612 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -1729,16 +1729,16 @@ def pooling_ndhwc_min(
 
 
 @linalg_structured_op
-def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
+def fill(value=ScalarDef(T), O=TensorDef(T, output=True)):
     """Fills the output tensor with the given value.
 
     Works for arbitrary ranked output tensors since the operation performs scalar
-    accesses only and is thus rank polymorphic. Numeric casting is performed on
-    the value operand, promoting it to the same data type as the output.
+    accesses only and is thus rank polymorphic. The value type must match the
+    element type of the output tensor or memref.
     """
     implements(FillOpInterface)
     defines(Canonicalizer)
-    O[None] = TypeFn.cast_signed(U, value)
+    O[None] = value
 
 
 @linalg_structured_op
diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
index 817614be50533..2e801028057a1 100644
--- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
@@ -36,13 +36,13 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index
 //       CHECK:   "test.some_use"(%[[c5]])
 //       CHECK:   %[[c5:.*]] = arith.constant 5 : index
 //       CHECK:   "test.some_use"(%[[c5]])
-func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
+func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: i32) {
   %c0 = arith.constant 0 : index
   %c4 = arith.constant 4 : index
   scf.for %iv = %c0 to %ub step %c4 {
     %sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
     %slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
-    %filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
+    %filled = linalg.fill ins(%f : i32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
 
     %bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
     "test.some_use"(%bound) : (index) -> ()
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index bc55c12c02f29..6f1a422324e08 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -921,30 +921,6 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
 
 // -----
 
-// CHECK-LABEL: func @fold_fill_generic_different_dtype
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
-//   CHECK-NOT: linalg.fill
-//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-//  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
-//  CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
-#map0 = affine_map<(d0) -> (d0)>
-func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 7.0 : f32
-  %0 = tensor.dim %arg0, %c0 : tensor<?xf16>
-  %1 = tensor.empty(%0) : tensor<?xf16>
-  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
-  %3 = tensor.empty(%0) : tensor<?xf16>
-  %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
-  ^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
-    %5 = arith.addf  %arg1, %arg2 : f16
-        linalg.yield %5 : f16
-  } -> tensor<?xf16>
-  return %4 : tensor<?xf16>
-}
-
-// -----
-
 // CHECK-LABEL: func @fold_fill_generic_mixedaccess
 //   CHECK-NOT: linalg.fill
 //       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
@@ -1079,4 +1055,4 @@ module {
 // CHECK-NOT:     linalg.generic
 // CHECK:         tensor.expand_shape
 // CHECK:         linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:     ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
\ No newline at end of file
+// CHECK-SAME:     ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 290c6c7c36f76..4526dc90fad2e 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -380,8 +380,8 @@ func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: t
 
 // -----
 
-func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
-  %0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
+func.func @generalize_fill_0d(%value: f32, %O: tensor<f32>) -> tensor<f32> {
+  %0 = linalg.fill ins(%value: f32) outs(%O : tensor<f32>) -> tensor<f32>
   return %0: tensor<f32>
 }
 
@@ -394,8 +394,8 @@ func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
 
 // -----
 
-func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) {
-  linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>)
+func.func @generalize_fill_2d(%value: f32, %O: memref<16x32xf32>) {
+  linalg.fill ins(%value: f32) outs(%O : memref<16x32xf32>)
   return
 }
 
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index fabc8e610612d..1f554e6c45da7 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -352,6 +352,24 @@ func.func @illegal_fill_tensor_with_memref_return
 
 // -----
 
+func.func @illegal_fill_element_type_truncation(%arg0 : tensor<2xf32>, %arg1 : f64) -> tensor<2xf32>
+{
+  // expected-error @+1 {{'linalg.fill' op expected fill value type ('f64') to match output element type ('f32')}}
+  %0 = linalg.fill ins(%arg1 : f64) outs(%arg0 : tensor<2xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
+func.func @illegal_fill_element_type_extension(%arg0 : tensor<2xi32>, %arg1 : i16) -> tensor<2xi32>
+{
+  // expected-error @+1 {{'linalg.fill' op expected fill value type ('i16') to match output element type ('i32')}}
+  %0 = linalg.fill ins(%arg1 : i16) outs(%arg0 : tensor<2xi32>) -> tensor<2xi32>
+  return %0 : tensor<2xi32>
+}
+
+// -----
+
 func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32>
 {
   // expected-error @+1 {{expected op with scalar input}}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index 8fa32d7aeb586..bbda8d4e99d04 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -27,8 +27,8 @@ func.func @main() {
   %A_dyn = tensor.cast %A : tensor<8x2xf32> to tensor<?x?xf32>
   %B_dyn = tensor.cast %B : tensor<2x4xf32> to tensor<?x?xf32>
 
-  %c0_i32 = arith.constant  0 : i32
-  %C_init = linalg.fill ins(%c0_i32 : i32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %c0_f32 = arith.constant 0.0 : f32
+  %C_init = linalg.fill ins(%c0_f32 : f32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
 
   %res = linalg.matmul ins(%A_dyn, %B_dyn: tensor<?x?xf32>, tensor<?x?xf32>)
             outs(%C_init: tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index a374d9a611258..e3fee917cdeaa 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -63,11 +63,11 @@ func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> te
 }
 
 func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> {
-  %cst = arith.constant 0.0 : f64
+  %cst = arith.constant 0.0 : f32
   %empty = tensor.empty() : tensor<10x15xf32>
 
   // expected-remark @below {{fill}}
-  %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
 
   %real_lhs = linalg.mul
     ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32>
diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 8f202318146ee..8eff573f98ad3 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -25,13 +25,13 @@ def log(*args):
   %O1 = memref.alloc() : memref<16xi32>
   %O2 = memref.alloc() : memref<4x16xi32>
 
-  %val0 = arith.constant 1.0 : f32
-  %val1 = arith.constant 2.0 : f32
-  %val2 = arith.constant 3.0 : f32
+  %val0 = arith.constant 1 : i32
+  %val1 = arith.constant 2 : i32
+  %val2 = arith.constant 3 : i32
 
-  call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
-  call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
-  call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
+  call @fill_0d_on_buffers(%val0, %O0) : (i32, memref<i32>) -> ()
+  call @fill_1d_on_buffers(%val1, %O1) : (i32, memref<16xi32>) -> ()
+  call @fill_2d_on_buffers(%val2, %O2) : (i32, memref<4x16xi32>) -> ()
 
   %c0 = arith.constant 0 : index
   %res0 = memref.load %O0[] : memref<i32>
@@ -149,19 +149,18 @@ def transform(module, boilerplate):
 def test_fill_builtin():
     with Context() as ctx, Location.unknown():
         module = Module.create()
-        f32 = F32Type.get()
         i32 = IntegerType.get_signless(32)
         with InsertionPoint(module.body):
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
             def fill_0d_on_buffers(value, out):
                 linalg.fill(value, outs=[out])
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
             def fill_1d_on_buffers(value, out):
                 linalg.fill(value, outs=[out])
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
             def fill_2d_on_buffers(value, out):
                 linalg.fill(value, outs=[out])
 
@@ -184,19 +183,18 @@ def fill_2d_on_buffers(value, out):
 def test_fill_generic():
     with Context() as ctx, Location.unknown():
         module = Module.create()
-        f32 = F32Type.get()
         i32 = IntegerType.get_signless(32)
         with InsertionPoint(module.body):
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
             def fill_0d_on_buffers(value, out):
                 linalg.fill(value, outs=[out], emit_generic=True)
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
             def fill_1d_on_buffers(value, out):
                 linalg.fill(value, outs=[out], emit_generic=True)
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
             def fill_2d_on_buffers(value, out):
                 linalg.fill(value, outs=[out], emit_generic=True)
 

>From e0767d9576c3212f729b0e9799d40873236cc6ae Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 25 Nov 2025 15:46:42 -0500
Subject: [PATCH 2/3] Use anonymous namespace

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3b0b7d8c492c8..f768bd14295f9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1057,7 +1057,7 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
 // FillOpInterface implementation
 //===----------------------------------------------------------------------===//
 
-namespace mlir::linalg::detail {
+namespace {
 enum class MatchFillResult {
   Success = 0,
   NotLinalgOp,
@@ -1071,6 +1071,8 @@ struct FillInterfaceResult {
   Type scalarType;
   Type outputElementType;
 };
+} // namespace
+
 static FillInterfaceResult isFillInterfaceImpl(Operation *op) {
   FillInterfaceResult fillResult = {};
   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
@@ -1102,7 +1104,6 @@ static FillInterfaceResult isFillInterfaceImpl(Operation *op) {
 
   return fillResult;
 }
-} // namespace mlir::linalg::detail
 
 LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
   auto [result, scalarType, outputElementType] = isFillInterfaceImpl(op);

>From 09e730550e168c8a3f2a0c93e0475bfc7fe46ea9 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 25 Nov 2025 16:15:52 -0500
Subject: [PATCH 3/3] Simplify error message

---
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 52 +++++++------------
 1 file changed, 20 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f768bd14295f9..b4b1347493529 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1065,58 +1065,46 @@ enum class MatchFillResult {
   NotScalarInput,
   TypeMismatch
 };
-
-struct FillInterfaceResult {
-  MatchFillResult result = MatchFillResult::Success;
-  Type scalarType;
-  Type outputElementType;
-};
 } // namespace
 
-static FillInterfaceResult isFillInterfaceImpl(Operation *op) {
-  FillInterfaceResult fillResult = {};
+static MatchFillResult isFillInterfaceImpl(Operation *op) {
   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
-  if (!linalgOp) {
-    fillResult.result = MatchFillResult::NotLinalgOp;
-    return fillResult;
-  }
-  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) {
-    fillResult.result = MatchFillResult::WrongNumOperands;
-    return fillResult;
-  }
+  if (!linalgOp)
+    return MatchFillResult::NotLinalgOp;
+  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
+    return MatchFillResult::WrongNumOperands;
 
   OpOperand *value = linalgOp.getDpsInputOperand(0);
-  if (!linalgOp.isScalar(value)) {
-    fillResult.result = MatchFillResult::NotScalarInput;
-    return fillResult;
-  }
+  if (!linalgOp.isScalar(value))
+    return MatchFillResult::NotScalarInput;
 
   // Check that the scalar input type matches the output element type.
   OpOperand *output = linalgOp.getDpsInitOperand(0);
   Type scalarType = value->get().getType();
   Type outputElementType = getElementTypeOrSelf(output->get().getType());
-  if (scalarType != outputElementType) {
-    fillResult.result = MatchFillResult::TypeMismatch;
-    fillResult.scalarType = scalarType;
-    fillResult.outputElementType = outputElementType;
-    return fillResult;
-  }
+  if (scalarType != outputElementType)
+    return MatchFillResult::TypeMismatch;
 
-  return fillResult;
+  return MatchFillResult::Success;
 }
 
 LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
-  auto [result, scalarType, outputElementType] = isFillInterfaceImpl(op);
-  if (result == MatchFillResult::NotLinalgOp)
+  MatchFillResult res = isFillInterfaceImpl(op);
+  if (res == MatchFillResult::NotLinalgOp)
     return op->emitError("expected a LinalgOp");
-  if (result == MatchFillResult::WrongNumOperands)
+  if (res == MatchFillResult::WrongNumOperands)
     return op->emitError("expected op with 1 input and 1 output");
-  if (result == MatchFillResult::NotScalarInput)
+  if (res == MatchFillResult::NotScalarInput)
     return op->emitError("expected op with scalar input");
-  if (result == MatchFillResult::TypeMismatch)
+  if (res == MatchFillResult::TypeMismatch) {
+    auto linalgOp = cast<linalg::LinalgOp>(op);
+    Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
+    Type outputElementType =
+        getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
     return op->emitOpError("expected fill value type (")
            << scalarType << ") to match output element type ("
            << outputElementType << ")";
+  }
 
   return success();
 }



More information about the Mlir-commits mailing list