[Mlir-commits] [mlir] 0bd2f12 - [mlir][linalg] Restrict fill initial value type to output element type (#169567)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 30 06:51:42 PST 2025


Author: Jakub Kuderski
Date: 2025-11-30T09:51:37-05:00
New Revision: 0bd2f12753604cd072ae0935820ba9a23bb17ccc

URL: https://github.com/llvm/llvm-project/commit/0bd2f12753604cd072ae0935820ba9a23bb17ccc
DIFF: https://github.com/llvm/llvm-project/commit/0bd2f12753604cd072ae0935820ba9a23bb17ccc.diff

LOG: [mlir][linalg] Restrict fill initial value type to output element type (#169567)

Disallow implicit casting, which is surprising, and, IME, usually
indicative of copy-paste errors.

Because the initial value must be a scalar, I don't expect this to
affect any data movement.

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg/OpDSL.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Affine/value-bounds-reification.mlir
    mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
    mlir/test/Integration/Dialect/Transform/match_matmul.mlir
    mlir/test/python/integration/dialects/linalg/opsrun.py

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index b892bbe427a18..37604fc17dd9b 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -311,16 +311,17 @@ An example for a rank polymorphic operation is `fill`:
 
 ```python
 @linalg_structured_op
-def fill(value=ScalarDef(T1),
-         O=TensorDef(U, output=True)):
-  O[None] = TypeFn.cast_signed(U, value)
+def fill(value=ScalarDef(T),
+         O=TensorDef(T, output=True)):
+  O[None] = value
 ```
 
-The operation sets the elements of the output tensor `O` to `value`. All
-operands are either scalars or rank zero tensors that are accessed using the
-index `None`. The operation thus performs a scalar computation that trivially
-extends to a multi-dimensional pointwise computation. As a result, we may use
-`fill` with arbitrary ranked output tensors:
+The operation sets the elements of the output tensor `O` to `value`. The value
+type must match the element type of the output tensor. All operands are either
+scalars or rank zero tensors that are accessed using the index `None`. The
+operation thus performs a scalar computation that trivially extends to a
+multi-dimensional pointwise computation. As a result, we may use `fill` with
+arbitrary ranked output tensors:
 
 ```python
 tensor_2d = tensor.EmptyOp([4, 8], f32)

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 9aae1b850c3a0..521afc991063f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -6054,9 +6054,9 @@ metadata: !LinalgOpMetadata
   doc: |-
     Fills the output tensor with the given value.
 
-    Works for arbitrary ranked output tensors since the operation performs scalar
-    accesses only and is thus rank polymorphic. Numeric casting is performed on
-    the value operand, promoting it to the same data type as the output.
+    Works for arbitrary ranked output tensors since the operation performs
+    scalar accesses only and is thus rank polymorphic. The value operand
+    type must match the element type of the output.
   implements:
   - LinalgFillOpInterface
   defines:
@@ -6066,11 +6066,11 @@ structured_op: !LinalgStructuredOpConfig
   - !LinalgOperandDefConfig
     name: value
     kind: scalar
-    type_var: T1
+    type_var: T
   - !LinalgOperandDefConfig
     name: O
     kind: output_tensor
-    type_var: U
+    type_var: T
     shape_map: affine_map<() -> ()>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
@@ -6081,13 +6081,7 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      scalar_fn:
-        kind: type
-        fn_name: cast_signed
-        type_var: U
-        operands:
-        - !ScalarExpression
-          scalar_arg: value
+      scalar_arg: value
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: fill_rng_2d

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index dcc1ef9e997ea..b4b1347493529 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
 // FillOpInterface implementation
 //===----------------------------------------------------------------------===//
 
+namespace {
 enum class MatchFillResult {
   Success = 0,
   NotLinalgOp,
   WrongNumOperands,
-  NotScalarInput
+  NotScalarInput,
+  TypeMismatch
 };
+} // namespace
 
 static MatchFillResult isFillInterfaceImpl(Operation *op) {
   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
@@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) {
   if (!linalgOp.isScalar(value))
     return MatchFillResult::NotScalarInput;
 
+  // Check that the scalar input type matches the output element type.
+  OpOperand *output = linalgOp.getDpsInitOperand(0);
+  Type scalarType = value->get().getType();
+  Type outputElementType = getElementTypeOrSelf(output->get().getType());
+  if (scalarType != outputElementType)
+    return MatchFillResult::TypeMismatch;
+
   return MatchFillResult::Success;
 }
 
 LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
-  auto res = isFillInterfaceImpl(op);
+  MatchFillResult res = isFillInterfaceImpl(op);
   if (res == MatchFillResult::NotLinalgOp)
     return op->emitError("expected a LinalgOp");
   if (res == MatchFillResult::WrongNumOperands)
     return op->emitError("expected op with 1 input and 1 output");
   if (res == MatchFillResult::NotScalarInput)
     return op->emitError("expected op with scalar input");
+  if (res == MatchFillResult::TypeMismatch) {
+    auto linalgOp = cast<linalg::LinalgOp>(op);
+    Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
+    Type outputElementType =
+        getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
+    return op->emitOpError("expected fill value type (")
+           << scalarType << ") to match output element type ("
+           << outputElementType << ")";
+  }
 
   return success();
 }

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index fd4a5a848f1e3..9c24f94fcf612 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -1729,16 +1729,16 @@ def pooling_ndhwc_min(
 
 
 @linalg_structured_op
-def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
+def fill(value=ScalarDef(T), O=TensorDef(T, output=True)):
     """Fills the output tensor with the given value.
 
     Works for arbitrary ranked output tensors since the operation performs scalar
-    accesses only and is thus rank polymorphic. Numeric casting is performed on
-    the value operand, promoting it to the same data type as the output.
+    accesses only and is thus rank polymorphic. The value type must match the
+    element type of the output tensor or memref.
     """
     implements(FillOpInterface)
     defines(Canonicalizer)
-    O[None] = TypeFn.cast_signed(U, value)
+    O[None] = value
 
 
 @linalg_structured_op

diff  --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
index 817614be50533..2e801028057a1 100644
--- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
@@ -36,13 +36,13 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index
 //       CHECK:   "test.some_use"(%[[c5]])
 //       CHECK:   %[[c5:.*]] = arith.constant 5 : index
 //       CHECK:   "test.some_use"(%[[c5]])
-func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
+func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: i32) {
   %c0 = arith.constant 0 : index
   %c4 = arith.constant 4 : index
   scf.for %iv = %c0 to %ub step %c4 {
     %sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
     %slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
-    %filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
+    %filled = linalg.fill ins(%f : i32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
 
     %bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
     "test.some_use"(%bound) : (index) -> ()

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index bc55c12c02f29..6f1a422324e08 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -921,30 +921,6 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
 
 // -----
 
-// CHECK-LABEL: func @fold_fill_generic_
diff erent_dtype
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
-//   CHECK-NOT: linalg.fill
-//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
-//  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
-//  CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
-#map0 = affine_map<(d0) -> (d0)>
-func.func @fold_fill_generic_
diff erent_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 7.0 : f32
-  %0 = tensor.dim %arg0, %c0 : tensor<?xf16>
-  %1 = tensor.empty(%0) : tensor<?xf16>
-  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
-  %3 = tensor.empty(%0) : tensor<?xf16>
-  %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
-  ^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
-    %5 = arith.addf  %arg1, %arg2 : f16
-        linalg.yield %5 : f16
-  } -> tensor<?xf16>
-  return %4 : tensor<?xf16>
-}
-
-// -----
-
 // CHECK-LABEL: func @fold_fill_generic_mixedaccess
 //   CHECK-NOT: linalg.fill
 //       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
@@ -1079,4 +1055,4 @@ module {
 // CHECK-NOT:     linalg.generic
 // CHECK:         tensor.expand_shape
 // CHECK:         linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:     ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
\ No newline at end of file
+// CHECK-SAME:     ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 290c6c7c36f76..4526dc90fad2e 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -380,8 +380,8 @@ func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: t
 
 // -----
 
-func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
-  %0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
+func.func @generalize_fill_0d(%value: f32, %O: tensor<f32>) -> tensor<f32> {
+  %0 = linalg.fill ins(%value: f32) outs(%O : tensor<f32>) -> tensor<f32>
   return %0: tensor<f32>
 }
 
@@ -394,8 +394,8 @@ func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
 
 // -----
 
-func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) {
-  linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>)
+func.func @generalize_fill_2d(%value: f32, %O: memref<16x32xf32>) {
+  linalg.fill ins(%value: f32) outs(%O : memref<16x32xf32>)
   return
 }
 

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index fabc8e610612d..1f554e6c45da7 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -352,6 +352,24 @@ func.func @illegal_fill_tensor_with_memref_return
 
 // -----
 
+func.func @illegal_fill_element_type_truncation(%arg0 : tensor<2xf32>, %arg1 : f64) -> tensor<2xf32>
+{
+  // expected-error @+1 {{'linalg.fill' op expected fill value type ('f64') to match output element type ('f32')}}
+  %0 = linalg.fill ins(%arg1 : f64) outs(%arg0 : tensor<2xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
+func.func @illegal_fill_element_type_extension(%arg0 : tensor<2xi32>, %arg1 : i16) -> tensor<2xi32>
+{
+  // expected-error @+1 {{'linalg.fill' op expected fill value type ('i16') to match output element type ('i32')}}
+  %0 = linalg.fill ins(%arg1 : i16) outs(%arg0 : tensor<2xi32>) -> tensor<2xi32>
+  return %0 : tensor<2xi32>
+}
+
+// -----
+
 func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32>
 {
   // expected-error @+1 {{expected op with scalar input}}

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index 8fa32d7aeb586..bbda8d4e99d04 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -27,8 +27,8 @@ func.func @main() {
   %A_dyn = tensor.cast %A : tensor<8x2xf32> to tensor<?x?xf32>
   %B_dyn = tensor.cast %B : tensor<2x4xf32> to tensor<?x?xf32>
 
-  %c0_i32 = arith.constant  0 : i32
-  %C_init = linalg.fill ins(%c0_i32 : i32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %c0_f32 = arith.constant 0.0 : f32
+  %C_init = linalg.fill ins(%c0_f32 : f32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32>
 
   %res = linalg.matmul ins(%A_dyn, %B_dyn: tensor<?x?xf32>, tensor<?x?xf32>)
             outs(%C_init: tensor<?x?xf32>) -> tensor<?x?xf32>

diff  --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index a374d9a611258..e3fee917cdeaa 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -63,11 +63,11 @@ func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> te
 }
 
 func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> {
-  %cst = arith.constant 0.0 : f64
+  %cst = arith.constant 0.0 : f32
   %empty = tensor.empty() : tensor<10x15xf32>
 
   // expected-remark @below {{fill}}
-  %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
 
   %real_lhs = linalg.mul
     ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32>

diff  --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 8f202318146ee..8eff573f98ad3 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -25,13 +25,13 @@ def log(*args):
   %O1 = memref.alloc() : memref<16xi32>
   %O2 = memref.alloc() : memref<4x16xi32>
 
-  %val0 = arith.constant 1.0 : f32
-  %val1 = arith.constant 2.0 : f32
-  %val2 = arith.constant 3.0 : f32
+  %val0 = arith.constant 1 : i32
+  %val1 = arith.constant 2 : i32
+  %val2 = arith.constant 3 : i32
 
-  call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
-  call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
-  call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
+  call @fill_0d_on_buffers(%val0, %O0) : (i32, memref<i32>) -> ()
+  call @fill_1d_on_buffers(%val1, %O1) : (i32, memref<16xi32>) -> ()
+  call @fill_2d_on_buffers(%val2, %O2) : (i32, memref<4x16xi32>) -> ()
 
   %c0 = arith.constant 0 : index
   %res0 = memref.load %O0[] : memref<i32>
@@ -149,19 +149,18 @@ def transform(module, boilerplate):
 def test_fill_builtin():
     with Context() as ctx, Location.unknown():
         module = Module.create()
-        f32 = F32Type.get()
         i32 = IntegerType.get_signless(32)
         with InsertionPoint(module.body):
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
             def fill_0d_on_buffers(value, out):
                 linalg.fill(value, outs=[out])
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
             def fill_1d_on_buffers(value, out):
                 linalg.fill(value, outs=[out])
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
             def fill_2d_on_buffers(value, out):
                 linalg.fill(value, outs=[out])
 
@@ -184,19 +183,18 @@ def fill_2d_on_buffers(value, out):
 def test_fill_generic():
     with Context() as ctx, Location.unknown():
         module = Module.create()
-        f32 = F32Type.get()
         i32 = IntegerType.get_signless(32)
         with InsertionPoint(module.body):
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([], i32))
             def fill_0d_on_buffers(value, out):
                 linalg.fill(value, outs=[out], emit_generic=True)
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32))
             def fill_1d_on_buffers(value, out):
                 linalg.fill(value, outs=[out], emit_generic=True)
 
-            @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+            @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32))
             def fill_2d_on_buffers(value, out):
                 linalg.fill(value, outs=[out], emit_generic=True)
 


        


More information about the Mlir-commits mailing list