[Mlir-commits] [mlir] 7294be2 - [mlir][linalg] Replace linalg.fill by OpDSL variant.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 14 03:52:11 PDT 2022


Author: gysit
Date: 2022-03-14T10:51:08Z
New Revision: 7294be2b8e9ada57e051bac6a20ca8b4dbf46475

URL: https://github.com/llvm/llvm-project/commit/7294be2b8e9ada57e051bac6a20ca8b4dbf46475
DIFF: https://github.com/llvm/llvm-project/commit/7294be2b8e9ada57e051bac6a20ca8b4dbf46475.diff

LOG: [mlir][linalg] Replace linalg.fill by OpDSL variant.

The revision removes the linalg.fill operation and renames the OpDSL generated linalg.fill_tensor operation to replace it. After the change, all named structured operations are defined via OpDSL and there are no handwritten operations left.

A side-effect of the change is that the pretty printed form changes from:
```
%1 = linalg.fill(%cst, %0) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
```
changes to
```
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
```
Additionally, the builder signature now takes input and output value ranges as it is the case for all other OpDSL operations:
```
rewriter.create<linalg::FillOp>(loc, val, output)
```
changes to
```
rewriter.create<linalg::FillOp>(loc, ValueRange{val}, ValueRange{output})
```
All other changes remain minimal. In particular, the canonicalization patterns are the same and the `value()`, `output()`, and `result()` methods are now implemented by the FillOpInterface.

Depends On D120726

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D120728

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
    mlir/python/mlir/dialects/_linalg_ops_ext.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
    mlir/test/Dialect/Linalg/bufferize.mlir
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/codegen-strategy.mlir
    mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
    mlir/test/Dialect/Linalg/forward-vector-transfers.mlir
    mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
    mlir/test/Dialect/Linalg/fusion-pattern.mlir
    mlir/test/Dialect/Linalg/fusion-sequence.mlir
    mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
    mlir/test/Dialect/Linalg/fusion.mlir
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
    mlir/test/Dialect/Linalg/hoist-padding.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/library-calls.mlir
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/named-ops.mlir
    mlir/test/Dialect/Linalg/pad.mlir
    mlir/test/Dialect/Linalg/pad_fusion.mlir
    mlir/test/Dialect/Linalg/promotion_options.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
    mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir
    mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir
    mlir/test/Dialect/Linalg/tile.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/Dialect/Linalg/vectorization.mlir
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
    mlir/test/Dialect/SparseTensor/sparse_1d.mlir
    mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
    mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
    mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
    mlir/test/mlir-cpu-runner/async.mlir
    mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir
    mlir/test/mlir-cpu-runner/unranked-memref.mlir
    mlir/test/mlir-cpu-runner/utils.mlir
    mlir/test/mlir-opt/async.mlir
    mlir/test/python/dialects/linalg/ops.py
    mlir/test/python/integration/dialects/linalg/opsrun.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index d249a8f3b9d32..5e2c6eab356cf 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2908,8 +2908,8 @@ structured_op: !LinalgStructuredOpConfig
               scalar_arg: I
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
-  name: fill_tensor
-  cpp_class_name: FillTensorOp
+  name: fill
+  cpp_class_name: FillOp
   doc: |-
     Fills the output tensor with the given value.
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c6a074f190069..d9800365629cd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -43,71 +43,6 @@ class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
   }];
 }
 
-class LinalgStructured_Op<string mnemonic, list<Trait> props>
-  : LinalgStructuredBase_Op<mnemonic, !listconcat(props, [])> {
-  code structuredOpsDecls = structuredOpsBaseDecls # [{
-    std::string getLibraryCallName() {
-      return generateLibraryCallName(getOperation());
-    }
-  }];
-  let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
-}
-
-//===----------------------------------------------------------------------===//
-// Named Linalg ops, implemented as special configurations of generic ops.
-//===----------------------------------------------------------------------===//
-
-def FillOp : LinalgStructured_Op<"fill", []> {
-  let arguments = (ins
-    AnyTypeOf<[AnyComplex, AnyFloat, AnySignlessInteger, AnyVector]>:$value,
-    AnyShaped:$output);
-  let results = (outs Optional<AnyRankedTensor>:$result);
-  let regions = (region AnyRegion:$region);
-  let extraClassDeclaration = structuredOpsDecls # [{
-    ValueRange inputs() { return getOperands().take_front(); }
-    ValueRange outputs() { return getOperands().take_back(); }
-
-    // Rank-polymorphic.
-    //   filling_value -> O(ivs) with parallel iterators.
-    ArrayAttr iterator_types() {
-      int64_t nPar = getRank(getOutputOperand(0));
-      return Builder(getContext()).getStrArrayAttr(
-        SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName()));
-    }
-
-    ArrayAttr indexing_maps() {
-      MLIRContext *context = getContext();
-      // filling_value -> O(ivs)
-      return Builder(getContext()).getAffineMapArrayAttr({
-          AffineMap::get(getNumParallelLoops(), 0, {}, getContext()),
-          extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
-    }
-
-    static void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                              ArrayRef<NamedAttribute> attrs);
-    static std::function<void(ImplicitLocOpBuilder&,
-                              Block&, ArrayRef<NamedAttribute>)>
-    getRegionBuilder() {
-      return ®ionBuilder;
-    }
-    static unsigned getNumRegionArgs() { return 2; }
-  }];
-
-  let assemblyFormat = [{
-    `(` $value `,` $output `)` attr-dict `:`
-        type($value) `,` type($output) (`->` type($result)^)?
-      custom<FillOpRegion>($region, ref(type($value)), ref(type($output)))
-  }];
-
-  let builders = [
-    OpBuilder<(ins "Value":$value, "Value":$output)>
-  ];
-
-  let hasFolder = 1;
-  let hasCanonicalizer = 1;
-  let hasVerifier = 1;
-}
-
 //===----------------------------------------------------------------------===//
 // Generic Linalg ops.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index d4d38bdbd9d30..3fd69c5729f1c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -856,8 +856,10 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
         op, "No initial value found for reduction operation");
 
   auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
-  auto filledTensor =
-      rewriter.create<linalg::FillOp>(loc, fillValue, initTensor).result();
+  auto filledTensor = rewriter
+                          .create<linalg::FillOp>(loc, ValueRange{fillValue},
+                                                  ValueRange{initTensor})
+                          .result();
 
   SmallVector<AffineExpr, 2> srcExprs;
   SmallVector<AffineExpr, 2> dstExprs;
@@ -1717,7 +1719,9 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
     Value zeroVal = rewriter.createOrFold<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(resultType.getElementType()));
     Value result =
-        rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0);
+        rewriter
+            .create<linalg::FillOp>(loc, ValueRange{zeroVal}, ValueRange{init})
+            .result();
 
     auto toOpFoldResult = [](Value v) -> OpFoldResult {
       auto op = v.getDefiningOp<arith::ConstantIndexOp>();
@@ -1989,7 +1993,9 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
     auto fillValueIdx = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getIntegerAttr(outElementTy, 0));
     auto filledTensorIdx =
-        rewriter.create<linalg::FillOp>(loc, fillValueIdx, initTensorIdx)
+        rewriter
+            .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
+                                    ValueRange{initTensorIdx})
             .result();
 
     // Second fill the output buffer for the running max.
@@ -2007,7 +2013,9 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
     auto fillValueMax =
         rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
     auto filledTensorMax =
-        rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
+        rewriter
+            .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
+                                    ValueRange{initTensorMax})
             .result();
 
     // We need to reduce along the arg-max axis, with parallel operations along

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index f2a5ffaf3e082..f5fe8f8389308 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -214,8 +214,10 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
         loc, filteredDims, resultTy.getShape(), resultETy);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
-    Value zeroTensor =
-        rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
+    Value zeroTensor = rewriter
+                           .create<linalg::FillOp>(loc, ValueRange{zero},
+                                                   ValueRange{initTensor})
+                           .result();
 
     // Extract the attributes for convolution.
     llvm::SmallVector<int64_t> stride, dilation;
@@ -401,8 +403,10 @@ class DepthwiseConvConverter
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
         loc, dynamicDims, linalgConvTy.getShape(), resultETy);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
-    Value zeroTensor =
-        rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
+    Value zeroTensor = rewriter
+                           .create<linalg::FillOp>(loc, ValueRange{zero},
+                                                   ValueRange{initTensor})
+                           .result();
 
     Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
         loc, dynamicDims, resultTy.getShape(), resultETy);
@@ -493,8 +497,10 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
     Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
     auto initTensor = rewriter.create<linalg::InitTensorOp>(
         loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
-    Value zeroTensor =
-        rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
+    Value zeroTensor = rewriter
+                           .create<linalg::FillOp>(loc, ValueRange{zero},
+                                                   ValueRange{initTensor})
+                           .result();
     if (!op.quantization_info()) {
       rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
           op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
@@ -567,8 +573,10 @@ class FullyConnectedConverter
     // When quantized, the input elemeny type is not the same as the output
     Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
-    Value zeroTensor =
-        rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
+    Value zeroTensor = rewriter
+                           .create<linalg::FillOp>(loc, ValueRange{zero},
+                                                   ValueRange{initTensor})
+                           .result();
 
     SmallVector<int64_t> permutation{1, 0};
     auto permutationAttr = DenseIntElementsAttr::get(
@@ -700,7 +708,10 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
         loc, dynamicDims, resultTy.getShape(), resultTy.getElementType());
 
     Value filledInitTensor =
-        rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
+        rewriter
+            .create<linalg::FillOp>(loc, ValueRange{initialValue},
+                                    ValueRange{initTensor})
+            .result();
 
     Value fakeWindowDims =
         rewriter.create<linalg::InitTensorOp>(loc, kernel, resultETy);
@@ -759,7 +770,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
         loc, dynamicDims, accTy.getShape(), accETy);
 
     Value filledInitTensor =
-        rewriter.create<linalg::FillOp>(loc, initialValue, poolInitTensor)
+        rewriter
+            .create<linalg::FillOp>(loc, ValueRange{initialValue},
+                                    ValueRange{poolInitTensor})
             .result();
 
     Value fakeWindowDims =

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index becf321456248..331a8b91bd330 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -303,47 +303,6 @@ class RegionBuilderHelper {
 //===----------------------------------------------------------------------===//
 // FillOp
 //===----------------------------------------------------------------------===//
-void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                           ArrayRef<NamedAttribute> attrs) {
-  assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args");
-  b.create<linalg::YieldOp>(block.getArgument(0));
-}
-
-void FillOp::build(OpBuilder &builder, OperationState &result, Value value,
-                   Value output) {
-  build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value,
-        output);
-  fillStructuredOpRegion<FillOp>(
-      builder, *result.regions.front(), TypeRange{value.getType()},
-      TypeRange{output.getType()}, result.attributes.getAttrs(), {});
-}
-
-ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType,
-                              Type outputType) {
-  OpBuilder opBuilder(parser.getContext());
-  fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{valueType},
-                                 TypeRange{outputType}, {});
-  return success();
-}
-
-/// FillOp region is elided when printing.
-void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
-
-LogicalResult FillOp::verify() {
-  OpOperand *output = getOutputOperand(0);
-  Type fillType = value().getType();
-  if (getElementTypeOrSelf(output->get()) != fillType)
-    return emitOpError("expects fill type to match view elemental type");
-  return success();
-}
-
-void FillOp::getEffects(
-    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-        &effects) {
-  if (output().getType().isa<MemRefType>())
-    effects.emplace_back(MemoryEffects::Write::get(), output(),
-                         SideEffects::DefaultResource::get());
-}
 
 namespace {
 
@@ -364,7 +323,8 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
     auto newInit = rewriter.create<TensorReshapeOp>(
         loc, reshapeOp.getResultType(), oldFill.output(),
         reshapeOp.reassociation());
-    rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, oldFill.value(), newInit);
+    rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
+                                        ValueRange{newInit});
 
     return success();
   }
@@ -400,8 +360,8 @@ struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
     auto newInitOp = rewriter.create<InitTensorOp>(
         padOp.getLoc(), reifiedShape.front(), staticShape,
         oldResultType.getElementType());
-    auto newFillOp =
-        rewriter.create<FillOp>(fillOp.getLoc(), padValue, newInitOp);
+    auto newFillOp = rewriter.create<FillOp>(
+        fillOp.getLoc(), ValueRange{padValue}, ValueRange{newInitOp});
     rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
                                                 newFillOp.result());
 
@@ -517,10 +477,6 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
            FoldInsertPadIntoFill>(context);
 }
 
-// TODO: Add the FillOp patterns when transitioning to the OpDSL FillOp.
-void FillTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                               MLIRContext *context) {}
-
 //===----------------------------------------------------------------------===//
 // GenericOps
 //===----------------------------------------------------------------------===//
@@ -877,6 +833,11 @@ void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
 }
 
+LogicalResult GenericOp::fold(ArrayRef<Attribute>,
+                              SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // InitTensorOp
 //===----------------------------------------------------------------------===//
@@ -1812,15 +1773,6 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
 
 } // namespace
 
-#define LINALGOP_FOLDERS(XXX)                                                  \
-  LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \
-                          SmallVectorImpl<OpFoldResult> &) {                   \
-    return foldMemRefCast(*this);                                              \
-  }
-
-LINALGOP_FOLDERS(FillOp)
-LINALGOP_FOLDERS(GenericOp)
-
 // All named ops canonicalizers and folders are auto-generated in the
 // .cpp.inc.
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index c031c0f57b988..ffdd142716fa9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -783,8 +783,10 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
       loc, resultShapedType.getShape(), resultShapedType.getElementType());
 
   // Initialize tensor with the pad value
-  Value tmpTensor =
-      rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
+  Value tmpTensor = rewriter
+                        .create<linalg::FillOp>(loc, ValueRange{padValue},
+                                                ValueRange{initTensor})
+                        .result();
 
   // Copy original contents into new tensor
   // Uses linalg.generic, but could be done with tensor.insert_slice

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index ce76ef289348e..f494af5d14c16 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -334,7 +334,7 @@ static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc,
   }
   Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes);
   Value zero = constantZero(rewriter, loc, elemTp);
-  rewriter.create<linalg::FillOp>(loc, zero, mem);
+  rewriter.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem});
   return mem;
 }
 
@@ -749,10 +749,12 @@ class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
     // introduces an O(N) operation into the computation, but this reset
     // operation is amortized over the innermost loops for the access
     // pattern expansion.
-    rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, eltType),
-                                    values);
-    rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, boolType),
-                                    filled);
+    rewriter.create<linalg::FillOp>(
+        loc, ValueRange{constantZero(rewriter, loc, eltType)},
+        ValueRange{values});
+    rewriter.create<linalg::FillOp>(
+        loc, ValueRange{constantZero(rewriter, loc, boolType)},
+        ValueRange{filled});
     // Replace expansion op with these buffers and initial index.
     assert(op.getNumResults() == 4);
     rewriter.replaceOp(op, {values, filled, indices, zero});

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index e4faddd9b7e8a..da9a2f53934d7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -466,7 +466,7 @@ static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
   Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
   if (isMaterializing(tensor)) {
     Value zero = constantZero(rewriter, loc, denseTp.getElementType());
-    rewriter.create<linalg::FillOp>(loc, zero, alloc);
+    rewriter.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{alloc});
   } else {
     Value init =
         rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 48470f7b059d5..c457621ec61bc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -262,7 +262,8 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
         b.create<scf::YieldOp>(loc, viewAndIndices);
       },
       [&](OpBuilder &b, Location loc) {
-        b.create<linalg::FillOp>(loc, xferOp.padding(), alloc);
+        b.create<linalg::FillOp>(loc, ValueRange{xferOp.padding()},
+                                 ValueRange{alloc});
         // Take partial subview of memref which guarantees no dimension
         // overflows.
         IRRewriter rewriter(b);

diff  --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index 167a9232d1361..e3fb460552b0e 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -20,22 +20,6 @@ def isa(cls: Type, ty: Type):
     return False
 
 
-class FillOp:
-  """Extends the linalg.fill op."""
-
-  def __init__(self, output: Value, value: Value, *, loc=None, ip=None):
-    results = []
-    if isa(RankedTensorType, output.type):
-      results = [output.type]
-    op = self.build_generic(
-        results=results,
-        operands=[_get_op_result_or_value(o) for o in [value, output]],
-        attributes=None,
-        loc=loc,
-        ip=ip)
-    OpView.__init__(self, op)
-    fill_builtin_region(self.operation)
-
 class InitTensorOp:
   """Extends the linalg.init_tensor op."""
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 5774cbc6c7677..2c6291bad85f4 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -715,7 +715,7 @@ def pooling_ndhwc_min(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
 
 
 @linalg_structured_op
-def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
+def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
   """Fills the output tensor with the given value.
 
   Works for arbitrary ranked output tensors since the operation performs scalar

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index ce7d15596b737..6d568183494b6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -4,7 +4,7 @@
 func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
   // CHECK: [[C0:%.+]] = arith.constant 0
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6]
-  // CHECK: [[FILLED:%.+]] = linalg.fill([[C0]], [[INIT]]) : f32, tensor<1x5x6xf32> -> tensor<1x5x6xf32>
+  // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : f32) outs([[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
   // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
   %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x6xf32>)  -> (tensor<1x5x6xf32>)
   return %0 : tensor<1x5x6xf32>
@@ -17,7 +17,7 @@ func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x
 func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> (tensor<1x5x6xi32>) {
   // CHECK: [[C0:%.+]] = arith.constant 0
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6]
-  // CHECK: [[FILLED:%.+]] = linalg.fill([[C0]], [[INIT]]) : i32, tensor<1x5x6xi32> -> tensor<1x5x6xi32>
+  // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : i32) outs([[INIT]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
   // CHECK: [[ONE:%.+]] = arith.constant 1
   // CHECK: [[TWO:%.+]] = arith.constant 2
   // CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
@@ -33,7 +33,7 @@ func @matmul_dyn_batch(%arg0: tensor<?x5x3xf32>, %arg1: tensor<?x3x6xf32>) -> (t
   // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
   // CHECK: %[[C0_0:.+]] = arith.constant 0
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM]], 5, 6]
-  // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0_0]], %[[INIT]]) : f32, tensor<?x5x6xf32> -> tensor<?x5x6xf32>
+  // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
   // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x5x3xf32>, tensor<?x3x6xf32>) outs(%[[FILLED]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
   %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<?x5x3xf32>, tensor<?x3x6xf32>)  -> (tensor<?x5x6xf32>)
   return %0 : tensor<?x5x6xf32>
@@ -47,7 +47,7 @@ func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x?xf
   // CHECK: %[[DIM:.+]] = tensor.dim %arg1, %[[C2]]
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 5, %[[DIM]]]
-  // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x5x?xf32> -> tensor<1x5x?xf32>
+  // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
   // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
   %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x?xf32>)  -> (tensor<1x5x?xf32>)
   return %0 : tensor<1x5x?xf32>
@@ -59,7 +59,7 @@ func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x?xf
 func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x?x6xf32>) -> (tensor<1x5x6xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 5, 6]
-  // CHECK: %[[FILLED:.+]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x5x6xf32> -> tensor<1x5x6xf32>
+  // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
   // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
   %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x?xf32>, tensor<1x?x6xf32>)  -> (tensor<1x5x6xf32>)
   return %0 : tensor<1x5x6xf32>
@@ -74,7 +74,7 @@ func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x?x6xf
 func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
   // CHECK: [[INITT:%.+]] = linalg.init_tensor [5, 6]
   // CHECK: [[ZERO:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[ZERO]], [[INITT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
   // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
   // CHECK: [[TRANSPOSE:%.+]] = "tosa.transpose"(%arg1, [[PERM]])
   // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6]
@@ -97,7 +97,7 @@ func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: ten
 func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
   // CHECK: [[INITT:%.+]] = linalg.init_tensor [5, 6]
   // CHECK: [[ZERO:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[ZERO]], [[INITT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
   // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
   // CHECK: [[TRANSPOSE:%.+]] = "tosa.transpose"(%arg1, [[PERM]])
   // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6]
@@ -123,7 +123,7 @@ func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %arg2:
   // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
   // CHECK: %[[INITT:.+]] = linalg.init_tensor [%[[DIM]], 6]
   // CHECK: %[[ZERO:.+]] = arith.constant 0
-  // CHECK: %[[FILL:.+]] = linalg.fill(%[[ZERO]], %[[INITT]])
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INITT]]
   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]>
   // CHECK: %[[TRANSPOSE:.+]] = "tosa.transpose"(%arg1, %[[PERM]])
   // CHECK: %[[INITB:.+]] = linalg.init_tensor [%[[DIM]], 6]
@@ -143,7 +143,7 @@ func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %arg2:
 func @max_pool(%arg0: tensor<1x6x34x62xf32>) -> () {
   // CHECK-DAG: [[CONST:%.+]] = arith.constant -3.40282347E+38
   // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 32, 62]
-  // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[CONST]], [[INIT]])
+  // CHECK-DAG: [[FILL:%.+]] = linalg.fill ins([[CONST]]{{.*}}outs([[INIT]]
   // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3]
   // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x32x62xf32>)
   %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>)  -> (tensor<1x4x32x62xf32>)
@@ -157,7 +157,7 @@ func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () {
   // CHECK-DAG:   tensor.yield [[CONST]]
   // CHECK-DAG: [[INITVAL:%.+]] = arith.constant -3.40282347E+38 : f32
   // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 33, 62]
-  // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INITVAL]], [[INIT]])
+  // CHECK-DAG: [[FILL:%.+]] = linalg.fill ins([[INITVAL]]{{.*}}outs([[INIT]]
   // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3]
   // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x6x35x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x33x62xf32>)
   %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 1], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>)  -> (tensor<1x4x33x62xf32>)
@@ -170,7 +170,7 @@ func @max_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> () {
   // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
   // CHECK: %[[CONST:.+]] = arith.constant -3.40282347E+38
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 32, 62]
-  // CHECK: %[[FILL:.+]] = linalg.fill(%[[CONST]], %[[INIT]])
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CONST]]{{.*}}outs(%[[INIT]]
   // CHECK: %[[KERNEL:.+]] = linalg.init_tensor [3, 3]
   // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x6x34x62xf32>, tensor<3x3xf32>) outs(%[[FILL]] : tensor<?x4x32x62xf32>)
   %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<?x6x34x62xf32>)  -> (tensor<?x4x32x62xf32>)
@@ -209,7 +209,7 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
   // CHECK: [[PAD:%.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK: [[CONST:%.+]] = arith.constant 0
   // CHECK: [[POOLINIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CONST]], [[POOLINIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CONST]]{{.*}}outs([[POOLINIT]]
   // CHECK: [[KERNEL:%.+]] = linalg.init_tensor [4, 4]
   // CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>)
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
@@ -474,7 +474,7 @@ func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1xi8>,
 func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11]
   // CHECK: [[CST0:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
   // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33]
   // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>)
   // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 5, 5, 33]}
@@ -520,7 +520,7 @@ func @depthwise_conv_dyn(%arg0 : tensor<?x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf3
 func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11]
   // CHECK: [[CST0:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
   // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33]
   // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>)
   // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 5, 5, 33]}
@@ -546,7 +546,7 @@ func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x12
 
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 12, 12, 4, 128]
   // CHECK: [[CST0:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
   // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 12, 12, 512]
   // CHECK: [[C128:%.+]] = arith.constant -128
   // CHECK: [[C42:%.+]] = arith.constant 42
@@ -570,7 +570,7 @@ func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x12
 func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> () {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 10, 10, 4, 128]
   // CHECK: [[CST0:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
   // CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 10, 10, 512]
   // CHECK: [[C128:%.+]] = arith.constant -128
   // CHECK: [[C42:%.+]] = arith.constant 42

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f375a930604ba..fa2f7b82af243 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -694,7 +694,7 @@ func @test_transpose_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
 func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
   // CHECK: [[CST0:%.+]] = arith.constant 0.0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<4xf32>)
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
   // CHECK:   [[RES:%.+]] = arith.addf %arg1, %arg2 : f32
@@ -704,7 +704,7 @@ func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
 
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
   // CHECK: [[CST0:%.+]] = arith.constant 0.0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5xf32>)
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
   // CHECK:   [[RES:%.+]] = arith.addf %arg1, %arg2 : f32
@@ -745,7 +745,7 @@ func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
   // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C0]]
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 4]
   // CHECK: %[[CST0:.+]] = arith.constant 0.0
-  // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST0]], %[[INIT]])
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT]]
   // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<?x5x4xf32>) outs(%[[FILL]] : tensor<?x4xf32>)
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
   // CHECK:   %[[RES:.+]] = arith.addf %arg1, %arg2 : f32
@@ -767,7 +767,7 @@ func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
   // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C1]]
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, %[[DYN]]]
   // CHECK: %[[CST1:.+]] = arith.constant 1.0
-  // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST1]], %[[INIT]])
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT]]
   // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<5x?x4xf32>) outs(%[[FILL]] : tensor<5x?xf32>)
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
   // CHECK:   %[[RES:.+]] = arith.mulf %arg1, %arg2 : f32
@@ -789,7 +789,7 @@ func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
   // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C0]]
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]]]
   // CHECK: %[[CMIN:.+]] = arith.constant -3.40282347E+38
-  // CHECK: %[[FILL:.+]] = linalg.fill(%[[CMIN]], %[[INIT]])
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CMIN]]{{.*}}outs(%[[INIT]]
   // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%[[FILL]] : tensor<?xf32>)
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
   // CHECK:   %[[CMP:.+]] = arith.cmpf ogt, %arg1, %arg2 : f32
@@ -811,7 +811,7 @@ func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
 func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
   // CHECK: [[CST0:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<4xi32>)
   // CHECK: ^bb0(%arg1: i32, %arg2: i32)
   // CHECK:   [[RES:%.+]] = arith.addi %arg1, %arg2 : i32
@@ -821,7 +821,7 @@ func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
 
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
   // CHECK: [[CST0:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5xi32>)
   // CHECK: ^bb0(%arg1: i32, %arg2: i32)
   // CHECK:   [[RES:%.+]] = arith.addi %arg1, %arg2 : i32
@@ -861,7 +861,7 @@ func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
 func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
   // CHECK: [[CST0:%.+]] = arith.constant true
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi1>) outs([[FILL]] : tensor<4xi1>)
   // CHECK: ^bb0(%arg1: i1, %arg2: i1)
   // CHECK:   [[RES:%.+]] = arith.andi %arg1, %arg2 : i1
@@ -889,7 +889,7 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
   // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1]
   // CHECK: [[CST:%.+]] = arith.constant 0.0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST]]{{.*}}outs([[INIT]]
   // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1]
   // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]][5, 0] [6, 1] [1, 1]
   %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>)  -> (tensor<11x1xf32>)
@@ -901,7 +901,7 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
   // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2]
   // CHECK: [[CST:%.+]] = arith.constant 0.0
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]])
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST]]{{.*}}outs([[INIT]]
   // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1]
   // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]][0, 1] [5, 1] [1, 1]
   %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>)  -> (tensor<5x2xf32>)
@@ -922,7 +922,7 @@ func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> ()
   // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX1_2]]
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [11, %[[DYN]]]
   // CHECK: %[[CST:.+]] = arith.constant 0.0
-  // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]])
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]]{{.*}}outs(%[[INIT]]
   // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [5, %[[SIZE]]] [1, 1]
   // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %arg1 into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1]
   %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>)  -> (tensor<11x?xf32>)
@@ -943,7 +943,7 @@ func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> () {
   // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 3]
   // CHECK: %[[CST:.+]] = arith.constant 0.0
-  // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]])
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]]{{.*}}outs(%[[INIT]]
   // CHECK: %[[DYN1:.+]] = tensor.dim %arg0, %[[AXIS]]
   // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [%[[DYN1]], 3] [1, 1]
   // CHECK: %[[SUM:.+]]  = arith.addi %[[OFFSET]], %[[DYN1]]
@@ -1330,10 +1330,10 @@ func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
 func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
   // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [2]
   // CHECK: [[IDX_MIN:%.+]] = arith.constant 0 : i32
-  // CHECK: [[IDX_FILL:%.+]] = linalg.fill([[IDX_MIN]], [[IDX_INIT]])
+  // CHECK: [[IDX_FILL:%.+]] = linalg.fill ins([[IDX_MIN]]{{.*}}outs([[IDX_INIT]]
   // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [2]
   // CHECK: [[VAL_MIN:%.+]] = arith.constant -2147483648
-  // CHECK: [[VAL_FILL:%.+]] = linalg.fill([[VAL_MIN]], [[VAL_INIT]])
+  // CHECK: [[VAL_FILL:%.+]] = linalg.fill ins([[VAL_MIN]]{{.*}}outs([[VAL_INIT]]
   // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<2xi32>, tensor<2xi32>)
   // CHECK:   [[IDX:%.+]] = linalg.index 0
   // CHECK:   [[CAST:%.+]] = arith.index_cast [[IDX]]
@@ -1345,10 +1345,10 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
 
   // CHECK: [[IDX_INIT:%.+]] = linalg.init_tensor [3]
   // CHECK: [[IDX_MIN:%.+]] = arith.constant 0 : i32
-  // CHECK: [[IDX_FILL:%.+]] = linalg.fill([[IDX_MIN]], [[IDX_INIT]])
+  // CHECK: [[IDX_FILL:%.+]] = linalg.fill ins([[IDX_MIN]]{{.*}}outs([[IDX_INIT]]
   // CHECK: [[VAL_INIT:%.+]] = linalg.init_tensor [3]
   // CHECK: [[VAL_MIN:%.+]] = arith.constant -2147483648
-  // CHECK: [[VAL_FILL:%.+]] = linalg.fill([[VAL_MIN]], [[VAL_INIT]])
+  // CHECK: [[VAL_FILL:%.+]] = linalg.fill ins([[VAL_MIN]]{{.*}}outs([[VAL_INIT]]
   // CHECK: linalg.generic {indexing_maps = [#map0, #map2, #map2], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x2xi32>) outs([[IDX_FILL]], [[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>)
   // CHECK:   [[IDX:%.+]] = linalg.index 1
   // CHECK:   [[CAST:%.+]] = arith.index_cast [[IDX]]
@@ -1380,10 +1380,10 @@ func @argmax_dyn_non_axis(%arg0 : tensor<3x?xi32>) -> () {
   // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]]
   // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [%[[DYN]]]
   // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32
-  // CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]])
+  // CHECK: %[[IDX_FILL:.+]] = linalg.fill ins(%[[IDX_MIN]]{{.*}}outs(%[[IDX_INIT]]
   // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [%[[DYN]]]
   // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648
-  // CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]])
+  // CHECK: %[[VAL_FILL:.+]] = linalg.fill ins(%[[VAL_MIN]]{{.*}}outs(%[[VAL_INIT]]
   // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<?xi32>, tensor<?xi32>)
   // CHECK:   %[[IDX:.+]] = linalg.index 0
   // CHECK:   %[[CAST:.+]] = arith.index_cast %[[IDX]]
@@ -1403,10 +1403,10 @@ func @argmax_dyn_non_axis(%arg0 : tensor<3x?xi32>) -> () {
 func @argmax_dyn_axis(%arg0 : tensor<3x?xi32>) -> () {
   // CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [3]
   // CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32
-  // CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]])
+  // CHECK: %[[IDX_FILL:.+]] = linalg.fill ins(%[[IDX_MIN]]{{.*}}outs(%[[IDX_INIT]]
   // CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [3]
   // CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648
-  // CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]])
+  // CHECK: %[[VAL_FILL:.+]] = linalg.fill ins(%[[VAL_MIN]]{{.*}}outs(%[[VAL_INIT]]
   // CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>)
   // CHECK:   %[[IDX:.+]] = linalg.index 1
   // CHECK:   %[[CAST:.+]] = arith.index_cast %[[IDX]]

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
index a806835c2848d..fc0ac2d384626 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
@@ -148,9 +148,9 @@ func @unknown_op_may_read(%v: vector<5xf32>)
   // CHECK: %[[m1_casted:.*]] = memref.cast %[[m1]] : memref<10xf32> to memref<10xf32, #[[$MAP3]]>
   %t1 = linalg.init_tensor [10] : tensor<10xf32>
 
-  // CHECK: linalg.fill(%{{.*}}, %[[m1]])
+  // CHECK: linalg.fill ins(%{{.*}}{{.*}}outs(%[[m1]]
   // CHECK: %[[filled_tensor:.*]] = bufferization.to_tensor %[[m1_casted]]
-  %filled = linalg.fill(%cst, %t1) : f32, tensor<10xf32> -> tensor<10xf32>
+  %filled = linalg.fill ins(%cst : f32) outs(%t1 : tensor<10xf32>) -> tensor<10xf32>
 
   // The transfer_write is out-of-place because "dummy_op" may read.
   // CHECK: memref.copy %[[m1]], %[[alloc]]

diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index e6d4f92b0a0f5..a027112ee05c0 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -167,10 +167,10 @@ func @generic_with_init_tensor(%arg0: tensor<2x3x4xvector<3x4xi4>>,
 func @bufferize_fill(%arg0: tensor<?xf32>) -> tensor<?xf32> {
   %c0 = arith.constant 0.0 : f32
   // CHECK: %[[ALLOC:.*]] = memref.alloc
-  // CHECK: linalg.fill(%cst, %[[ALLOC]]) : f32, memref<?xf32>
+  // CHECK: linalg.fill ins(%cst : f32) outs(%[[ALLOC]] : memref<?xf32>)
   // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<?xf32>
   // CHECK: return %[[TENSOR]]
-  %0 = linalg.fill(%c0, %arg0) : f32, tensor<?xf32> -> tensor<?xf32>
+  %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 0dd435c01d818..28655b96e5df6 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -275,7 +275,7 @@ func @dead_linalg_tensor(%arg0 : tensor<7x7xi32>, %arg1 : tensor<7x7xf32>,
   %c0_i32 = arith.constant 0 : i32
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
-  %0 = linalg.fill(%c0_i32, %arg0) : i32, tensor<7x7xi32> -> tensor<7x7xi32>
+  %0 = linalg.fill ins(%c0_i32 : i32) outs(%arg0 : tensor<7x7xi32>) -> tensor<7x7xi32>
   %1 = linalg.matmul ins(%arg1, %arg1: tensor<7x7xf32>, tensor<7x7xf32>)
                      outs(%arg1: tensor<7x7xf32>) -> tensor<7x7xf32>
   %2 = linalg.generic #trait outs(%arg0 : tensor<7x7xi32>) {
@@ -298,7 +298,7 @@ func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
   %c21 = arith.constant 21 : index
   %c42 = arith.constant 42 : index
   %0 = linalg.init_tensor [%c21, %c42] : tensor<?x?xf32>
-  %1 = linalg.fill(%arg1, %0) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+  %1 = linalg.fill ins(%arg1 : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
   %2 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
   %3 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
   %4 = tensor.insert_slice %arg0 into %1[%arg2, %arg3] [%2, %3] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
@@ -306,7 +306,7 @@ func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
 }
 // CHECK-LABEL: func @propogate_casts
 //       CHECK:   %[[INIT:.+]] = linalg.init_tensor [21, 42]
-//       CHECK:   %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
+//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
 //       CHECK:   %[[INSERTED:.+]] = tensor.insert_slice %{{.+}} into %[[FILL]]
 //       CHECK:   %[[RESULT:.+]] = tensor.cast %[[INSERTED]]
 //       CHECK:   return %[[RESULT]]
@@ -330,8 +330,8 @@ func @fold_fill_reshape() -> tensor<6x4xf32> {
   %zero = arith.constant 0.0 : f32
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [6, 4] : tensor<6x4xf32>
   %init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32>
-  // CHECK: %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor<6x4xf32> -> tensor<6x4xf32>
-  %fill = linalg.fill(%zero, %init) : f32, tensor<1x2x3x4xf32> -> tensor<1x2x3x4xf32>
+  // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<6x4xf32>) -> tensor<6x4xf32>
+  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
   %reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]]
       : tensor<1x2x3x4xf32> into tensor<6x4xf32>
   // CHECK: return %[[FILL]] : tensor<6x4xf32>
@@ -345,8 +345,8 @@ func @fold_fill_reshape() -> tensor<6x4xf32> {
 func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32> {
   %zero = arith.constant 0.0 : f32
   // CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
-  %0 = linalg.fill(%zero, %arg0) : f32, tensor<?x?x?x?x?xf32> -> tensor<?x?x?x?x?xf32>
-  // CHECK: %[[RESULT:.+]] = linalg.fill(%{{.+}}, %[[RESHAPE]])
+  %0 = linalg.fill ins(%zero : f32) outs(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  // CHECK: %[[RESULT:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[RESHAPE]]
   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4]]
       : tensor<?x?x?x?x?xf32> into tensor<?x?xf32>
   // CHECK: return %[[RESULT]]
@@ -395,15 +395,15 @@ func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
 // CHECK: func @fold_self_copy
 func @fold_self_copy(%0 : memref<4x16xf32>) {
 // CHECK-NEXT: return
-  linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 
-                                   affine_map<(d0, d1) -> (d0, d1)>], 
-                  iterator_types = ["parallel", "parallel"]} 
-    ins(%0 : memref<4x16xf32>) 
+  linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                   affine_map<(d0, d1) -> (d0, d1)>],
+                  iterator_types = ["parallel", "parallel"]}
+    ins(%0 : memref<4x16xf32>)
     outs(%0 : memref<4x16xf32>) {
       ^bb0(%arg4: f32, %arg5: f32):
         linalg.yield %arg4 : f32
     }
-  return 
+  return
 }
 
 // -----
@@ -411,12 +411,12 @@ func @fold_self_copy(%0 : memref<4x16xf32>) {
 // CHECK-LABEL: func @fold_static_pad_fill
 //       CHECK:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
 //       CHECK:   %[[INIT:.+]] = linalg.init_tensor [412, 276] : tensor<412x276xf32>
-//       CHECK:   %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
+//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]]
 //       CHECK:   return %[[FILL]]
 func @fold_static_pad_fill() -> tensor<412x276xf32> {
   %f0 = arith.constant 0.0 : f32
   %init = linalg.init_tensor [400, 273] : tensor<400x273xf32>
-  %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32>
+  %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<400x273xf32>) -> tensor<400x273xf32>
   %pad = tensor.pad %fill low[4, 1] high[8, 2] {
   ^bb0(%arg1: index, %arg2: index):
     tensor.yield %f0 : f32
@@ -436,18 +436,18 @@ func @fold_static_pad_fill() -> tensor<412x276xf32> {
 
 //  CHECK-DAG:   %[[I1:.+]] = arith.constant 1 : index
 //  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
-//      CHECK:   %[[OF:.+]] = linalg.fill(%[[F0]], %[[SRC]]) : f32, tensor<8x?x16x32xf32>
+//      CHECK:   %[[OF:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[SRC]] : tensor<8x?x16x32xf32>)
 //      CHECK:   %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]]
 //      CHECK:   %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32>
 //      CHECK:   %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]]
 //      CHECK:   %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]]
 //      CHECK:   %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]]
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32>
-//      CHECK:   %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
+//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]]
 //      CHECK:   return %[[FILL]]
 func @fold_dynamic_pad_fill(%init: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor<?x?x?x?xf32> {
   %f0 = arith.constant 0.0 : f32
-  %fill = linalg.fill(%f0, %init) : f32, tensor<8x?x16x32xf32> -> tensor<8x?x16x32xf32>
+  %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x?x16x32xf32>) -> tensor<8x?x16x32xf32>
   %pad = tensor.pad %fill low[%low0, 8, 7, %low3] high[1, 2, %high2, %high3] {
   ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
     tensor.yield %f0 : f32
@@ -462,7 +462,7 @@ func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
   %f0 = arith.constant 0.0 : f32
   %f1 = arith.constant 1.0 : f32
   %init = linalg.init_tensor [400, 273] : tensor<400x273xf32>
-  %fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32>
+  %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<400x273xf32>) -> tensor<400x273xf32>
   // CHECK: tensor.pad
   %pad = tensor.pad %fill low[4, 1] high[8, 2] {
   ^bb0(%arg1: index, %arg2: index):
@@ -635,7 +635,7 @@ func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2: index
 //   CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
 //   CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
 //       CHECK: %[[INIT:.+]] = linalg.init_tensor [8, 384, 384]
-//       CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
+//       CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]]
 //       CHECK: %[[OFFSET1:.+]] = affine.apply #[[$MAP]]()[%[[LOW1]]]
 //       CHECK: %[[D0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x?xf32>
 //       CHECK: %[[D1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x?xf32>
@@ -649,7 +649,7 @@ func @insert_pad_into_fill(%input: tensor<?x?x?xf32>, %low0: index, %low1: index
     tensor.yield %f0 : f32
   } : tensor<?x?x?xf32> to tensor<8x128x128xf32>
   %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
-  %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+  %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
   %0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
   return %0: tensor<8x384x384xf32>
 }
@@ -670,7 +670,7 @@ func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128
     tensor.yield %f0 : f32
   } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
   %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
-  %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+  %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
   %0 = tensor.insert_slice %a   into %fill[%offset, 0, 0]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
   %1 = tensor.insert_slice %a   into %0   [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
   %2 = tensor.insert_slice %pad into %1   [0, 0, 256]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
@@ -689,7 +689,7 @@ func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tens
     tensor.yield %f0 : f32
   } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
   %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
-  %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+  %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
   %0 = tensor.insert_slice %a   into %fill[%offset, 0, 0]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
   %1 = tensor.insert_slice %a   into %0   [0, 0, 129]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
   // Range overlap with %1 at dim#3
@@ -709,7 +709,7 @@ func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tens
     tensor.yield %f0 : f32
   } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
   %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
-  %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+  %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
   %0 = tensor.insert_slice %a   into %fill[0, 0, %offset]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
   %1 = tensor.insert_slice %a   into %0   [0, 128, 255]    [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
   // Range overlap with %0 at dim#3
@@ -729,7 +729,7 @@ func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128
     tensor.yield %f0 : f32
   } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
   %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
-  %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+  %fill = linalg.fill ins(%f0 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
   // Overlap btween %0 and %1 is fine but not with %2 is fine.
   // CHECK-COUNT-3: tensor.insert_slice
   %0 = tensor.insert_slice %a   into %fill[0, 0, %offset]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
@@ -752,7 +752,7 @@ func @multi_insert_pad_into_fill_mismatch(%input: tensor<7x123x124xf32>, %a: ten
   } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
   %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
   // Different filling value than padding value.
-  %fill = linalg.fill(%f1, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+  %fill = linalg.fill ins(%f1 : f32) outs(%init : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
   %0 = tensor.insert_slice %a   into %fill[%offset, 0, 0]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
   %1 = tensor.insert_slice %a   into %0   [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
   %2 = tensor.insert_slice %pad into %1   [0, 0, 256]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>

diff  --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
index d698d63758e4c..7937755d6a3fe 100644
--- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir
+++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
@@ -68,7 +68,7 @@ func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<7
   //        CHECK-FUSE:  %[[CST:.*]] = arith.constant dense<0.000000e+00>
   //        CHECK-FUSE:  vector.transfer_write %[[CST]]
   %cst = arith.constant 0.0 : f32
-  %0 = linalg.fill(%cst, %arg0) : f32, tensor<72x72xf32> -> tensor<72x72xf32>
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<72x72xf32>) -> tensor<72x72xf32>
 
   // Check the matmul is padded and vectorized despite the empty anchor op string.
   //        CHECK-FUSE:  vector.outerproduct
@@ -81,7 +81,7 @@ func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<7
 //         CHECK-DECOMP: func @conv(
 func @conv(%arg0: tensor<8x18x17x32xf32>, %arg1: tensor<3x3x32x64xf32>, %arg2: tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32> {
   %cst = arith.constant 0.000000e+00 : f32
-  %0 = linalg.fill(%cst, %arg2) : f32, tensor<8x16x15x64xf32> -> tensor<8x16x15x64xf32>
+  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32>
 
   // Check the conv is padded by a rank-reducing vector transfer op pair.
   //        CHECK-DECOMP:  vector.transfer_read {{.*}}: tensor<1x1x?x8xf32>, vector<1x8x8xf32>

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir b/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir
index e8fc0af57e560..f9b7cd0c0fe41 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir
@@ -18,9 +18,9 @@ func @fill_extract_matmul_1234(
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
@@ -45,9 +45,9 @@ func @fill_extract_matmul_1243(
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
@@ -71,11 +71,11 @@ func @fill_extract_matmul_1324(%arg0: tensor<518x518xf32> {linalg.buffer_layout
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
@@ -97,13 +97,13 @@ func @fill_extract_matmul_1342(%arg0: tensor<518x518xf32> {linalg.buffer_layout
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%3, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>
@@ -123,11 +123,11 @@ func @fill_extract_matmul_1423(%arg0: tensor<518x518xf32> {linalg.buffer_layout
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
@@ -149,13 +149,13 @@ func @fill_extract_matmul_1432(%arg0: tensor<518x518xf32> {linalg.buffer_layout
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%3, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>
@@ -176,9 +176,9 @@ func @fill_extract_matmul_2134(
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
@@ -203,9 +203,9 @@ func @fill_extract_matmul_2143(
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
@@ -230,11 +230,11 @@ func @fill_extract_matmul_2314(
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
@@ -257,13 +257,13 @@ func @fill_extract_matmul_2341(
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%1, %4 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>
@@ -284,11 +284,11 @@ func @fill_extract_matmul_2413(
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
@@ -311,13 +311,13 @@ func @fill_extract_matmul_2431(
   %0 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
 
   // CHECK: {__inplace_operands_attr__ = ["none", "false"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%1, %4 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>
@@ -340,9 +340,9 @@ func @fill_extract_matmul_3124(
   // CHECK: {__inplace_operands_attr__ = ["false"]}
   %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
@@ -367,11 +367,11 @@ func @fill_extract_matmul_3142(
   // CHECK: {__inplace_operands_attr__ = ["false"]}
   %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>
@@ -394,9 +394,9 @@ func @fill_extract_matmul_3214(
   // CHECK: {__inplace_operands_attr__ = ["false"]}
   %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
@@ -421,11 +421,11 @@ func @fill_extract_matmul_3241(
   // CHECK: {__inplace_operands_attr__ = ["false"]}
   %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %2[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%1, %4 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>
@@ -450,9 +450,9 @@ func @fill_extract_matmul_3412(
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>
@@ -477,9 +477,9 @@ func @fill_extract_matmul_3421(
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>
@@ -502,9 +502,9 @@ func @fill_extract_matmul_4123(
   // CHECK: {__inplace_operands_attr__ = ["false"]}
   %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
@@ -529,11 +529,11 @@ func @fill_extract_matmul_4132(
   // CHECK: {__inplace_operands_attr__ = ["false"]}
   %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%3, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>
@@ -556,9 +556,9 @@ func @fill_extract_matmul_4213(
   // CHECK: {__inplace_operands_attr__ = ["false"]}
   %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %1[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
@@ -583,11 +583,11 @@ func @fill_extract_matmul_4231(
   // CHECK: {__inplace_operands_attr__ = ["false"]}
   %4 = tensor.extract_slice %0[0, 0] [16, 256] [1, 1] : tensor<256x256xf32> to tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>
@@ -612,9 +612,9 @@ func @fill_extract_matmul_4312(
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>
@@ -639,9 +639,9 @@ func @fill_extract_matmul_4321(
   // CHECK: {__inplace_operands_attr__ = ["true"]}
   %3 = tensor.extract_slice %0[0, 0] [256, 16] [1, 1] : tensor<256x256xf32> to tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %2 = linalg.fill(%cst_0, %4) : f32, tensor<16x256xf32> -> tensor<16x256xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%4 : tensor<16x256xf32>) -> tensor<16x256xf32>
   // CHECK: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %3) : f32, tensor<256x16xf32> -> tensor<256x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%3 : tensor<256x16xf32>) -> tensor<256x16xf32>
   // CHECK: {__inplace_operands_attr__ = ["true", "true", "true"]}
   %5 = linalg.matmul ins(%1, %2 : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
   return %5 : tensor<256x256xf32>

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
index 8b4db05f8251e..ec20cd89c52e6 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
@@ -11,8 +11,8 @@ func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> t
   // CHECK-NEXT:   %[[C0:.*]] = arith.constant 0{{.*}} : f32
   %v0 = arith.constant 0.0 : f32
 
-  // CHECK-NEXT:   linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32, #[[$DYN_0D_MAP]]>
-  %d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
+  // CHECK-NEXT:   linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
+  %d = linalg.fill ins(%v0 : f32) outs(%c : tensor<f32>) -> tensor<f32>
 
   // CHECK-NEXT:   linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
   %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
@@ -41,12 +41,12 @@ func @main() {
   %B = linalg.init_tensor [64] : tensor<64xf32>
   %C = linalg.init_tensor [] : tensor<f32>
 
-  // CHECK-NEXT:   linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32>
-  // CHECK-NEXT:   linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32>
-  // CHECK-NEXT:   linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32>
-  %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
-  %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
-  %CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
+  // CHECK-NEXT:   linalg.fill ins(%[[C1]] : f32) outs(%[[A]] : memref<64xf32>)
+  // CHECK-NEXT:   linalg.fill ins(%[[C2]] : f32) outs(%[[B]] : memref<64xf32>)
+  // CHECK-NEXT:   linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref<f32>)
+  %AA = linalg.fill ins(%v1 : f32) outs(%A : tensor<64xf32>) -> tensor<64xf32>
+  %BB = linalg.fill ins(%v2 : f32) outs(%B : tensor<64xf32>) -> tensor<64xf32>
+  %CC = linalg.fill ins(%v0 : f32) outs(%C : tensor<f32>) -> tensor<f32>
 
   // CHECK-NEXT:   call @init_and_dot(%[[cA]], %[[cB]], %[[cC]])
   %res = call @init_and_dot(%AA, %BB, %CC) :

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir
index b2116e5d6d322..580e9b34b6eb8 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir
@@ -16,7 +16,7 @@ func @buffer_forwarding_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = true
 
   //      CHECK: linalg.fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]
-  %1 = linalg.fill(%cst, %0) : f32, tensor<?xf32> -> tensor<?xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
 
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "false", "none"]
@@ -43,7 +43,7 @@ func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = t
 
   //      CHECK: linalg.fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]
-  %1 = linalg.fill(%cst, %0) : f32, tensor<?xf32> -> tensor<?xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
 
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none"]

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 2d15bffbb0d64..60ec4482e5429 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -257,7 +257,7 @@ func @read_of_matching_insert_slice_source(
 
   //      CHECK: linalg.fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<?xf32> -> tensor<?xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
 
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]}
@@ -289,7 +289,7 @@ func @read_of_matching_insert_slice_source_interleaved(
 
   //      CHECK: linalg.fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]}
-  %1 = linalg.fill(%cst, %0) : f32, tensor<?xf32> -> tensor<?xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
 
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]}
@@ -301,7 +301,7 @@ func @read_of_matching_insert_slice_source_interleaved(
 
   //      CHECK: linalg.fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]}
-  %5 = linalg.fill(%cst, %4) : f32, tensor<?xf32> -> tensor<?xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<?xf32>) -> tensor<?xf32>
 
   %3 = vector.transfer_read %1[%idx2], %cst2 : tensor<?xf32>, vector<5xf32>
 
@@ -501,7 +501,7 @@ func @nested_extract_slice_and_insert(
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "false", "none", "none"]}
   %sA = tensor.extract_slice %A[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
   %ssA = tensor.extract_slice %sA[0, 0][4, 4][1, 1] : tensor<?x?xf32> to tensor<4x4xf32>
-  %FA = linalg.fill(%f0, %ssA) : f32, tensor<4x4xf32> -> tensor<4x4xf32>
+  %FA = linalg.fill ins(%f0 : f32) outs(%ssA : tensor<4x4xf32>) -> tensor<4x4xf32>
   %rsA = tensor.insert_slice %FA into %sA[0, 0][4, 4][1, 1] : tensor<4x4xf32> into tensor<?x?xf32>
   %rA = tensor.insert_slice %rsA into %A[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
 
@@ -524,7 +524,7 @@ func @nested_extract_slice_and_insert(
   %sB = tensor.extract_slice %B[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
   %ssB = tensor.extract_slice %sB[0, 0][4, %idx][1, 1] : tensor<?x?xf32> to tensor<4x?xf32>
   %sssB = tensor.extract_slice %ssB[0, 0][4, 4][1, 1] : tensor<4x?xf32> to tensor<4x4xf32>
-  %FB = linalg.fill(%f0, %sssB) : f32, tensor<4x4xf32> -> tensor<4x4xf32>
+  %FB = linalg.fill ins(%f0 : f32) outs(%sssB : tensor<4x4xf32>) -> tensor<4x4xf32>
   %rssB = tensor.insert_slice %FB into %ssB[0, 0][4, 4][1, 1] : tensor<4x4xf32> into tensor<4x?xf32>
   %rsB = tensor.insert_slice %rssB into %sB[0, 0][4, %idx][1, 1] : tensor<4x?xf32> into tensor<?x?xf32>
   %rB = tensor.insert_slice %rsB into %B[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
@@ -547,7 +547,7 @@ func @nested_extract_slice_and_insert(
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]}
   %sC = tensor.extract_slice %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
   %ssC = tensor.extract_slice %sC[0, 0][%sz1, 4][1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
-  %FC = linalg.fill(%f0, %ssC) : f32, tensor<?x4xf32> -> tensor<?x4xf32>
+  %FC = linalg.fill ins(%f0 : f32) outs(%ssC : tensor<?x4xf32>) -> tensor<?x4xf32>
   %rsC = tensor.insert_slice %FC into %sC[0, 0][%sz2, 4][1, 1] : tensor<?x4xf32> into tensor<?x?xf32>
   %rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
 
@@ -689,12 +689,12 @@ func @dependence_through_call(%I : tensor<64xf32> {linalg.inplaceable = true}) {
   // cannot bufferize inplace.
   //     CHECK: fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false"]}
-  %A = linalg.fill(%f1, %I) : f32, tensor<64xf32> -> tensor<64xf32>
+  %A = linalg.fill ins(%f1 : f32) outs(%I : tensor<64xf32>) -> tensor<64xf32>
 
   // 1. Bufferizes inplace: no alias to %A is yet possible.
   //     CHECK: fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]}
-  %B = linalg.fill(%f2, %I) : f32, tensor<64xf32> -> tensor<64xf32>
+  %B = linalg.fill ins(%f2 : f32) outs(%I : tensor<64xf32>) -> tensor<64xf32>
 
   call @foo(%A) : (tensor<64xf32>) -> ()
   call @foo(%B) : (tensor<64xf32>) -> ()
@@ -725,12 +725,12 @@ func @read_dependence_through_scf_and_call(
   // bufferize inplace.
   //     CHECK: fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false"]}
-  %A = linalg.fill(%f1, %I) : f32, tensor<64xf32> -> tensor<64xf32>
+  %A = linalg.fill ins(%f1 : f32) outs(%I : tensor<64xf32>) -> tensor<64xf32>
 
   // 4. Bufferizes inplace: no alias to %A is yet possible.
   //     CHECK: fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]}
-  %B = linalg.fill(%f2, %I) : f32, tensor<64xf32> -> tensor<64xf32>
+  %B = linalg.fill ins(%f2 : f32) outs(%I : tensor<64xf32>) -> tensor<64xf32>
 
   // 3. Does not read or write, bufferizes inplace.
   //      CHECK: scf.for
@@ -750,12 +750,12 @@ func @read_dependence_through_scf_and_call(
   // cannot bufferize inplace.
   //     CHECK: fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false"]}
-  %A2 = linalg.fill(%f1, %I2) : f32, tensor<64xf32> -> tensor<64xf32>
+  %A2 = linalg.fill ins(%f1 : f32) outs(%I2 : tensor<64xf32>) -> tensor<64xf32>
 
   // 1. Bufferizes inplace: no alias to %A2 is yet possible.
   //     CHECK: fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]}
-  %B2 = linalg.fill(%f2, %I2) : f32, tensor<64xf32> -> tensor<64xf32>
+  %B2 = linalg.fill ins(%f2 : f32) outs(%I2 : tensor<64xf32>) -> tensor<64xf32>
 
   call @bar(%A2) : (tensor<64xf32>) -> ()
   call @bar(%B2) : (tensor<64xf32>) -> ()
@@ -800,8 +800,8 @@ builtin.func @matmul_on_tensors(
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false"]}
   //      CHECK: linalg.fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]}
-  %8 = linalg.fill(%cst_0, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
-  %11 = linalg.fill(%cst_1, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %8 = linalg.fill ins(%cst_0 : f32) outs(%7 : tensor<256x256xf32>) -> tensor<256x256xf32>
+  %11 = linalg.fill ins(%cst_1 : f32) outs(%7 : tensor<256x256xf32>) -> tensor<256x256xf32>
 
   //      CHECK: tensor.extract_slice
   // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
@@ -838,7 +838,7 @@ builtin.func @matmul_on_tensors(
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false"]}
   //      CHECK: vector.transfer_write
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none", "none"]
-  %8 = linalg.fill(%cst_0, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %8 = linalg.fill ins(%cst_0 : f32) outs(%7 : tensor<256x256xf32>) -> tensor<256x256xf32>
   %9 = vector.transfer_read %arg0[%c0, %c0], %cst_0 {in_bounds = [false, true]} : tensor<518x518xf32>, vector<256x256xf32>
   %10 = vector.transfer_write %9, %8[%c0, %c0] {in_bounds = [true, true]} : vector<256x256xf32>, tensor<256x256xf32>
 
@@ -846,7 +846,7 @@ builtin.func @matmul_on_tensors(
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]}
   //      CHECK: vector.transfer_write
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none", "none"]
-  %11 = linalg.fill(%cst_1, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %11 = linalg.fill ins(%cst_1 : f32) outs(%7 : tensor<256x256xf32>) -> tensor<256x256xf32>
   %12 = vector.transfer_read %arg1[%c0, %c0], %cst_0 {in_bounds = [false, true]} : tensor<518x518xf32>, vector<256x256xf32>
   %13 = vector.transfer_write %12, %11[%c0, %c0] {in_bounds = [true, true]} : vector<256x256xf32>, tensor<256x256xf32>
 
@@ -891,7 +891,7 @@ func @insert_slice_chain(
 
   //      CHECK: linalg.fill
   // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]
-  %0 = linalg.fill(%cst, %arg2) : f32, tensor<62x90xf32> -> tensor<62x90xf32>
+  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<62x90xf32>) -> tensor<62x90xf32>
 
   //      CHECK: tensor.extract_slice
   // CHECK-SAME: {__inplace_operands_attr__ = ["true"]

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
index a96deea161cdf..7b510b48c9cfd 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
@@ -23,8 +23,8 @@ func @buffer_forwarding_conflict(
   //     CHECK: %[[T_SUBVIEW:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
   %a = linalg.init_tensor[%sz] : tensor<?xf32>
 
-  //     CHECK: linalg.fill({{.*}}, %[[EXTRACT_SLICE_ALLOC]]) : f32, memref<?xf32>
-  %f = linalg.fill(%f0, %a) : f32, tensor<?xf32> -> tensor<?xf32>
+  //     CHECK: linalg.fill ins({{.*}} : f32) outs(%[[EXTRACT_SLICE_ALLOC]] : memref<?xf32>)
+  %f = linalg.fill ins(%f0 : f32) outs(%a : tensor<?xf32>) -> tensor<?xf32>
 
   //     CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref<?xf32> to memref<?xf32>
   //     CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref<?xf32> to memref<?xf32>
@@ -54,8 +54,8 @@ func @buffer_forwarding_no_conflict(
   // CHECK: %[[T_SUBVIEW:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
   %a = linalg.init_tensor[%sz] : tensor<?xf32>
 
-  // CHECK: linalg.fill({{.*}}, %[[T_SUBVIEW]]) : f32, memref<?xf32
-  %f = linalg.fill(%f0, %a) : f32, tensor<?xf32> -> tensor<?xf32>
+  // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW]] : memref<?xf32
+  %f = linalg.fill ins(%f0 : f32) outs(%a : tensor<?xf32>) -> tensor<?xf32>
 
   // Self-copy canonicalizes away later.
   %r1 = tensor.insert_slice %f into %t[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
@@ -81,8 +81,8 @@ func @insertion_point_inside_loop(%t : tensor<?xf32>, %sz : index) -> (tensor<?x
     %iv_i32 = arith.index_cast %iv : index to i32
     %f = arith.sitofp %iv_i32 : i32 to f32
 
-    // CHECK: linalg.fill(%{{.*}}, %[[subview]])
-    %filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32>
+    // CHECK: linalg.fill ins(%{{.*}}{{.*}}outs(%[[subview]]
+    %filled = linalg.fill ins(%f : f32) outs(%blank : tensor<5xf32>) -> tensor<5xf32>
 
     // CHECK-NOT: memref.copy
     %inserted = tensor.insert_slice %filled into %bb[%iv][5][1] : tensor<5xf32> into tensor<?xf32>
@@ -111,8 +111,8 @@ func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
     %iv_i32 = arith.index_cast %iv : index to i32
     %f = arith.sitofp %iv_i32 : i32 to f32
 
-    // CHECK: linalg.fill(%{{.*}}, %[[subview]])
-    %filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32>
+    // CHECK: linalg.fill ins(%{{.*}}{{.*}}outs(%[[subview]]
+    %filled = linalg.fill ins(%f : f32) outs(%blank : tensor<5xf32>) -> tensor<5xf32>
 
     // CHECK-NOT: memref.copy
     %inserted = tensor.insert_slice %filled into %bb[%idx][5][1] : tensor<5xf32> into tensor<?xf32>

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 2adf2aadc2d93..02410c811e933 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -173,7 +173,7 @@ func @unknown_op(%A : tensor<4xf32>) -> tensor<4xf32>
 func @mini_test_case1() -> tensor<10x20xf32> {
   %f0 = arith.constant 0.0 : f32
   %t = linalg.init_tensor [10, 20] : tensor<10x20xf32>
-  %r = linalg.fill(%f0, %t) : f32, tensor<10x20xf32> -> tensor<10x20xf32>
+  %r = linalg.fill ins(%f0 : f32) outs(%t : tensor<10x20xf32>) -> tensor<10x20xf32>
   // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}}
   return %r : tensor<10x20xf32>
 }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index b1e086c9e7035..0baf61ff49773 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -40,8 +40,8 @@ func @fill_inplace(
 
   /// Inplaceable, no alloc
   // CHECK-NOT: alloc
-  //     CHECK: linalg.fill(%[[F0]], %[[A]]) : f32, memref<?xf32, #[[$map_1d_dyn]]>
-  %r = linalg.fill(%f0, %A) : f32, tensor<?xf32> -> tensor<?xf32>
+  //     CHECK: linalg.fill ins(%[[F0]] : f32) outs(%[[A]] : memref<?xf32, #[[$map_1d_dyn]]>)
+  %r = linalg.fill ins(%f0 : f32) outs(%A : tensor<?xf32>) -> tensor<?xf32>
 
   //     CHECK: return
   // CHECK-NOT: tensor
@@ -78,8 +78,8 @@ func @not_inplace(
 
   //     CHECK: %[[D0:.*]] = memref.dim %[[A]], {{.*}} : memref<?xf32, #[[$map_1d_dyn]]>
   //     CHECK: %[[ALLOC:.*]] = memref.alloc(%[[D0]]) {alignment = 128 : i64} : memref<?xf32>
-  //     CHECK: linalg.fill(%[[F0]], %[[ALLOC]]) : f32, memref<?xf32>
-  %r = linalg.fill(%f0, %A) : f32, tensor<?xf32> -> tensor<?xf32>
+  //     CHECK: linalg.fill ins(%[[F0]] : f32) outs(%[[ALLOC]] : memref<?xf32>)
+  %r = linalg.fill ins(%f0 : f32) outs(%A : tensor<?xf32>) -> tensor<?xf32>
 
   //     CHECK:  dealloc %[[ALLOC]] : memref<?xf32>
   //     CHECK:  return %[[ALLOC]] : memref<?xf32>
@@ -101,8 +101,8 @@ func @not_inplace(
 
   /// Cross-op multiple uses of %A, the first op which has interfering reads must alloc.
   //       CHECK: %[[ALLOC:.*]] = memref.alloc
-  //       CHECK: linalg.fill({{.*}}, %[[ALLOC]]
-  %f = linalg.fill(%f0, %A) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+  //       CHECK: linalg.fill ins({{.*}}{{.*}}outs(%[[ALLOC]]
+  %f = linalg.fill ins(%f0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
 
   /// The second op has no interfering reads and can reuse.
   //   CHECK-NOT: alloc
@@ -240,8 +240,8 @@ func @insert_slice_fun(
   %r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   /// Overwrite A inplace.
-  //      CHECK: linalg.fill({{.*}}, %[[A]]
-  %r1 = linalg.fill(%f0, %r0) : f32, tensor<?xf32> -> tensor<?xf32>
+  //      CHECK: linalg.fill ins({{.*}}{{.*}}outs(%[[A]]
+  %r1 = linalg.fill ins(%f0 : f32) outs(%r0 : tensor<?xf32>) -> tensor<?xf32>
 
   //     CHECK: return
   // CHECK-NOT: tensor
@@ -262,8 +262,8 @@ func @insert_slice_fun(
 {
   %f0 = arith.constant 0.0 : f32
 
-  //      CHECK: linalg.fill({{.*}}, %[[A]]
-  %r0 = linalg.fill(%f0, %A) : f32, tensor<?xf32> -> tensor<?xf32>
+  //      CHECK: linalg.fill ins({{.*}}{{.*}}outs(%[[A]]
+  %r0 = linalg.fill ins(%f0 : f32) outs(%A : tensor<?xf32>) -> tensor<?xf32>
 
   //  CHECK-NOT: alloc
   //      CHECK: %[[SV_A:.*]] = memref.subview %[[A]]
@@ -581,8 +581,8 @@ func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> t
   // CHECK-NEXT:   %[[C0:.*]] = arith.constant 0{{.*}} : f32
   %v0 = arith.constant 0.0 : f32
 
-  // CHECK-NEXT:   linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32, #[[$DYN_0D_MAP]]>
-  %d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
+  // CHECK-NEXT:   linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
+  %d = linalg.fill ins(%v0 : f32) outs(%c : tensor<f32>) -> tensor<f32>
 
   // CHECK-NEXT:   linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
   %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
@@ -611,12 +611,12 @@ func @main() {
   %B = linalg.init_tensor [64] : tensor<64xf32>
   %C = linalg.init_tensor [] : tensor<f32>
 
-  // CHECK-NEXT:   linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32>
-  // CHECK-NEXT:   linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32>
-  // CHECK-NEXT:   linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32>
-  %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
-  %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
-  %CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
+  // CHECK-NEXT:   linalg.fill ins(%[[C1]] : f32) outs(%[[A]] : memref<64xf32>)
+  // CHECK-NEXT:   linalg.fill ins(%[[C2]] : f32) outs(%[[B]] : memref<64xf32>)
+  // CHECK-NEXT:   linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref<f32>)
+  %AA = linalg.fill ins(%v1 : f32) outs(%A : tensor<64xf32>) -> tensor<64xf32>
+  %BB = linalg.fill ins(%v2 : f32) outs(%B : tensor<64xf32>) -> tensor<64xf32>
+  %CC = linalg.fill ins(%v0 : f32) outs(%C : tensor<f32>) -> tensor<f32>
 
   // CHECK-NEXT:   call @init_and_dot(%[[cA]], %[[cB]], %[[cC]])
   %res = call @init_and_dot(%AA, %BB, %CC) :
@@ -730,8 +730,8 @@ func @matmul(
         tensor<128x192xf32> to tensor<8x16xf32>
 
       // linalg.fill is inplace.
-      // CHECK: linalg.fill(%{{.*}}, %[[ALLOC]]) : f32, memref<8x16xf32>
-      %5 = linalg.fill(%cst, %4) : f32, tensor<8x16xf32> -> tensor<8x16xf32>
+      // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%[[ALLOC]] : memref<8x16xf32>)
+      %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<8x16xf32>) -> tensor<8x16xf32>
 
       // CHECK: scf.for %[[K:.*]] =
       %6 = scf.for %arg7 = %c0 to %c256 step %c32 iter_args(%arg8 = %5) -> (tensor<8x16xf32>) {
@@ -800,7 +800,7 @@ func @dominance_violation_bug_1(
 
   %sA = tensor.extract_slice %A[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
   %ssA = tensor.extract_slice %sA[0, 0][4, 4][1, 1] : tensor<?x?xf32> to tensor<4x4xf32>
-  %FA = linalg.fill(%f0, %ssA) : f32, tensor<4x4xf32> -> tensor<4x4xf32>
+  %FA = linalg.fill ins(%f0 : f32) outs(%ssA : tensor<4x4xf32>) -> tensor<4x4xf32>
   %rsA = tensor.insert_slice %FA into %sA[0, 0][4, 4][1, 1] : tensor<4x4xf32> into tensor<?x?xf32>
   %rA = tensor.insert_slice %rsA into %A[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
 

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index ed49c6b212d40..ef1ed2618012e 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -156,7 +156,7 @@ func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor<5xf3
   %0 = linalg.generic #trait
      ins(%arg0 : tensor<1x5xf32>)
     outs(%shape : tensor<5xf32>) {
-  ^bb0(%arg2: f32, %arg3: f32):     
+  ^bb0(%arg2: f32, %arg3: f32):
     linalg.yield %arg2 : f32
   } -> tensor<5xf32>
   return %0 : tensor<5xf32>
@@ -250,7 +250,7 @@ func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5xf32>
   %2 = linalg.generic {i64, indexing_maps = [#map1, #map0],
     iterator_types = ["parallel", "parallel", "parallel"]}
     ins(%arg0 : tensor<5xf32>) outs(%1 : tensor<1x2x5xf32>) {
-    ^bb0(%arg1: f32, %arg2: f32):  
+    ^bb0(%arg1: f32, %arg2: f32):
       linalg.yield %arg1 : f32
     } -> tensor<1x2x5xf32>
   %3 = tensor.collapse_shape %2 [[0, 1], [2]]
@@ -266,7 +266,7 @@ func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5xf32>
 func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32> {
   %cst = arith.constant 0.0 : f32
   %init = linalg.init_tensor [1] : tensor<1xf32>
-  %fill = linalg.fill(%cst, %init) : f32, tensor<1xf32> -> tensor<1xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1xf32>) -> tensor<1xf32>
   %add = linalg.generic {
       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
       iterator_types = ["parallel", "reduction"]}
@@ -287,7 +287,7 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32>
 
 //       CHECK: %[[INPUT_RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32>
 //       CHECK: %[[INIT:.+]] = linalg.init_tensor [] : tensor<f32>
-//       CHECK: %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor<f32> -> tensor<f32>
+//       CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<f32>) -> tensor<f32>
 //       CHECK: %[[GENERIC:.+]] = linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP2]]]
 //  CHECK-SAME:     iterator_types = ["reduction"]
@@ -331,14 +331,14 @@ func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
   %c3 = arith.constant 3 : index
   %0 = tensor.dim %arg0, %c3 : tensor<1x?x1x?xf32>
   %1 = linalg.init_tensor [1, %0] : tensor<1x?xf32>
-  %2 = linalg.fill(%cst, %1) : f32, tensor<1x?xf32> -> tensor<1x?xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x?xf32>) -> tensor<1x?xf32>
   %3 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
                      affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
     iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
     ins(%arg0 : tensor<1x?x1x?xf32>)
     outs(%2 : tensor<1x?xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32):  
+  ^bb0(%arg1: f32, %arg2: f32):
     %4 = arith.addf %arg1, %arg2 : f32
     linalg.yield %4 : f32
   } -> tensor<1x?xf32>
@@ -350,7 +350,7 @@ func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x?xf32>
 //  CHECK-DAG:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor<?xf32>
-//      CHECK:   %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
+//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
 //      CHECK:   %[[RESULT:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
 // CHECK-SAME:     iterator_types = ["parallel", "reduction"]
@@ -365,14 +365,14 @@ func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32>
   %cst = arith.constant 1.000000e+00 : f32
   %c3 = arith.constant 3 : index
   %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
-  %2 = linalg.fill(%cst, %1) : f32, tensor<1x1xf32> -> tensor<1x1xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
   %3 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
                      affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
     iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
     ins(%arg0 : tensor<1x?x1x1xf32>)
     outs(%2 : tensor<1x1xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32):  
+  ^bb0(%arg1: f32, %arg2: f32):
     %4 = arith.addf %arg1, %arg2 : f32
     linalg.yield %4 : f32
   } -> tensor<1x1xf32>
@@ -383,7 +383,7 @@ func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32>
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x1xf32>
 //  CHECK-DAG:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3]
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32>
-//      CHECK:   %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
+//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
 //      CHECK:   %[[RESULT:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]]]
 // CHECK-SAME:     iterator_types = ["parallel"]
@@ -399,14 +399,14 @@ func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32
   %c2 = arith.constant 2 : index
   %0 = tensor.dim %arg0, %c2 : tensor<?x1x?x1xf32>
   %1 = linalg.init_tensor [%0, 1] : tensor<?x1xf32>
-  %2 = linalg.fill(%cst, %1) : f32, tensor<?x1xf32> -> tensor<?x1xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?x1xf32>) -> tensor<?x1xf32>
   %3 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
                      affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
     iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
     ins(%arg0 : tensor<?x1x?x1xf32>)
     outs(%2 : tensor<?x1xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32):  
+  ^bb0(%arg1: f32, %arg2: f32):
     %4 = arith.addf %arg1, %arg2 : f32
     linalg.yield %4 : f32
   } -> tensor<?x1xf32>
@@ -418,7 +418,7 @@ func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<?x1x?x1xf32>
 //  CHECK-DAG:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]]
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor<?xf32>
-//      CHECK:   %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
+//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
 //      CHECK:   %[[RESULT:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
 // CHECK-SAME:     iterator_types = ["parallel", "reduction"]
@@ -608,7 +608,7 @@ func @leading_dim_1_canonicalization(%arg0: memref<1x5xf32>, %shape: memref<5xf3
   linalg.generic #trait
      ins(%arg0 : memref<1x5xf32>)
     outs(%shape : memref<5xf32>) {
-  ^bb0(%arg2: f32, %arg3: f32):     
+  ^bb0(%arg2: f32, %arg3: f32):
     linalg.yield %arg2 : f32
   }
   return %shape : memref<5xf32>
@@ -702,7 +702,7 @@ func @fold_unit_dim_memref_reshape_op(%arg0 : memref<5xf32>) -> memref<2x5xf32>
   linalg.generic {i64, indexing_maps = [#map1, #map0],
     iterator_types = ["parallel", "parallel", "parallel"]}
     ins(%arg0 : memref<5xf32>) outs(%1 : memref<1x2x5xf32>) {
-    ^bb0(%arg1: f32, %arg2: f32):  
+    ^bb0(%arg1: f32, %arg2: f32):
       linalg.yield %arg1 : f32
     }
   %3 = memref.collapse_shape %1 [[0, 1], [2]]
@@ -792,7 +792,7 @@ func @input_stays_same(%arg0 : memref<?x1x?xf32, #map0>, %arg1 : f32, %shape: me
 // CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]}
 // CHECK-SAME:   ins(%[[ARG0]], %[[ARG1]] : memref<?x1x?xf32, #[[MAP0]]>, f32)
 // CHECK-SAME:   outs(%[[OUT]] : memref<?x?x?xf32>) {
-// CHECK:      ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32): 
+// CHECK:      ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32):
 // CHECK:       linalg.yield %[[ARG]] : f32
 // CHECK:      }
 // CHECK:      return %[[ARG2]] : memref<?x1x?x1x?xf32>

diff  --git a/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir b/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir
index b8e39196d5cce..1afb5e13facaa 100644
--- a/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir
+++ b/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir
@@ -29,7 +29,7 @@ func @testAllocFillRead(%in: memref<? x f32>) -> vector<32 x f32> {
   %c0 = arith.constant 0: index
   %f0 = arith.constant 0.0: f32
   %alloc = memref.alloc() : memref<32 x f32>
-  linalg.fill(%f0, %alloc) : f32, memref<32 x f32>
+  linalg.fill ins(%f0 : f32) outs(%alloc : memref<32 x f32>)
   %subview = memref.subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
   memref.copy %in, %subview : memref<? x f32> to memref<16 x f32>
   %0 = vector.transfer_read %alloc[%c0], %f0 {in_bounds = [true]} : memref<32 x f32>, vector<32 x f32>
@@ -69,7 +69,7 @@ func @testViewFillRead(%in: memref<? x f32>) -> vector<32 x f32> {
   %alloc = memref.alloc() : memref<128 x i8>
   %view = memref.view %alloc[%c0][] : memref<128 x i8> to memref<32 x f32>
   %subview = memref.subview %view[0][16][1] : memref<32 x f32> to memref<16 x f32>
-  linalg.fill(%f0, %view) : f32, memref<32 x f32>
+  linalg.fill ins(%f0 : f32) outs(%view : memref<32 x f32>)
   memref.copy %in, %subview : memref<? x f32> to memref<16 x f32>
   %0 = vector.transfer_read %view[%c0], %f0 {in_bounds = [true]} : memref<32 x f32>, vector<32 x f32>
   memref.dealloc %alloc : memref<128 x i8>
@@ -129,7 +129,7 @@ func @failAllocFillRead(%in: memref<? x f32>) -> vector<32 x f32> {
   %f0 = arith.constant 0.0: f32
   %f1 = arith.constant 1.0: f32
   %alloc = memref.alloc() : memref<32 x f32>
-  linalg.fill(%f0, %alloc) : f32, memref<32 x f32>
+  linalg.fill ins(%f0 : f32) outs(%alloc : memref<32 x f32>)
   %subview = memref.subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
   memref.copy %in, %subview : memref<? x f32> to memref<16 x f32>
   "some_interleaved_use"(%subview) : (memref<16 x f32>) -> ()

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index d36186c396f87..0572ecd98ffc7 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -934,7 +934,7 @@ func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> ten
     linalg.yield %arg2 : f32
   } -> tensor<?x?xf32>
   %6 = linalg.init_tensor [%arg1] : tensor<?xf32>
-  %7 = linalg.fill(%cst, %6) : f32, tensor<?xf32> -> tensor<?xf32>
+  %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
   %8 = linalg.generic {
     indexing_maps = [#map2, #map3],
     iterator_types = ["parallel", "reduction"]

diff  --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
index 38378ff615b88..a107f46f6b250 100644
--- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
@@ -4,7 +4,7 @@ module {
   func @basic_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
                      %arg2: memref<?x?xf32>) {
     %cst = arith.constant 0.000000e+00 : f32
-    linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
+    linalg.fill ins(%cst : f32) outs(%arg2 : memref<?x?xf32>)
     linalg.matmul {__internal_linalg_transform__ = "basic_fusion"}
       ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
       outs(%arg2 : memref<?x?xf32>)
@@ -28,8 +28,9 @@ module {
 //  CHECK-DAG:   %[[C64:.+]] = arith.constant 64 : index
 //  CHECK-DAG:   %[[C16:.+]] = arith.constant 16 : index
 //  CHECK-DAG:   %[[CST:.+]] = arith.constant 0.0{{.*}} : f32
-//  CHECK-DAG:   linalg.fill(%[[CST]], %[[ARG2]])
+//  CHECK-DAG:   linalg.fill
 // CHECK-SAME:   __internal_linalg_transform__ = "after_basic_fusion_original"
+// CHECK-SAME:   ins(%[[CST]]{{.*}}outs(%[[ARG2]]
 //  CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
 //  CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
 //      CHECK:   scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) =
@@ -51,8 +52,9 @@ module {
 //      CHECK:     %[[TILE_N_3:.+]] = affine.min #[[MAP5]](%[[IV1]])[%[[N_2]], %[[N]]]
 //      CHECK:     %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
 // CHECK-SAME:       [%[[TILE_M_3]], %[[TILE_N_3]]]
-//      CHECK:     linalg.fill(%[[CST]], %[[SV3_2]])
+//      CHECK:     linalg.fill
 // CHECK-SAME:       __internal_linalg_transform__ = "after_basic_fusion_producer"
+// CHECK-SAME:       ins(%[[CST]]{{.*}}outs(%[[SV3_2]]
 //      CHECK:     scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
 //      CHECK:       %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
 //      CHECK:       %[[SV4:.+]] = memref.subview %[[SV1]][0, %[[IV2]]]
@@ -250,7 +252,7 @@ module {
     %c64 = arith.constant 64 : index
     %c16 = arith.constant 16 : index
     %cst = arith.constant 0.000000e+00 : f32
-    linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
+    linalg.fill ins(%cst : f32) outs(%arg2 : memref<?x?xf32>)
     %0 = memref.dim %arg0, %c0 : memref<?x?xf32>
     %1 = memref.dim %arg1, %c1 : memref<?x?xf32>
     %2 = memref.dim %arg0, %c1 : memref<?x?xf32>
@@ -285,7 +287,7 @@ module {
   func @basic_conv_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
                           %arg2: memref<?x?xf32>) {
     %cst = arith.constant 0.000000e+00 : f32
-    linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
+    linalg.fill ins(%cst : f32) outs(%arg2 : memref<?x?xf32>)
     linalg.conv_2d {__internal_linalg_transform__ = "basic_fusion"}
       ins(%arg1, %arg0 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>)
     return

diff  --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
index 4a3f898882ee9..10dc8d3fcb7d2 100644
--- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
@@ -9,7 +9,7 @@ module {
     %d0 = memref.dim %arg0, %c0 : memref<?x?xf32>
     %d1 = memref.dim %arg1, %c1 : memref<?x?xf32>
     %0 = memref.alloc(%d0, %d1) : memref<?x?xf32>
-    linalg.fill(%cst, %0) : f32, memref<?x?xf32>
+    linalg.fill ins(%cst : f32) outs(%0 : memref<?x?xf32>)
     linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
       outs(%0 : memref<?x?xf32>)
     linalg.generic
@@ -43,7 +43,7 @@ module {
 //   CHECK-DAG:     %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]]
 //       CHECK:     %[[SV_TEMP_2:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
 //       CHECK:     %[[SV_TEMP_3:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
-//       CHECK:     linalg.fill(%{{.+}}, %[[SV_TEMP_3]])
+//       CHECK:     linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_TEMP_3]]
 //       CHECK:     linalg.matmul
 //  CHECK-SAME:       ins(%[[SV_ARG0]], %[[SV_ARG1]]
 //  CHECK-SAME:         : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
@@ -70,13 +70,13 @@ module {
     %n3 = memref.dim %arg3, %c1 : memref<?x?xf32>
     %0 = memref.alloc(%m, %n1) : memref<?x?xf32>
     %1 = memref.alloc(%m, %n2) : memref<?x?xf32>
-    linalg.fill(%cst, %0) : f32, memref<?x?xf32>
+    linalg.fill ins(%cst : f32) outs(%0 : memref<?x?xf32>)
     linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
       outs(%0 : memref<?x?xf32>)
-    linalg.fill(%cst, %1) : f32, memref<?x?xf32>
+    linalg.fill ins(%cst : f32) outs(%1 : memref<?x?xf32>)
     linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
       outs(%1 : memref<?x?xf32>)
-    linalg.fill(%cst, %arg4) : f32, memref<?x?xf32>
+    linalg.fill ins(%cst : f32) outs(%arg4 : memref<?x?xf32>)
     linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
       outs(%arg4 : memref<?x?xf32>)
     return
@@ -126,15 +126,15 @@ module {
 //  CHECK-SAME:       [%[[TILE_M_5]], %[[N0]]]
 //       CHECK:     %[[SV_ALLOC4:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0]
 //  CHECK-SAME:       [%[[TILE_M_5]], %[[N1]]]
-//       CHECK:     linalg.fill(%{{.+}}, %[[SV_ALLOC1]])
+//       CHECK:     linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC1]]
 //       CHECK:     linalg.matmul ins(%[[SV_ARG0]], %[[ARG1]]
 //  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
 //  CHECK-SAME:        outs(%[[SV_ALLOC4]] : memref<?x?xf32, #[[MAP1]]>)
-//       CHECK:     linalg.fill(%{{.+}}, %[[SV_ALLOC2]])
+//       CHECK:     linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC2]]
 //       CHECK:     linalg.matmul ins(%[[SV_ALLOC1]], %[[ARG2]]
 //  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
 //  CHECK-SAME:        outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
-//       CHECK:     linalg.fill(%{{.+}}, %[[SV_ARG4_2]])
+//       CHECK:     linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ARG4_2]]
 //       CHECK:     linalg.matmul ins(%[[SV_ALLOC3]], %[[ARG3]]
 //  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
 //  CHECK-SAME:        outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index ef315c96f16b4..fa934ee8a9f0e 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -142,7 +142,7 @@ module {
   func @matmul_out_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
                       %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
     %c0 = arith.constant 0.0 : f32
-    %0 = linalg.fill(%c0, %arg0) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+    %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
     %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"}
       ins(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
@@ -159,7 +159,9 @@ module {
 //       CHECK: scf.for %[[I:.*]]{{.*}}iter_args(%{{.*}} = %[[ARG0]]) -> (tensor<?x?xf32>) {
 //       CHECK:   scf.for %[[J:.*]]
 //       CHECK:     %[[ST:.*]] = tensor.extract_slice %[[ARG0]]
-//       CHECK:     %[[ST_FILL:.*]] = linalg.fill(%[[C0]], %[[ST]]) {__internal_linalg_transform__ = "after_out_fusion_producer"} : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+//       CHECK:     %[[ST_FILL:.*]] = linalg.fill
+//  CHECK-SAME:       {__internal_linalg_transform__ = "after_out_fusion_producer"}
+//  CHECK-SAME:       ins(%[[C0]] : f32) outs(%[[ST]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 //       CHECK:     %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor<?x?xf32>) {
 //   CHECK-NOT:       fill
 //       CHECK:       %[[ST_FILL_SUB:.*]] = tensor.extract_slice %[[BB]][0, 0]

diff  --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index ccdb043ec1e40..9f3dd44e7266f 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -685,7 +685,7 @@ func @fill_and_conv(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memre
   %c3 = arith.constant 3 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  linalg.fill(%cst, %arg0) : f32, memref<?x?xf32>
+  linalg.fill ins(%cst : f32) outs(%arg0 : memref<?x?xf32>)
   %2 = memref.dim %arg1, %c0 : memref<?x?xf32>
   %3 = memref.dim %arg1, %c1 : memref<?x?xf32>
   %4 = memref.dim %arg2, %c0 : memref<?x?xf32>

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 961a31d293b9a..5602d18f9abfb 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -156,7 +156,7 @@ func @conv_1d_nwc_wcf(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %ou
 // -----
 
 func @generalize_fill(%output: memref<?x?xf32>, %value : f32) {
-  linalg.fill(%value, %output) : f32, memref<?x?xf32>
+  linalg.fill ins(%value : f32) outs(%output : memref<?x?xf32>)
   return
 }
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 3ac2d752bdd93..d455d84c228d6 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -222,7 +222,7 @@ func @generalize_pooling_nhwc_sum_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
 // -----
 
 func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
-  %0 = linalg.fill_tensor ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
+  %0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
   return %0: tensor<f32>
 }
 
@@ -236,7 +236,7 @@ func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
 // -----
 
 func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) {
-  linalg.fill_tensor ins(%value: f64) outs(%O : memref<16x32xf32>)
+  linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>)
   return
 }
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
index 8b5e2a313d5a5..2ebec15b840d1 100644
--- a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
@@ -4,7 +4,7 @@
 // CHECK-SAME:                                             %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
 // CHECK:           %[[C0:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[INIT:.*]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32>
-// CHECK:           %[[FILL:.*]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x32x32x1xf32> -> tensor<1x32x32x1xf32>
+// CHECK:           %[[FILL:.*]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x32x32x1xf32>) -> tensor<1x32x32x1xf32>
 // CHECK:           %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 2, 2, 0] [1, 28, 28, 1] [1, 1, 1, 1] : tensor<1x28x28x1xf32> into tensor<1x32x32x1xf32>
 // CHECK:           return %[[PADDED]] : tensor<1x32x32x1xf32>
 func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
@@ -29,7 +29,7 @@ func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> tensor
 // CHECK:           %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
 // CHECK:           %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index
 // CHECK:           %[[INIT:.*]] = linalg.init_tensor [4, %[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]] : tensor<4x?x?x?xf32>
-// CHECK:           %[[FILL:.*]] = linalg.fill(%[[CST]], %[[INIT]]) : f32, tensor<4x?x?x?xf32> -> tensor<4x?x?x?xf32>
+// CHECK:           %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<4x?x?x?xf32>) -> tensor<4x?x?x?xf32>
 // CHECK:           %[[DIM1_1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
 // CHECK:           %[[DIM3_1:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
 // CHECK:           %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1_1]], 2, %[[DIM3_1]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>

diff  --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir
index 534e37a1f9ee0..1e3482bd1cf0c 100644
--- a/mlir/test/Dialect/Linalg/hoist-padding.mlir
+++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir
@@ -377,7 +377,7 @@ func @tile_and_fuse(%arg0: tensor<12x6xf32>,
     ^bb0(%arg5: index, %arg6: index):
       tensor.yield %cst : f32
     } : tensor<?x24xf32> to tensor<5x24xf32>
-    %5 = linalg.fill(%cst, %4) : f32, tensor<5x24xf32> -> tensor<5x24xf32>
+    %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<5x24xf32>) -> tensor<5x24xf32>
     %6 = tensor.extract_slice %5[0, 0] [%1, 24] [1, 1] : tensor<5x24xf32> to tensor<?x24xf32>
 
     // Check the first input operand is hoisted by one loop nest.

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 132d14d6c4156..f5f4ce70ce3ac 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -369,16 +369,7 @@ func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32)
 {
   %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
   // expected-error @+1 {{expected the number of results (0) to be equal to the number of output tensors (1)}}
-  linalg.fill(%arg2, %0) : f32, tensor<?x?xf32>
-}
-
-// -----
-
-func @illegal_fill_memref_with_return(%arg0 : memref<?x?xf32>, %arg1 : f32) -> tensor<?x?xf32>
-{
-  // expected-error @+1 {{op expected the number of results (1) to be equal to the number of output tensors (0)}}
-  %0 = linalg.fill(%arg1, %arg0) : f32, memref<?x?xf32> -> tensor<?x?xf32>
-  return %0 : tensor<?x?xf32>
+  linalg.fill ins(%arg2 : f32) outs(%0 : tensor<?x?xf32>)
 }
 
 // -----
@@ -387,7 +378,7 @@ func @illegal_fill_memref_with_tensor_return
   (%arg0 : memref<?x?xf32>, %arg1 : f32) -> tensor<?x?xf32>
 {
   // expected-error @+1 {{expected the number of results (1) to be equal to the number of output tensors (0)}}
-  %0 = linalg.fill(%arg1, %arg0) : f32, memref<?x?xf32> -> tensor<?x?xf32>
+  %0 = linalg.fill ins(%arg1 : f32) outs(%arg0 : memref<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
@@ -396,8 +387,8 @@ func @illegal_fill_memref_with_tensor_return
 func @illegal_fill_tensor_with_memref_return
   (%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
 {
-  // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'memref<?x?xf32>'}}
-  %0 = linalg.fill(%arg1, %arg0) : f32, tensor<?x?xf32> -> memref<?x?xf32>
+  // expected-error @+1 {{result #0 must be ranked tensor of any type values, but got 'memref<?x?xf32>'}}
+  %0 = linalg.fill ins(%arg1 : f32) outs(%arg0 : tensor<?x?xf32>) -> memref<?x?xf32>
   return %0 : memref<?x?xf32>
 }
 

diff  --git a/mlir/test/Dialect/Linalg/library-calls.mlir b/mlir/test/Dialect/Linalg/library-calls.mlir
index 89b8e520441c9..6c16acad4629e 100644
--- a/mlir/test/Dialect/Linalg/library-calls.mlir
+++ b/mlir/test/Dialect/Linalg/library-calls.mlir
@@ -14,7 +14,7 @@ func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>) -> (memref<?x?xf32>) {
   %C = memref.alloc(%x, %y) : memref<?x?xf32>
 
   // CHECK: call @linalg_fill_f32_viewsxsxf32({{.*}}) : (f32, memref<?x?xf32, {{.*}}>)
-  linalg.fill(%f0, %C) : f32, memref<?x?xf32>
+  linalg.fill ins(%f0 : f32) outs(%C : memref<?x?xf32>)
 
   // CHECK:  call @linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32({{.*}}) : (memref<?x?xf32, {{.*}}>, memref<?x?xf32, {{.*}}>, memref<?x?xf32, {{.*}}>) -> ()
   linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 077d9d3421513..8b3e76b686c7b 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -166,7 +166,7 @@ func @dot_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf3
 //       CHECKPARALLEL:   store %[[res]], %{{.*}}[] : memref<f32>
 
 func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
-  linalg.fill(%arg1, %arg0) : f32, memref<?xf32, offset: ?, strides: [1]>
+  linalg.fill ins(%arg1 : f32) outs(%arg0 : memref<?xf32, offset: ?, strides: [1]>)
   return
 }
 // CHECK-LABEL: func @fill_view(
@@ -180,7 +180,7 @@ func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
 //       CHECKPARALLEL:     store %{{.*}}, %{{.*}}[%{{.*}}] : memref<?xf32, #[[$strided1D]]>
 
 func @fill_view0(%arg0: memref<f32>, %arg1: f32) {
-  linalg.fill(%arg1, %arg0) : f32, memref<f32>
+  linalg.fill ins(%arg1 : f32) outs(%arg0 : memref<f32>)
   return
 }
 // CHECK-LABEL: func @fill_view0(%{{.*}}: memref<f32>, %{{.*}}: f32) {
@@ -190,7 +190,7 @@ func @fill_view0(%arg0: memref<f32>, %arg1: f32) {
 //       CHECKPARALLEL:   store %{{.*}}, %{{.*}}[] : memref<f32>
 
 func @fill_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: f32) {
-  linalg.fill(%arg1, %arg0) : f32, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+  linalg.fill ins(%arg1 : f32) outs(%arg0 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
   return
 }
 // CHECK-LABEL: func @fill_view3(

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index a06ca5f801278..39a13683bfa92 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -4,7 +4,7 @@
 func @depthwise_conv_2d_nhwc_hwcm_tensor(%input: tensor<2x4x5x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x3x4x2x3xf32> {
   %zero = arith.constant 0.000000e+00 : f32
   %init = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
-  %fill = linalg.fill(%zero, %init) : f32, tensor<2x3x4x2x3xf32> -> tensor<2x3x4x2x3xf32>
+  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
   // CHECK:      %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwcm
   // CHECK-SAME:   {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
   // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
@@ -70,7 +70,7 @@ func @depthwise_conv_2d_nhwc_hwc_memref(%input: memref<1x113x113x96xf32>, %filte
 func @depthwise_conv_2d_nhwc_hwcm_tensor_dilated(%input: tensor<2x8x9x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x6x7x2x3xf32> {
   %zero = arith.constant 0.000000e+00 : f32
   %init = linalg.init_tensor [2, 6, 7, 2, 3] : tensor<2x6x7x2x3xf32>
-  %fill = linalg.fill(%zero, %init) : f32, tensor<2x6x7x2x3xf32> -> tensor<2x6x7x2x3xf32>
+  %fill = linalg.fill ins(%zero : f32) outs(%init : tensor<2x6x7x2x3xf32>) -> tensor<2x6x7x2x3xf32>
   // CHECK:      %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwcm
   // CHECK-SAME:   {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
   // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>)
@@ -236,7 +236,7 @@ func @pooling_nhwc_sum_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32
   %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
   %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32>
   %cst = arith.constant 0.000000e+00 : f32
-  %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x1xf32> -> tensor<1x2x2x1xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
   %res = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
     ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
     outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
@@ -270,7 +270,7 @@ func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32
   %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
   %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32>
   %cst = arith.constant 0.000000e+00 : f32
-  %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x1xf32> -> tensor<1x2x2x1xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
   %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
     ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
     outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
@@ -289,7 +289,7 @@ func @pooling_nchw_max_tensor(%input: tensor<1x1x4x4xf32>) -> tensor<1x1x2x2xf32
   %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
   %init = linalg.init_tensor [1, 1, 2, 2] : tensor<1x1x2x2xf32>
   %cst = arith.constant 0.000000e+00 : f32
-  %fill = linalg.fill(%cst, %init) : f32, tensor<1x1x2x2xf32> -> tensor<1x1x2x2xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
   %res = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
     ins(%input, %fake: tensor<1x1x4x4xf32>, tensor<3x3xf32>)
     outs(%fill: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
@@ -323,7 +323,7 @@ func @pooling_nhwc_i8_max_tensor(%input: tensor<1x4x4x1xi8>) -> tensor<1x2x2x1xi
   %fake = linalg.init_tensor [3, 3] : tensor<3x3xi8>
   %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi8>
   %cst = arith.constant 0 : i8
-  %fill = linalg.fill(%cst, %init) : i8, tensor<1x2x2x1xi8> -> tensor<1x2x2x1xi8>
+  %fill = linalg.fill ins(%cst : i8) outs(%init : tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8>
   %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
     ins(%input, %fake: tensor<1x4x4x1xi8>, tensor<3x3xi8>)
     outs(%fill: tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8>
@@ -357,7 +357,7 @@ func @pooling_nhwc_i16_max_tensor(%input: tensor<1x4x4x1xi16>) -> tensor<1x2x2x1
   %fake = linalg.init_tensor [3, 3] : tensor<3x3xi16>
   %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi16>
   %cst = arith.constant 0 : i16
-  %fill = linalg.fill(%cst, %init) : i16, tensor<1x2x2x1xi16> -> tensor<1x2x2x1xi16>
+  %fill = linalg.fill ins(%cst : i16) outs(%init : tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16>
   %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
     ins(%input, %fake: tensor<1x4x4x1xi16>, tensor<3x3xi16>)
     outs(%fill: tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16>
@@ -391,7 +391,7 @@ func @pooling_nhwc_i32_max_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1
   %fake = linalg.init_tensor [3, 3] : tensor<3x3xi32>
   %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi32>
   %cst = arith.constant 0 : i32
-  %fill = linalg.fill(%cst, %init) : i32, tensor<1x2x2x1xi32> -> tensor<1x2x2x1xi32>
+  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
   %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
     ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
     outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
@@ -426,7 +426,7 @@ func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32
   %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
   %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32>
   %cst = arith.constant 0.000000e+00 : f32
-  %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x1xf32> -> tensor<1x2x2x1xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
   %res = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
     ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
     outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
@@ -460,7 +460,7 @@ func @pooling_ndhwc_sum_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x
   %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32>
   %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32>
   %cst = arith.constant 0.000000e+00 : f32
-  %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x2x1xf32> -> tensor<1x2x2x2x1xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
   %res = linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
     ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>)
     outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
@@ -494,7 +494,7 @@ func @pooling_ndhwc_max_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x
   %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32>
   %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32>
   %cst = arith.constant 0.000000e+00 : f32
-  %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x2x1xf32> -> tensor<1x2x2x2x1xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
   %res = linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
     ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>)
     outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
@@ -528,7 +528,7 @@ func @pooling_ndhwc_min_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x
   %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32>
   %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32>
   %cst = arith.constant 0.000000e+00 : f32
-  %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x2x1xf32> -> tensor<1x2x2x2x1xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>
   %res = linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
     ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>)
     outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32>

diff  --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir
index 36879b7254a7e..99c9317ab9bb5 100644
--- a/mlir/test/Dialect/Linalg/pad.mlir
+++ b/mlir/test/Dialect/Linalg/pad.mlir
@@ -173,11 +173,11 @@ func @pad_multiple(%arg0: tensor<64x64xf32>,
 
   // Check both fill operations are padded by the same pad tensor operation.
   //      FILL:  %[[T0:.*]] = tensor.pad
-  //      FILL:  %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
-  //      FILL:  %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
+  //      FILL:  %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]]
+  //      FILL:  %[[T2:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T1]]
   //      FILL:  = tensor.extract_slice %[[T2]]
-  %1 = linalg.fill(%cst, %0) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
-  %2 = linalg.fill(%cst, %1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %2 : tensor<?x?xf32>
 }
 
@@ -198,15 +198,15 @@ func @compose_padding(%arg0: tensor<64x64xf32>,
   // MATMUL-SAME:                                     [0, 0]
   // MATMUL-SAME:                                     [%[[SIZE]], %[[SIZE]]]
   //      MATMUL:  %[[T1:.*]] = tensor.pad %[[T0]]
-  //      MATMUL:  %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]
-  //      MATMUL:  %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]
+  //      MATMUL:  %[[T2:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T1]]
+  //      MATMUL:  %[[T3:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T2]]
   %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
   %1 = tensor.pad %0 low[0, 0] high[%iv0, %iv0]  {
-    ^bb0(%arg3: index, %arg4: index):  
+    ^bb0(%arg3: index, %arg4: index):
       tensor.yield %cst : f32
   } : tensor<?x?xf32> to tensor<64x64xf32>
-  %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32>
-  %3 = linalg.fill(%cst, %2) : f32, tensor<64x64xf32> -> tensor<64x64xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<64x64xf32>) -> tensor<64x64xf32>
   %4 = tensor.extract_slice %3[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
 
   // Check there are no additional pad tensor operations.
@@ -234,10 +234,10 @@ func @
diff erent_padding_values(%arg0: tensor<64x64xf32>,
   %size = affine.min #map0()[%iv0]
   %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
   %1 = tensor.pad %0 low[0, 0] high[%iv0, %iv0]  {
-    ^bb0(%arg3: index, %arg4: index):  
+    ^bb0(%arg3: index, %arg4: index):
       tensor.yield %cst : f32
   } : tensor<?x?xf32> to tensor<64x64xf32>
-  %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32>
   %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
 
   // Different padding values prevent composing the paddings (42.0 vs. 0.0).
@@ -259,10 +259,10 @@ func @
diff erent_padding_dynamic_sizes(%arg0: tensor<64x64xf32>,
   %size = affine.min #map0()[%iv0]
   %0 = tensor.extract_slice %arg0[0, 0] [%iv0, %iv0] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
   %1 = tensor.pad %0 low[0, 0] high[%iv0, %iv0]  {
-    ^bb0(%arg3: index, %arg4: index):  
+    ^bb0(%arg3: index, %arg4: index):
       tensor.yield %cst : f32
   } : tensor<?x?xf32> to tensor<64x64xf32>
-  %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32>
   %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
 
   // Different dynamic sizes prevent composing the paddings (%iv0 vs %size).
@@ -284,10 +284,10 @@ func @
diff erent_padding_dynamic_rank(%arg0: tensor<64x64x1xf32>,
   %size = affine.min #map0()[%iv0]
   %0 = tensor.extract_slice %arg0[0, 0, 0] [%size, %size, 1] [1, 1, 1] : tensor<64x64x1xf32> to tensor<?x?xf32>
   %1 = tensor.pad %0 low[0, 0] high[%iv0, %iv0]  {
-    ^bb0(%arg3: index, %arg4: index):  
+    ^bb0(%arg3: index, %arg4: index):
       tensor.yield %cst : f32
   } : tensor<?x?xf32> to tensor<64x64xf32>
-  %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32>
   %3 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32>
 
   // Different dynamic ranks prevent composing the paddings ([%size, %size, 1] vs [%size, %size]).
@@ -309,10 +309,10 @@ func @
diff erent_padding_static_sizes(%arg0: tensor<62x62xf32>,
   %size = affine.min #map0()[%iv0]
   %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<62x62xf32> to tensor<?x?xf32>
   %1 = tensor.pad %0 low[0, 0] high[%iv0, %iv0]  {
-    ^bb0(%arg3: index, %arg4: index):  
+    ^bb0(%arg3: index, %arg4: index):
       tensor.yield %cst : f32
   } : tensor<?x?xf32> to tensor<62x62xf32>
-  %2 = linalg.fill(%cst, %1) : f32, tensor<62x62xf32> -> tensor<62x62xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<62x62xf32>) -> tensor<62x62xf32>
   %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<62x62xf32> to tensor<?x?xf32>
 
   // Different static sizes prevent composing the paddings (62 vs 64 derived from #map0).
@@ -340,8 +340,8 @@ func @scalar_operand(%arg0: f32,
   %1 = tensor.extract_slice %arg1[0, 0] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
 
   // Check only the fill output operand is padded.
-  //      FILL:   %[[T6:.*]] = linalg.fill(%[[ARG0]], %[[T1]]
-  %2 = linalg.fill(%arg0, %1) : f32, tensor<4x?xf32> -> tensor<4x?xf32>
+  //      FILL:   %[[T6:.*]] = linalg.fill ins(%[[ARG0]]{{.*}}outs(%[[T1]]
+  %2 = linalg.fill ins(%arg0 : f32) outs(%1 : tensor<4x?xf32>) -> tensor<4x?xf32>
   %3 = tensor.insert_slice %2 into %arg1[0, 0] [4, %0] [1, 1] : tensor<4x?xf32> into tensor<24x12xf32>
   return %3 : tensor<24x12xf32>
 }
@@ -466,9 +466,9 @@ func @rank_reducing(%arg0: tensor<1x64x1x64xf32>,
 
   // Check the fill is padded despite the rank-reducing slice operation.
   //      FILL:  %[[T0:.*]] = tensor.pad
-  //      FILL:  %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //      FILL:  %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]]
   // FILL-SAME:    tensor<1x64x64xf32>
   //      FILL:  = tensor.extract_slice %[[T1]]
-  %1 = linalg.fill(%cst, %0) : f32, tensor<1x?x?xf32> -> tensor<1x?x?xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
   return %1 : tensor<1x?x?xf32>
 }

diff  --git a/mlir/test/Dialect/Linalg/pad_fusion.mlir b/mlir/test/Dialect/Linalg/pad_fusion.mlir
index 90e6381f6f16a..78f42f9dff559 100644
--- a/mlir/test/Dialect/Linalg/pad_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/pad_fusion.mlir
@@ -38,7 +38,7 @@ func @dynamic_pad_fusion(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index,
 //  CHECK-DAG:   %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
 //  CHECK-DAG:   %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]]
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%[[TARGET_D0]], %[[TARGET_D1]]] 
-//      CHECK:   %[[FILL:.+]] = linalg.fill(%[[ARG5]], %[[INIT]])
+//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[ARG5]]{{.*}}outs(%[[INIT]]
 //  CHECK-DAG:   %[[SIZE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
 //  CHECK-DAG:   %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
 //      CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
@@ -82,7 +82,7 @@ func @mixed_pad_fusion(%arg0 : tensor<?x42xf32>, %arg1 : index, %arg2 : index,
 //  CHECK-DAG:   %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
 //  CHECK-DAG:   %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]]
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [49, %[[TARGET_D1]]] 
-//      CHECK:   %[[FILL:.+]] = linalg.fill(%[[ARG3]], %[[INIT]])
+//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[ARG3]]{{.*}}outs(%[[INIT]]
 //  CHECK-DAG:   %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
 //      CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
 // CHECK-SAME:       [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1]

diff  --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
index 9f0679f5d4336..9420ece55a140 100644
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -23,9 +23,9 @@ func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 //      CHECK:       %[[T19:.+]] = memref.subview %[[T18]]
 //      CHECK:       %[[T20:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32, 3>
 //      CHECK:       %[[T21:.+]] = memref.subview %[[T20]]
-//      CHECK:       linalg.fill(%[[C42]], %[[T19]])
+//      CHECK:       linalg.fill ins(%[[C42]]{{.*}}outs(%[[T19]]
 //      CHECK:       memref.copy %[[T7]], %[[T19]]
-//      CHECK:       linalg.fill(%[[C42]], %[[T21]])
+//      CHECK:       linalg.fill ins(%[[C42]]{{.*}}outs(%[[T21]]
 //      CHECK:       memref.copy %[[T17]], %[[T21]]
 //      CHECK:       linalg.matmul ins(%[[T19]], %[[T12]]{{.*}} outs(%[[T21]]
 //  CHECK-NOT:       linalg.fill

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index f2957f73e6221..eeb9ab8c96669 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -63,12 +63,12 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // -----
 
 func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
-  linalg.fill(%arg1, %arg0) : f32, memref<?xf32, offset: ?, strides: [1]>
+  linalg.fill ins(%arg1 : f32) outs(%arg0 : memref<?xf32, offset: ?, strides: [1]>)
   return
 }
 // CHECK-LABEL: func @fill_view(
 //       CHECK:  %{{.*}}: memref<?xf32, #[[$strided1D]]>, %{{.*}}: f32) {
-//       CHECK:   linalg.fill(%{{.*}}, %{{.*}}) : f32, memref<?xf32, #[[$strided1D]]>
+//       CHECK:   linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref<?xf32, #[[$strided1D]]>)
 
 // -----
 
@@ -84,12 +84,12 @@ func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
 
 
 func @fill_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: f32) {
-  linalg.fill(%arg1, %arg0) : f32, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+  linalg.fill ins(%arg1 : f32) outs(%arg0 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
   return
 }
 // CHECK-LABEL: func @fill_view3(
 //       CHECK:  %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>, %{{.*}}: f32) {
-//       CHECK:   linalg.fill(%{{.*}}, %{{.*}}) : f32, memref<?x?x?xf32, #[[$strided3D]]>
+//       CHECK:   linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
 
 // -----
 
@@ -208,9 +208,9 @@ func @generic_with_multiple_tensor_outputs(
     -> (tensor<i32>, tensor<i32>) {
   %c0 = arith.constant 0 : index
   %0 = linalg.init_tensor [] : tensor<i32>
-  %1 = linalg.fill(%arg2, %0) : i32, tensor<i32> -> tensor<i32>
+  %1 = linalg.fill ins(%arg2 : i32) outs(%0 : tensor<i32>) -> tensor<i32>
   %2 = linalg.init_tensor [] : tensor<i32>
-  %3 = linalg.fill(%arg2, %2) : i32, tensor<i32> -> tensor<i32>
+  %3 = linalg.fill ins(%arg2 : i32) outs(%2 : tensor<i32>) -> tensor<i32>
   %4:2 = linalg.generic {
     indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>],
     iterator_types = ["reduction"]}
@@ -346,7 +346,7 @@ func @init_tensor(%arg0 : index, %arg1 : index)
 
 func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32> {
   %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
-  %1 = linalg.fill(%arg2, %0) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+  %1 = linalg.fill ins(%arg2 : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-// CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<?x?xf32>) -> tensor<?x?xf32>

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir
index 516b3701f38f9..7ad921b582c96 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir
@@ -7,7 +7,7 @@ builtin.func @no_fuse_gemm(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
   %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
   %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
   %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
-  %fill = linalg.fill(%cst, %init) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
   %result = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %result : tensor<?x?xf32>

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
index 1c5093fa9a716..d3f670d2b67b9 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
@@ -17,7 +17,7 @@ builtin.func @fuse_input(%arg0: tensor<24x12xf32>,
   %c24 = arith.constant 24 : index
   %c4 = arith.constant 4 : index
   %cst = arith.constant 0.000000e+00 : f32
-  %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32>
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<24x12xf32>) -> tensor<24x12xf32>
 
   //      MATMUL:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
   //      MATMUL:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
@@ -31,7 +31,7 @@ builtin.func @fuse_input(%arg0: tensor<24x12xf32>,
   //      MATMUL:        %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
   // MATMUL-SAME:                                          %[[IV1]], %[[IV2]]
   // MATMUL-SAME:                                          %[[UB1]], %[[UB2]]
-  //      MATMUL:        %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //      MATMUL:        %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]]
   //      MATMUL:        %{{.*}} = linalg.matmul ins(%[[T1]]
   %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
   return %1 : tensor<24x25xf32>
@@ -55,7 +55,7 @@ builtin.func @fuse_output(%arg0: tensor<24x12xf32>,
   %c24 = arith.constant 24 : index
   %c4 = arith.constant 4 : index
   %cst = arith.constant 0.000000e+00 : f32
-  %0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32>
+  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
 
   // Update the iteration argument of the outermost tile loop.
   //      MATMUL:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
@@ -67,7 +67,7 @@ builtin.func @fuse_output(%arg0: tensor<24x12xf32>,
   //      MATMUL:      %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
   // MATMUL-SAME:                                        %[[IV1]], %[[IV0]]
   // MATMUL-SAME:                                        %[[TS1]], %[[TS0]]
-  //      MATMUL:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //      MATMUL:      %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]]
   //      MATMUL:        scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
 
   // Check there is an extract/insert slice pair for the output operand.
@@ -184,19 +184,19 @@ builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>,
   %c24 = arith.constant 24 : index
   %c4 = arith.constant 4 : index
   %cst = arith.constant 0.000000e+00 : f32
-  %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32>
-  %1 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32>
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<24x12xf32>) -> tensor<24x12xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
 
   // Fuse both producers to the appropriate tile loops.
   //      MATMUL:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
   //      MATMUL:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
   //      MATMUL:      %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
   // MATMUL-SAME:                                        %[[IV1]], %[[IV0]]
-  //      MATMUL:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //      MATMUL:      %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]]
   //      MATMUL:        scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
   //      MATMUL:          %[[T2:.*]] = tensor.extract_slice %[[ARG0]]
   // MATMUL-SAME:                                            %[[IV1]], %[[IV2]]
-  //      MATMUL:          %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
+  //      MATMUL:          %[[T3:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T2]]
   //      MATMUL:          %[[T4:.*]] = tensor.extract_slice %[[ARG5]]
   //      MATMUL:          %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T4]]
   %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32>
@@ -255,11 +255,11 @@ builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
 func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>,
                                %arg1: tensor<10xf32>) -> tensor<10xf32> {
   %cst = arith.constant 0.000000e+00 : f32
-  %0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32>
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<10x17xf32>) -> tensor<10x17xf32>
 
   // Cannot fuse the output fill since the reduction loop is the outermost loop.
-  //      GENERIC:      %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG1]])
-  %1 = linalg.fill(%cst, %arg1) : f32, tensor<10xf32> -> tensor<10xf32>
+  //      GENERIC:      %[[T0:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[ARG1]]
+  %1 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<10xf32>) -> tensor<10xf32>
 
   //      GENERIC:  scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]]
   //      GENERIC:    scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
@@ -267,7 +267,7 @@ func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>,
   // MATMUL the input fill has been fused.
   //      GENERIC:      %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
   // GENERIC-SAME:                                        %[[IV1]], %[[IV0]]
-  //      GENERIC:      %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
+  //      GENERIC:      %[[T2:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T1]]
   //      GENERIC:      %[[T3:.*]] = tensor.extract_slice %[[ARG3]]
   // GENERIC-SAME:                                        %[[IV1]]
   //      GENERIC:  linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]]
@@ -298,7 +298,7 @@ func @fuse_non_rectangular(%arg0: tensor<10x17xf32>,
   //  GENERIC-DAG:  %[[C8:.*]] = arith.constant 8 : index
   //  GENERIC-DAG:  %[[C10:.*]] = arith.constant 10 : index
   %cst = arith.constant 0.000000e+00 : f32
-  %0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32>
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<10x17xf32>) -> tensor<10x17xf32>
 
   //      GENERIC:  scf.for %[[IV0:[0-9a-zA-Z]*]] = %[[C0]] to %[[C8]] step %[[C4]]
   //      GENERIC:    scf.for %[[IV1:[0-9a-zA-Z]*]] = %[[C0]] to %[[C10]] step %[[C5]]
@@ -313,7 +313,7 @@ func @fuse_non_rectangular(%arg0: tensor<10x17xf32>,
   //      GENERIC:      %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
   // GENERIC-SAME:                                        %[[IV1]], %[[SUM]]
   // GENERIC-SAME:                                                , %[[UB1]]
-  //      GENERIC:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //      GENERIC:      %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]]
   %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x17xf32>) outs(%arg1 : tensor<10x8xf32>) {
   ^bb0(%arg2: f32, %arg3: f32):  
     %2 = arith.addf %arg2, %arg3 : f32

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
index ef28641f4062c..5de6c1ad84e1a 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
@@ -15,8 +15,8 @@ builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>,
   %cst = arith.constant 1.0 : f32
 
   // Do not tile the filter fill since the filter dimensions are not tiled.
-  //      CONV:  %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
-  %0 = linalg.fill(%cst, %arg0) : f32, tensor<2x2xf32> -> tensor<2x2xf32>
+  //      CONV:  %[[T0:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[ARG0]]
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<2x2xf32>) -> tensor<2x2xf32>
 
   // Fuse all other operations.
   //      CONV:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]]
@@ -26,24 +26,24 @@ builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>,
   // CONV-SAME:                                            %[[IV0]], %[[IV1]]
   //      CONV:          %[[T2:.*]] = tensor.extract_slice %[[ARG2]]
   // CONV-SAME:                                            %[[IV0]], %[[IV1]]
-  //      CONV:          %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
+  //      CONV:          %[[T3:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T2]]
   //      CONV:          %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]]
-  %1 = linalg.fill(%cst, %arg2) : f32, tensor<10x10xf32> -> tensor<10x10xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<10x10xf32>) -> tensor<10x10xf32>
   %2 = linalg.conv_2d ins(%arg1, %0 : tensor<11x11xf32>, tensor<2x2xf32>) outs(%1 : tensor<10x10xf32>) -> tensor<10x10xf32>
 
   //      CONV:          %[[T5:.*]] = tensor.extract_slice %[[ARG3]]
   // CONV-SAME:                                            %[[IV0]], %[[IV1]]
-  //      CONV:          %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]])
+  //      CONV:          %[[T6:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T5]]
   //      CONV:          %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]]
-  %3 = linalg.fill(%cst, %arg3) : f32, tensor<9x9xf32> -> tensor<9x9xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%arg3 : tensor<9x9xf32>) -> tensor<9x9xf32>
   %4 = linalg.conv_2d ins(%2, %0 : tensor<10x10xf32>, tensor<2x2xf32>) outs(%3 : tensor<9x9xf32>) -> tensor<9x9xf32>
 
   // Use the argument passed in by iteration argument.
   //      CONV:          %[[T8:.*]] = tensor.extract_slice %[[ARG6]]
   // CONV-SAME:                                            %[[IV0]], %[[IV1]]
-  //      CONV:          %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]])
+  //      CONV:          %[[T9:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T8]]
   //      CONV:          %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]]
-  %5 = linalg.fill(%cst, %arg4) : f32, tensor<8x8xf32> -> tensor<8x8xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%arg4 : tensor<8x8xf32>) -> tensor<8x8xf32>
   %6 = linalg.conv_2d ins(%4, %0 : tensor<9x9xf32>, tensor<2x2xf32>) outs(%5 : tensor<8x8xf32>) -> tensor<8x8xf32>
   return %6 : tensor<8x8xf32>
 }
@@ -61,8 +61,8 @@ builtin.func @fuse_matmul_chain(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
   %cst = arith.constant 0.000000e+00 : f32
 
   // Do not tile rhs fill of the producer matmul since none of its loop dimension is tiled.
-  //      MATMUL:  %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
-  %0 = linalg.fill(%cst, %arg0) : f32, tensor<8x8xf32> -> tensor<8x8xf32>
+  //      MATMUL:  %[[T0:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[ARG0]]
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf32>
 
   //      MATMUL:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG1:.*]] = %[[ARG0]]
   //      MATMUL:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[ARG1]]
@@ -70,14 +70,14 @@ builtin.func @fuse_matmul_chain(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
   // Only the outermost loop of the producer matmul is tiled.
   //      MATMUL:      %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
   // MATMUL-SAME:                                        %[[IV0]], 0
-  //      MATMUL:      %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
+  //      MATMUL:      %[[T2:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T1]]
   //      MATMUL:      %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}}
   %1 = linalg.matmul ins(%0, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32>
 
   // Use the argument passed in by iteration argument.
   //      MATMUL:      %[[T4:.*]] = tensor.extract_slice %[[ARG2]]
   // MATMUL-SAME:                                        %[[IV0]], %[[IV1]]
-  //      MATMUL:      %[[T5:.*]] = linalg.fill(%{{.*}}, %[[T4]])
+  //      MATMUL:      %[[T5:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T4]]
   //      MATMUL:      %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T5]]
   %2 = linalg.matmul ins(%1, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32>
   return %2 : tensor<8x8xf32>

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 23a991bf6b7a4..c5acd8ebf28dc 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -75,7 +75,7 @@ func @conv_tensors_static(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3
   %cst = arith.constant 0.0 : f32
 
   %init = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
-  %fill = linalg.fill(%cst, %init) : f32, tensor<1x112x112x32xf32> -> tensor<1x112x112x32xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
 
   %conv = linalg.conv_2d_nhwc_hwcf
     {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
@@ -119,7 +119,7 @@ func @conv_tensors_static(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3
 // CHECK-SAME: (%[[INPUT:.+]]: tensor<1x225x225x3xf32>, %[[FILTER:.+]]: tensor<3x3x3x32xf32>, %[[ELEM:.+]]: tensor<1x112x112x32xf32>)
 
 //      CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
-// CHECK-NEXT: %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor<1x112x112x32xf32> -> tensor<1x112x112x32xf32>
+// CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
 
 // CHECK-NEXT: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG0:.+]] = %[[FILL]])
 // CHECK-NEXT:   %[[OFFSET_H:.+]] = affine.apply #[[MAP0]](%[[IV0]])
@@ -157,7 +157,7 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
   %oc = tensor.dim %elementwise, %c3 : tensor<?x?x?x?xf32>
 
   %init = linalg.init_tensor [%n, %oh, %ow, %oc] : tensor<?x?x?x?xf32>
-  %fill = linalg.fill(%cst, %init) : f32, tensor<?x?x?x?xf32> -> tensor<?x?x?x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
 
   %conv = linalg.conv_2d_nhwc_hwcf
     {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
@@ -226,7 +226,7 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
 //  CHECK-DAG:   %[[ELEM_OC:.+]] = tensor.dim %[[ELEM]], %[[C3]] : tensor<?x?x?x?xf32>
 
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%[[ELEM_N]], %[[ELEM_OH]], %[[ELEM_OW]], %[[ELEM_OC]]] : tensor<?x?x?x?xf32>
-//      CHECK:   %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor<?x?x?x?xf32> -> tensor<?x?x?x?xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
 
 //  CHECK-DAG:   %[[FILTER_H:.+]] = tensor.dim %[[FILTER]], %[[C0]] : tensor<?x?x?x?xf32>
 //  CHECK-DAG:   %[[FILTER_W:.+]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<?x?x?x?xf32>
@@ -310,7 +310,7 @@ func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tensor<64
     tensor.yield %zero : f32
   } : tensor<58x1xf32> to tensor<64x128xf32>
 
-  %fill = linalg.fill(%zero, %large_input) : f32, tensor<64x128xf32> -> tensor<64x128xf32>
+  %fill = linalg.fill ins(%zero : f32) outs(%large_input : tensor<64x128xf32>) -> tensor<64x128xf32>
 
   %for0 = scf.for %iv0 = %c0 to %d0 step %c16 iter_args(%arg0 = %fill) -> tensor<64x128xf32> {
     %for1 = scf.for %iv1 = %c0 to %d1 step %c32 iter_args(%arg1 = %arg0) -> tensor<64x128xf32> {

diff  --git a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir
index 58f82972e309f..a5a5f56495b09 100644
--- a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir
+++ b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir
@@ -24,7 +24,7 @@ func @fill_matmul_tensors(
 //      CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]]
 //      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
 //      CHECK:     %[[SLICE:.+]] = tensor.extract_slice %[[TC1]]
-//      CHECK:     %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[SLICE]])
+//      CHECK:     %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[SLICE]]
 //      CHECK:     %[[sTD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[FILL]]) -> (tensor<?x?xf32>) {
 //      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
 //      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
@@ -42,7 +42,7 @@ func @fill_matmul_tensors(
   %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
   %1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
   %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
-  %3 = linalg.fill(%cst, %2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
   %4 = linalg.matmul {__internal_linalg_transform__ = "tensors_fuse_distribute1"}
        ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%3: tensor<?x?xf32>)

diff  --git a/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir b/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir
index df372159a3377..4826798d7a441 100644
--- a/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir
+++ b/mlir/test/Dialect/Linalg/tile-scalarize-dynamic-dims.mlir
@@ -43,7 +43,7 @@ func @tiled_and_peeled_matmul(%arg0: tensor<257x259xf32>, %arg1: tensor<259x258x
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
   %c32 = arith.constant 32 : index
-  %0 = linalg.fill(%cst, %arg2) : f32, tensor<257x258xf32> -> tensor<257x258xf32>
+  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<257x258xf32>) -> tensor<257x258xf32>
   %1 = scf.for %arg3 = %c0 to %c257 step %c64 iter_args(%arg4 = %0) -> (tensor<257x258xf32>) {
     %2 = affine.min #map0(%arg3)
     %3 = tensor.extract_slice %arg0[%arg3, 0] [%2, 259] [1, 1] : tensor<257x259xf32> to tensor<?x259xf32>

diff  --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index 3802b1a9365a4..1c3e8b95ad299 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -254,35 +254,35 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
 //       TILE-234:    linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs(
 
 func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) {
-  linalg.fill(%arg1, %arg0) : f32, memref<127x99xf32>
+  linalg.fill ins(%arg1 : f32) outs(%arg0 : memref<127x99xf32>)
   return
 }
 // TILE-2-LABEL: func @fill_static
 //       TILE-2:   for
 //   TILE-2-NOT:   for
 //       TILE-2:       memref.subview{{.*}} : memref<127x99xf32>
-//       TILE-2:       linalg.fill{{.*}} : f32, memref<?x99xf32, #[[$stride_99_1_layout_map]]>
+//       TILE-2:       linalg.fill{{.*}} : memref<?x99xf32, #[[$stride_99_1_layout_map]]>
 
 // TILE-02-LABEL: func @fill_static
 //       TILE-02:   for
 //   TILE-02-NOT:   for
 //       TILE-02:       memref.subview{{.*}} : memref<127x99xf32>
-//       TILE-02:       linalg.fill{{.*}} : f32, memref<127x?xf32, #[[$stride_99_1_layout_map]]>
+//       TILE-02:       linalg.fill{{.*}} : memref<127x?xf32, #[[$stride_99_1_layout_map]]>
 
 // TILE-002-LABEL: func @fill_static
 //   TILE-002-NOT:   for
-//       TILE-002:     linalg.fill{{.*}} f32, memref<127x99xf32>
+//       TILE-002:     linalg.fill{{.*}} : memref<127x99xf32>
 
 // TILE-234-LABEL: func @fill_static
 //       TILE-234:   for
 //       TILE-234:     for
 //   TILE-234-NOT:   for
 //       TILE-234:       memref.subview{{.*}} : memref<127x99xf32>
-//       TILE-234:       linalg.fill{{.*}} : f32, memref<?x3xf32, #[[$stride_99_1_layout_map]]>
+//       TILE-234:       linalg.fill{{.*}} : memref<?x3xf32, #[[$stride_99_1_layout_map]]>
 
 
 func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) {
-  linalg.fill(%arg1, %arg0) : f32, memref<?x?xf32, offset: ?, strides: [?, 1]>
+  linalg.fill ins(%arg1 : f32) outs(%arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>)
   return
 }
 // TILE-2-LABEL: func @fill
@@ -318,7 +318,7 @@ func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memre
   linalg.generic #pointwise_2d_trait
     ins(%arg0, %arg1 : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>)
     outs(%arg2 : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-  ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):   
+  ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
     %4 = arith.addf %arg4, %arg5 : f32
     linalg.yield %4 : f32
   }

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index a16b4bd85f1c1..522422526906f 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -296,8 +296,8 @@ func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
   %cf = arith.constant 1.0 : f32
   %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] :
  	 memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
-  linalg.fill(%cf, %3) { __internal_linalg_transform__ = "_promote_views_aligned_"}
-  	: f32, memref<?x?xf32, offset: ?, strides: [?, ?]>
+  linalg.fill { __internal_linalg_transform__ = "_promote_views_aligned_"}
+   ins(%cf : f32) outs(%3 : memref<?x?xf32, offset: ?, strides: [?, ?]>)
   return
 }
 // CHECK-LABEL: func @aligned_promote_fill
@@ -306,9 +306,9 @@ func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
 // CHECK:         %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<32000000xi8>
 // CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
 // CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
-// CHECK:         linalg.fill({{.*}}, %[[v0]]) : f32, memref<?x?xf32>
+// CHECK:         linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref<?x?xf32>)
 // CHECK:         memref.copy %[[s0]], %[[l0]] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
-// CHECK:         linalg.fill(%[[cf]], %[[v0]]) : f32, memref<?x?xf32>
+// CHECK:         linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref<?x?xf32>)
 
 func @aligned_promote_fill_complex(%arg0: memref<?x?xcomplex<f32>, offset: ?, strides: [?, 1]>) {
   %c2000 = arith.constant 2000 : index
@@ -319,8 +319,8 @@ func @aligned_promote_fill_complex(%arg0: memref<?x?xcomplex<f32>, offset: ?, st
   %cc = complex.create %cf, %cf : complex<f32>
   %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] :
  	 memref<?x?xcomplex<f32>, offset: ?, strides: [?, 1]> to memref<?x?xcomplex<f32>, offset: ?, strides: [?, ?]>
-  linalg.fill(%cc, %3) { __internal_linalg_transform__ = "_promote_views_aligned_"}
-  	: complex<f32>, memref<?x?xcomplex<f32>, offset: ?, strides: [?, ?]>
+  linalg.fill { __internal_linalg_transform__ = "_promote_views_aligned_"}
+   ins(%cc : complex<f32>) outs(%3 : memref<?x?xcomplex<f32>, offset: ?, strides: [?, ?]>)
   return
 }
 // CHECK-LABEL: func @aligned_promote_fill_complex
@@ -329,9 +329,9 @@ func @aligned_promote_fill_complex(%arg0: memref<?x?xcomplex<f32>, offset: ?, st
 // CHECK:         %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<64000000xi8>
 // CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref<?x?xcomplex<f32>>
 // CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xcomplex<f32>> to memref<?x?xcomplex<f32>, #[[$STRIDED_2D_u_1]]>
-// CHECK:         linalg.fill({{.*}}, %[[v0]]) : complex<f32>, memref<?x?xcomplex<f32>>
+// CHECK:         linalg.fill ins({{.*}} : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
 // CHECK:         memref.copy %[[s0]], %[[l0]] : memref<?x?xcomplex<f32>, #map{{.*}}> to memref<?x?xcomplex<f32>, #map{{.*}}>
-// CHECK:         linalg.fill(%[[cc]], %[[v0]]) : complex<f32>, memref<?x?xcomplex<f32>>
+// CHECK:         linalg.fill ins(%[[cc]] : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
 
 func @tile_permute_parallel_loop(%arg0: memref<?x?xf32>,
                                  %arg1: memref<?x?xf32>,

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 49ec6e7033281..e728013a073e7 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -191,7 +191,7 @@ func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
 func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
   //       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
   //       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
-  linalg.fill(%arg0, %A) : f32, memref<8x16xf32>
+  linalg.fill ins(%arg0 : f32) outs(%A : memref<8x16xf32>)
   return
 }
 
@@ -202,7 +202,7 @@ func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
   // CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
   //      CHECK:   %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
   //      CHECK:   vector.transfer_write %[[VEC]], %[[M]][] : vector<f32>, memref<f32>
-  linalg.fill(%arg0, %A) : f32, memref<f32>
+  linalg.fill ins(%arg0 : f32) outs(%A : memref<f32>)
   return
 }
 
@@ -590,7 +590,7 @@ func @pad_static_source(%arg0: tensor<2x5x2xf32>, %pad_value: f32) -> tensor<2x6
 //       CHECK:   %[[V4:.*]] = arith.addi %[[DIM3]], %[[C3]] : index
 //       CHECK:   %[[V5:.*]] = arith.addi %[[V4]], %[[C2]] : index
 //       CHECK:   %[[INIT:.*]] = linalg.init_tensor [6, %[[V1]], %[[V2]], %[[V5]]] : tensor<6x?x?x?xf32>
-//       CHECK:   %[[FILL:.*]] = linalg.fill(%{{.*}}, %[[INIT]]) : f32, tensor<6x?x?x?xf32> -> tensor<6x?x?x?xf32>
+//       CHECK:   %[[FILL:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%[[INIT]] : tensor<6x?x?x?xf32>) -> tensor<6x?x?x?xf32>
 //       CHECK:   %[[SRCDIM:.*]] = tensor.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32>
 //       CHECK:   %[[RESULT:.*]] = tensor.insert_slice %[[SRC]] into %[[FILL]][2, %[[LOW]], 3, 3] [1, 2, 2, %[[SRCDIM]]] [1, 1, 1, 1] : tensor<1x2x2x?xf32> into tensor<6x?x?x?xf32>
 //       CHECK:   return %[[RESULT]]
@@ -833,7 +833,7 @@ func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   %ident = arith.constant -3.40282e+38 : f32
   %init = linalg.init_tensor [4] : tensor<4xf32>
-  %fill = linalg.fill(%ident, %init) : f32, tensor<4xf32> -> tensor<4xf32>
+  %fill = linalg.fill ins(%ident : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
   %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                           affine_map<(d0, d1) -> (d0)>],
                          iterator_types = ["parallel", "reduction"]}
@@ -858,7 +858,7 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   %maxf32 = arith.constant 3.40282e+38 : f32
   %init = linalg.init_tensor [4] : tensor<4xf32>
-  %fill = linalg.fill(%maxf32, %init) : f32, tensor<4xf32> -> tensor<4xf32>
+  %fill = linalg.fill ins(%maxf32 : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
   %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                           affine_map<(d0, d1) -> (d0)>],
                          iterator_types = ["parallel", "reduction"]}
@@ -881,7 +881,7 @@ func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   %ident = arith.constant 1.0 : f32
   %init = linalg.init_tensor [4] : tensor<4xf32>
-  %fill = linalg.fill(%ident, %init) : f32, tensor<4xf32> -> tensor<4xf32>
+  %fill = linalg.fill ins(%ident : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
   %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                           affine_map<(d0, d1) -> (d0)>],
                          iterator_types = ["parallel", "reduction"]}
@@ -904,7 +904,7 @@ func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   %ident = arith.constant false
   %init = linalg.init_tensor [4] : tensor<4xi1>
-  %fill = linalg.fill(%ident, %init) : i1, tensor<4xi1> -> tensor<4xi1>
+  %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1>
   %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                           affine_map<(d0, d1) -> (d0)>],
                          iterator_types = ["parallel", "reduction"]}
@@ -927,7 +927,7 @@ func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   %ident = arith.constant true
   %init = linalg.init_tensor [4] : tensor<4xi1>
-  %fill = linalg.fill(%ident, %init) : i1, tensor<4xi1> -> tensor<4xi1>
+  %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1>
   %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                           affine_map<(d0, d1) -> (d0)>],
                          iterator_types = ["parallel", "reduction"]}
@@ -950,7 +950,7 @@ func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   %ident = arith.constant false
   %init = linalg.init_tensor [4] : tensor<4xi1>
-  %fill = linalg.fill(%ident, %init) : i1, tensor<4xi1> -> tensor<4xi1>
+  %fill = linalg.fill ins(%ident : i1) outs(%init : tensor<4xi1>) -> tensor<4xi1>
   %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                           affine_map<(d0, d1) -> (d0)>],
                          iterator_types = ["parallel", "reduction"]}
@@ -974,7 +974,7 @@ func @explicit_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) -> tens
   // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
   %c0 = arith.constant 0.0 : f32
   %init = linalg.init_tensor [4, 4] : tensor<4x4xf32>
-  %fill = linalg.fill(%c0, %init) : f32, tensor<4x4xf32> -> tensor<4x4xf32>
+  %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<4x4xf32>) -> tensor<4x4xf32>
   %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                           affine_map<(d0, d1) -> (d0, 0)>,
                                           affine_map<(d0, d1) -> (d0, d1)>],
@@ -1003,7 +1003,7 @@ func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) ->
   // CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
   %c0 = arith.constant 0.0 : f32
   %init = linalg.init_tensor [4] : tensor<4xf32>
-  %fill = linalg.fill(%c0, %init) : f32, tensor<4xf32> -> tensor<4xf32>
+  %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
   %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                                           affine_map<(d0, d1) -> (d0, 0)>,
                                           affine_map<(d0, d1) -> (d0)>],
@@ -1034,7 +1034,7 @@ func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
 
   //      CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][]
   // CHECK-SAME:   : vector<f32>, tensor<f32>
-  %1 = linalg.fill(%f0, %0) : f32, tensor<f32> -> tensor<f32>
+  %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32>
   //      CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
   // CHECK-SAME:   : tensor<32xf32>, vector<32xf32>
   //      CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 04c8a51817813..db96c8ce07193 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -446,8 +446,8 @@ func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
 //    %[[V:.*]] = memref.alloca(%[[S]]) : memref<?xf64>
 //    %[[F:.*]] = memref.alloca(%[[S]]) : memref<?xi1>
 //    %[[A:.*]] = memref.alloca(%[[S]]) : memref<?xindex>
-//    linalg.fill(%{{.*}}, %[[V]]) : f64, memref<?xf64>
-//    linalg.fill(%{{.*}}, %[[F]]) : i1, memref<?xi1>
+//    linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref<?xf64>)
+//    linalg.fill ins(%{{.*}} : i1) outs(%[[F]] : memref<?xi1>)
 //       CHECK: return
 func @sparse_expansion() {
   %c = arith.constant 8 : index

diff  --git a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
index 2917685064afc..4bb678d0fefbc 100644
--- a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
@@ -35,7 +35,7 @@
 //   CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref<i32>
 //   CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<13xi32>
-//   CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref<13xi32>
+//   CHECK-DAG: linalg.fill ins(%[[zeroI32]] : i32) outs(%[[M]] : memref<13xi32>)
 //       CHECK: scf.while : () -> () {
 //       CHECK:   %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<i32>) -> i1
 //       CHECK:   scf.condition(%[[Cond]])
@@ -74,7 +74,7 @@ func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32>
 //   CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref<i32>
 //   CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI0]]) : memref<?xi32>
-//   CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref<?xi32>
+//   CHECK-DAG: linalg.fill ins(%[[zeroI32]] : i32) outs(%[[M]] : memref<?xi32>)
 //       CHECK: scf.while : () -> () {
 //       CHECK:   %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<i32>) -> i1
 //       CHECK:   scf.condition(%[[Cond]])
@@ -117,7 +117,7 @@ func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<?xi32
 //   CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref<f64>
 //   CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<2x4xf64>
 //   CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64
-//   CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref<2x4xf64>
+//   CHECK-DAG: linalg.fill ins(%[[E0]] : f64) outs(%[[M]] : memref<2x4xf64>)
 //       CHECK: scf.while : () -> () {
 //       CHECK:   %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
 //       CHECK:   scf.condition(%[[Cond]])
@@ -161,7 +161,7 @@ func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64
 //   CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref<f64>
 //   CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI0]]) : memref<?x4xf64>
 //   CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64
-//   CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref<?x4xf64>
+//   CHECK-DAG: linalg.fill ins(%[[E0]] : f64) outs(%[[M]] : memref<?x4xf64>)
 //       CHECK: scf.while : () -> () {
 //       CHECK:   %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
 //       CHECK:   scf.condition(%[[Cond]])
@@ -205,7 +205,7 @@ func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tensor<?x
 //   CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref<f64>
 //   CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI1]]) : memref<2x?xf64>
 //   CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64
-//   CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref<2x?xf64>
+//   CHECK-DAG: linalg.fill ins(%[[E0]] : f64) outs(%[[M]] : memref<2x?xf64>)
 //       CHECK: scf.while : () -> () {
 //       CHECK:   %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
 //       CHECK:   scf.condition(%[[Cond]])
@@ -249,7 +249,7 @@ func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tensor<2x
 //   CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref<f64>
 //   CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI0]], %[[SizeI1]]) : memref<?x?xf64>
 //   CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64
-//   CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref<?x?xf64>
+//   CHECK-DAG: linalg.fill ins(%[[E0]] : f64) outs(%[[M]] : memref<?x?xf64>)
 //       CHECK: scf.while : () -> () {
 //       CHECK:   %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
 //       CHECK:   scf.condition(%[[Cond]])
@@ -297,7 +297,7 @@ func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x
 //   CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref<f64>
 //   CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<2x3x4xf64>
 //   CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64
-//   CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : f64, memref<2x3x4xf64>
+//   CHECK-DAG: linalg.fill ins(%[[E0]] : f64) outs(%[[M]] : memref<2x3x4xf64>)
 //       CHECK: scf.while : () -> () {
 //       CHECK:   %[[Cond:.*]] = call @getNextF64(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
 //       CHECK:   scf.condition(%[[Cond]])

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
index cf9ff82929154..07d6cdbdafefa 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
@@ -52,7 +52,7 @@ func @add_d(%arga: tensor<32xf32, #DV>, %argb: f32, %argx: tensor<32xf32>) -> te
 // CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_7:.*]] = memref.alloc() : memref<32xf32>
-// CHECK:           linalg.fill(%[[VAL_3]], %[[VAL_7]]) : f32, memref<32xf32>
+// CHECK:           linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_7]] : memref<32xf32>)
 // CHECK:           scf.for %[[VAL_8:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_9:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_8]]] : memref<?xf32>
 // CHECK:             %[[VAL_10:.*]] = arith.addf %[[VAL_9]], %[[VAL_1]] : f32

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
index ace977fb1e7a3..771d469c3125c 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
@@ -75,7 +75,7 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
   //      LINALG:   scf.yield %[[A]], %[[i]], %[[j]] : memref<?x8xf32>, index, index
   //      LINALG: } else {
   //               slow path, fill tmp alloc and yield a memref_casted version of it
-  //      LINALG:   linalg.fill(%cst, %[[alloc]]) : f32, memref<4x8xf32>
+  //      LINALG:   linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>)
   //      LINALG:   %[[d0:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x8xf32>
   //      LINALG:   %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[d0]], %[[i]], %[[c4]])
   //      LINALG:   %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
@@ -168,7 +168,7 @@ func @split_vector_transfer_read_strided_2d(
   // LINALG-SAME:     memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
   //      LINALG: } else {
   //               slow path, fill tmp alloc and yield a memref_casted version of it
-  //      LINALG:   linalg.fill(%cst, %[[alloc]]) : f32, memref<4x8xf32>
+  //      LINALG:   linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>)
   //      LINALG:   %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]])
   //      LINALG:   %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
   //      LINALG:   %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1]

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
index 1b62fefa7ec9d..0be51efee904e 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
@@ -68,8 +68,8 @@ func @entry() {
   %RHS10 = memref.alloc() {alignment = 64} : memref<1x10xf32>
   %DST10 = memref.alloc() {alignment = 64} : memref<1x10xf32>
 
-  linalg.fill(%f1, %LHS10) : f32, memref<1x10xf32>
-  linalg.fill(%f1, %RHS10) : f32, memref<1x10xf32>
+  linalg.fill ins(%f1 : f32) outs(%LHS10 : memref<1x10xf32>)
+  linalg.fill ins(%f1 : f32) outs(%RHS10 : memref<1x10xf32>)
 
   %LHS = memref.cast %LHS10 : memref<1x10xf32> to memref<?x?xf32>
   %RHS = memref.cast %RHS10 : memref<1x10xf32> to memref<?x?xf32>

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir
index 9dabde643cf6e..738e47ceb7ebd 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir
@@ -90,8 +90,8 @@ func @entry() {
   %RHS10 = memref.alloc() {alignment = 64} : memref<1x10xf32>
   %DST10 = memref.alloc() {alignment = 64} : memref<1x10xf32>
 
-  linalg.fill(%f1, %LHS10) : f32, memref<1x10xf32>
-  linalg.fill(%f1, %RHS10) : f32, memref<1x10xf32>
+  linalg.fill ins(%f1 : f32) outs(%LHS10 : memref<1x10xf32>)
+  linalg.fill ins(%f1 : f32) outs(%RHS10 : memref<1x10xf32>)
 
   %LHS = memref.cast %LHS10 : memref<1x10xf32> to memref<?x?xf32>
   %RHS = memref.cast %RHS10 : memref<1x10xf32> to memref<?x?xf32>

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
index 07d4e49fc5b98..8ee06c8a0ac7e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/benchmark_matmul.mlir
@@ -59,9 +59,9 @@ func @main() {
   %B = memref.alloc() : !row_major_B
   %C = memref.alloc() : !row_major_C
 
-  linalg.fill(%v1, %A) : !elem_type_a, !row_major_A
-  linalg.fill(%v1, %B) : !elem_type_b, !row_major_B
-  linalg.fill(%v0, %C) : !elem_type_c, !row_major_C
+  linalg.fill ins(%v1 : !elem_type_a) outs(%A : !row_major_A)
+  linalg.fill ins(%v1 : !elem_type_b) outs(%B : !row_major_B)
+  linalg.fill ins(%v0 : !elem_type_c) outs(%C : !row_major_C)
 
   %c0 = arith.constant 0: index
   %c1 = arith.constant 1: index
@@ -71,7 +71,7 @@ func @main() {
   /// Preheating run:
   scf.for %arg0 = %c0 to %iters step %c1 {
     %z = arith.constant 0.0 : !elem_type_c
-    linalg.fill(%z, %C) : !elem_type_c, !row_major_C
+    linalg.fill ins(%z : !elem_type_c) outs(%C : !row_major_C)
     call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
   }
   %t_start_matmul = call @rtclock() : () -> f64
@@ -81,7 +81,7 @@ func @main() {
     // Once linalg on tensors is ready, fusing fill at the register level will
     // be easy.
     %z = arith.constant 0.0 : !elem_type_c
-    linalg.fill(%z, %C) : !elem_type_c, !row_major_C
+    linalg.fill ins(%z : !elem_type_c) outs(%C : !row_major_C)
     call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
   }
   %t_end_matmul = call @rtclock() : () -> f64
@@ -90,7 +90,7 @@ func @main() {
 
   // CHECK: {{^0$}}
   %C_ref = memref.alloc() : !row_major_C
-  linalg.fill(%v0, %C_ref) : !elem_type_c, !row_major_C
+  linalg.fill ins(%v0 : !elem_type_c) outs(%C_ref : !row_major_C)
   linalg.matmul ins(%A, %B : !row_major_A, !row_major_B)
     outs(%C_ref: !row_major_C)
   %act = memref.cast %C : !row_major_C to memref<*xf32>

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir
index 5be02fe1e17c3..a67d4730af87f 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir
@@ -12,7 +12,7 @@ func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>) -> (memref<?x?xf32>) {
   %x = memref.dim %A, %c0 : memref<?x?xf32>
   %y = memref.dim %B, %c1 : memref<?x?xf32>
   %C = memref.alloc(%x, %y) : memref<?x?xf32>
-  linalg.fill(%f0, %C) : f32, memref<?x?xf32>
+  linalg.fill ins(%f0 : f32) outs(%C : memref<?x?xf32>)
   linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
                 outs(%C: memref<?x?xf32>)
   return %C : memref<?x?xf32>
@@ -26,7 +26,7 @@ func @matvec(%A: memref<?x?xf32>, %B: memref<?x?xf32>) -> (memref<?x?xf32>) {
   %x = memref.dim %A, %c1 : memref<?x?xf32>
   %n = memref.dim %B, %c1 : memref<?x?xf32>
   %C = memref.alloc(%m, %n) : memref<?x?xf32>
-  linalg.fill(%f0, %C) : f32, memref<?x?xf32>
+  linalg.fill ins(%f0 : f32) outs(%C : memref<?x?xf32>)
   scf.for %i = %c0 to %n step %c1 {
     %b = memref.subview %B[0, %i][%x, 1][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [?]>
     %c = memref.subview %C[0, %i][%m, 1][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [?]>
@@ -46,8 +46,8 @@ func @main() {
   %val2 = arith.constant 17.0 : f32
   %A = memref.alloc(%m, %x) : memref<?x?xf32>
   %B = memref.alloc(%x, %n) : memref<?x?xf32>
-  linalg.fill(%val1, %A) : f32, memref<?x?xf32>
-  linalg.fill(%val2, %B) : f32, memref<?x?xf32>
+  linalg.fill ins(%val1 : f32) outs(%A : memref<?x?xf32>)
+  linalg.fill ins(%val2 : f32) outs(%B : memref<?x?xf32>)
   memref.store %val1, %B[%c0, %c0] : memref<?x?xf32>
   %C1 = call @matmul(%A, %B) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>
   %C2 = call @matvec(%A, %B) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
index 0508a2fab6653..4effdbcc8b063 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
@@ -14,7 +14,7 @@ func @init_and_dot(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<f
   %cst = arith.constant 0.000000e+00 : f32
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
-  %0 = linalg.fill(%cst, %arg2) : f32, tensor<f32> -> tensor<f32>
+  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<f32>) -> tensor<f32>
   %1 = affine.apply #map0(%c0, %c64)[%c2]
   %2 = linalg.init_tensor [%1, 2] : tensor<?x2xf32>
   %3 = scf.for %arg3 = %c0 to %c64 step %c2 iter_args(%arg4 = %2) -> (tensor<?x2xf32>) {
@@ -83,9 +83,9 @@ func @main() {
   %A = linalg.init_tensor [64] : tensor<64xf32>
   %B = linalg.init_tensor [64] : tensor<64xf32>
   %C = linalg.init_tensor [] : tensor<f32>
-  %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
-  %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
-  %CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
+  %AA = linalg.fill ins(%v1 : f32) outs(%A : tensor<64xf32>) -> tensor<64xf32>
+  %BB = linalg.fill ins(%v2 : f32) outs(%B : tensor<64xf32>) -> tensor<64xf32>
+  %CC = linalg.fill ins(%v0 : f32) outs(%C : tensor<f32>) -> tensor<f32>
 
   %res = call @init_and_dot(%AA, %BB, %CC) :
     (tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir
index bceb9b20c7ac0..d7d6bcaebfbf0 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir
@@ -14,7 +14,7 @@ func private @print_memref_f32(memref<*xf32>)
 // Creates and returns a 1-D buffer of size %s1 filled with the value %f
 func @alloc_1d_filled_f32(%s1 : index, %f : f32) -> memref<?xf32> {
   %buf = memref.alloc(%s1) : memref<?xf32>
-  linalg.fill(%f, %buf) : f32, memref<?xf32>
+  linalg.fill ins(%f : f32) outs(%buf : memref<?xf32>)
   return %buf : memref<?xf32>
 }
 

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir
index 6c1ff90701591..5ce790ba1b961 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir
@@ -14,7 +14,7 @@ func private @print_memref_f32(memref<*xf32>)
 // Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f
 func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> memref<?x?x?xf32> {
   %buf = memref.alloc(%s1, %s2, %s3) : memref<?x?x?xf32>
-  linalg.fill(%f, %buf) : f32, memref<?x?x?xf32>
+  linalg.fill ins(%f : f32) outs(%buf : memref<?x?x?xf32>)
   return %buf : memref<?x?x?xf32>
 }
 

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir
index 5f33579e692df..e257095b7707f 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir
@@ -14,7 +14,7 @@ func private @print_memref_f32(memref<*xf32>)
 // Creates and returns a 2-D buffer of size (%s1, %s2) filled with the value %f
 func @alloc_2d_filled_f32(%s1 : index, %s2 : index, %f : f32) -> memref<?x?xf32> {
   %buf = memref.alloc(%s1, %s2) : memref<?x?xf32>
-  linalg.fill(%f, %buf) : f32, memref<?x?xf32>
+  linalg.fill ins(%f : f32) outs(%buf : memref<?x?xf32>)
   return %buf : memref<?x?xf32>
 }
 

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir
index fb5732effc9f8..0b840979df83a 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir
@@ -14,7 +14,7 @@ func private @print_memref_f32(memref<*xf32>)
 // Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
 func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f : f32) -> memref<?x?x?x?xf32> {
   %buf = memref.alloc(%s1, %s2, %s3, %s4) : memref<?x?x?x?xf32>
-  linalg.fill(%f, %buf) : f32, memref<?x?x?x?xf32>
+  linalg.fill ins(%f : f32) outs(%buf : memref<?x?x?x?xf32>)
   return %buf : memref<?x?x?x?xf32>
 }
 

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir
index ce077654cb94b..6d3bfe4d87bc9 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir
@@ -14,7 +14,7 @@ func private @print_memref_f32(memref<*xf32>)
 // Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f
 func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> memref<?x?x?xf32> {
   %buf = memref.alloc(%s1, %s2, %s3) : memref<?x?x?xf32>
-  linalg.fill(%f, %buf) : f32, memref<?x?x?xf32>
+  linalg.fill ins(%f : f32) outs(%buf : memref<?x?x?xf32>)
   return %buf : memref<?x?x?xf32>
 }
 

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir
index 03b7bc2d71ead..521ab66c1927c 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir
@@ -14,7 +14,7 @@ func private @print_memref_f32(memref<*xf32>)
 // Creates and returns 5-D buffer of size (%s1, %s2, %s3, %s4, %s5) filled with the value %f
 func @alloc_5d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index, %f : f32) -> memref<?x?x?x?x?xf32> {
   %buf = memref.alloc(%s1, %s2, %s3, %s4, %s5) : memref<?x?x?x?x?xf32>
-  linalg.fill(%f, %buf) : f32, memref<?x?x?x?x?xf32>
+  linalg.fill ins(%f : f32) outs(%buf : memref<?x?x?x?x?xf32>)
   return %buf : memref<?x?x?x?x?xf32>
 }
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
index 46c7ba8a49287..2dce759bfca05 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
@@ -491,7 +491,7 @@ def emit_tensor_init(self) -> ir.RankedTensorType:
       ir_type = _mlir_type_from_taco_type(self.dst_dtype)
       tensor = linalg.InitTensorOp(self.dst_dims, ir_type).result
       zero = arith.ConstantOp(ir_type, 0.0)
-      return linalg.FillOp(output=tensor, value=zero).results[0]
+      return linalg.fill(zero, outs=[tensor])
 
     # Initialize the sparse tensor.
     mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims,

diff  --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir
index 24263021aba31..0e820ed47357f 100644
--- a/mlir/test/mlir-cpu-runner/async.mlir
+++ b/mlir/test/mlir-cpu-runner/async.mlir
@@ -19,7 +19,7 @@ func @main() {
   %c4 = arith.constant 4.0 : f32
 
   %A = memref.alloc() : memref<4xf32>
-  linalg.fill(%c0, %A) : f32, memref<4xf32>
+  linalg.fill ins(%c0 : f32) outs(%A : memref<4xf32>)
 
   // CHECK: [0, 0, 0, 0]
   %U = memref.cast %A :  memref<4xf32> to memref<*xf32>

diff  --git a/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir b/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir
index 81c6e204a684b..0fff711fe97e3 100644
--- a/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir
+++ b/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir
@@ -7,14 +7,14 @@ func @main() {
 
   %cf1 = arith.constant 1.00000e+00 : f32
 
-  linalg.fill(%cf1, %A) : f32, memref<16x16xf32>
-  linalg.fill(%cf1, %B) : f32, memref<16x16xf32>
+  linalg.fill ins(%cf1 : f32) outs(%A : memref<16x16xf32>)
+  linalg.fill ins(%cf1 : f32) outs(%B : memref<16x16xf32>)
 
   %reps = arith.constant 1 : index
 
   %t_start = call @rtclock() : () -> f64
   affine.for %arg0 = 0 to 5 {
-    linalg.fill(%cf1, %C) : f32, memref<16x16xf32>
+    linalg.fill ins(%cf1 : f32) outs(%C : memref<16x16xf32>)
     call @sgemm_naive(%A, %B, %C) : (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>) -> ()
   }
   %t_end = call @rtclock() : () -> f64

diff  --git a/mlir/test/mlir-cpu-runner/unranked-memref.mlir b/mlir/test/mlir-cpu-runner/unranked-memref.mlir
index 2e2062cdde632..91163511ac9f2 100644
--- a/mlir/test/mlir-cpu-runner/unranked-memref.mlir
+++ b/mlir/test/mlir-cpu-runner/unranked-memref.mlir
@@ -42,18 +42,18 @@ func @main() -> () {
     %f10 = arith.constant 10.00000e+00 : f32
 
     %V = memref.cast %A : memref<10x3xf32, 0> to memref<?x?xf32>
-    linalg.fill(%f10, %V) : f32, memref<?x?xf32, 0>
+    linalg.fill ins(%f10 : f32) outs(%V : memref<?x?xf32, 0>)
     %U = memref.cast %A : memref<10x3xf32, 0> to memref<*xf32>
     call @print_memref_f32(%U) : (memref<*xf32>) -> ()
 
     %V2 = memref.cast %U : memref<*xf32> to memref<?x?xf32>
-    linalg.fill(%f5, %V2) : f32, memref<?x?xf32, 0>
+    linalg.fill ins(%f5 : f32) outs(%V2 : memref<?x?xf32, 0>)
     %U2 = memref.cast %V2 : memref<?x?xf32, 0> to memref<*xf32>
     call @print_memref_f32(%U2) : (memref<*xf32>) -> ()
 
     %V3 = memref.cast %V2 : memref<?x?xf32> to memref<*xf32>
     %V4 = memref.cast %V3 : memref<*xf32> to memref<?x?xf32>
-    linalg.fill(%f2, %V4) : f32, memref<?x?xf32, 0>
+    linalg.fill ins(%f2 : f32) outs(%V4 : memref<?x?xf32, 0>)
     %U3 = memref.cast %V2 : memref<?x?xf32> to memref<*xf32>
     call @print_memref_f32(%U3) : (memref<*xf32>) -> ()
 
@@ -79,7 +79,7 @@ func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface
 func @return_two_var_memref_caller() {
   %0 = memref.alloca() : memref<4x3xf32>
   %c0f32 = arith.constant 1.0 : f32
-  linalg.fill(%c0f32, %0) : f32, memref<4x3xf32>
+  linalg.fill ins(%c0f32 : f32) outs(%0 : memref<4x3xf32>)
   %1:2 = call @return_two_var_memref(%0) : (memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>)
   call @print_memref_f32(%1#0) : (memref<*xf32>) -> ()
   call @print_memref_f32(%1#1) : (memref<*xf32>) -> ()
@@ -94,7 +94,7 @@ func @return_two_var_memref_caller() {
 func @return_var_memref_caller() {
   %0 = memref.alloca() : memref<4x3xf32>
   %c0f32 = arith.constant 1.0 : f32
-  linalg.fill(%c0f32, %0) : f32, memref<4x3xf32>
+  linalg.fill ins(%c0f32 : f32) outs(%0 : memref<4x3xf32>)
   %1 = call @return_var_memref(%0) : (memref<4x3xf32>) -> memref<*xf32>
   call @print_memref_f32(%1) : (memref<*xf32>) -> ()
   return

diff  --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir
index 461d0578a90ad..a91abf1ddbf11 100644
--- a/mlir/test/mlir-cpu-runner/utils.mlir
+++ b/mlir/test/mlir-cpu-runner/utils.mlir
@@ -19,7 +19,7 @@ func @print_1d() {
   %f = arith.constant 2.00000e+00 : f32
   %A = memref.alloc() : memref<16xf32>
   %B = memref.cast %A: memref<16xf32> to memref<?xf32>
-  linalg.fill(%f, %B) : f32, memref<?xf32>
+  linalg.fill ins(%f : f32) outs(%B : memref<?xf32>)
   %U = memref.cast %B :  memref<?xf32> to memref<*xf32>
   call @print_memref_f32(%U): (memref<*xf32>) -> ()
   memref.dealloc %A : memref<16xf32>
@@ -33,7 +33,7 @@ func @print_3d() {
   %f4 = arith.constant 4.00000e+00 : f32
   %A = memref.alloc() : memref<3x4x5xf32>
   %B = memref.cast %A: memref<3x4x5xf32> to memref<?x?x?xf32>
-  linalg.fill(%f, %B) : f32, memref<?x?x?xf32>
+  linalg.fill ins(%f : f32) outs(%B : memref<?x?x?xf32>)
 
   %c2 = arith.constant 2 : index
   memref.store %f4, %B[%c2, %c2, %c2]: memref<?x?x?xf32>

diff  --git a/mlir/test/mlir-opt/async.mlir b/mlir/test/mlir-opt/async.mlir
index 5b0c11ca19138..af06c397a6381 100644
--- a/mlir/test/mlir-opt/async.mlir
+++ b/mlir/test/mlir-opt/async.mlir
@@ -20,7 +20,7 @@ func @main() {
   %c4 = arith.constant 4.0 : f32
 
   %A = memref.alloc() : memref<4xf32>
-  linalg.fill(%c0, %A) : f32, memref<4xf32>
+  linalg.fill ins(%c0 : f32) outs(%A : memref<4xf32>)
 
   %U = memref.cast %A :  memref<4xf32> to memref<*xf32>
   call @print_memref_f32(%U): (memref<*xf32>) -> ()

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 6f912c8bb129a..494213ef2777f 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -65,22 +65,22 @@ def testFill():
       # CHECK-LABEL: func @fill_tensor
       #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
       #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
-      #  CHECK-NEXT: %[[RES:.*]] = linalg.fill(%[[CST]], %[[OUT]]) : f32, tensor<12x?xf32> -> tensor<12x?xf32>
+      #  CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32>
       #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
       @builtin.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32))
       def fill_tensor(out):
         zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
-        return linalg.FillOp(output=out, value=zero).result
+        return linalg.fill(zero, outs=[out])
 
       # CHECK-LABEL: func @fill_buffer
       #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
       #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
-      #  CHECK-NEXT: linalg.fill(%[[CST]], %[[OUT]]) : f32, memref<12x?xf32>
+      #  CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>)
       #  CHECK-NEXT: return
       @builtin.FuncOp.from_py_func(MemRefType.get((12, -1), f32))
       def fill_buffer(out):
         zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
-        linalg.FillOp(output=out, value=zero)
+        linalg.fill(zero, outs=[out])
 
   print(module)
 
@@ -179,9 +179,9 @@ def testOpResultFromOtherOp():
       def pass_an_op_directly(arg0, arg1):
         one = arith.ConstantOp(F32Type.get(), 1.0)
         # CHECK: %[[LHS:.*]] = linalg.fill
-        lhs = linalg.FillOp(arg0, one)
+        lhs = linalg.fill(one, outs=[arg0])
         # CHECK: %[[RHS:.*]] = linalg.fill
-        rhs = linalg.FillOp(arg1, one)
+        rhs = linalg.fill(one, outs=[arg1])
         # CHECK: %[[INIT:.*]] = linalg.init_tensor
         init = linalg.InitTensorOp([4, 8], f32)
         # CHECK: linalg.matmul

diff  --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 279af2bf77b34..f58292ee7241d 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -29,10 +29,10 @@ def log(*args):
   %rhs = memref.alloc() : memref<4x8xf32>
   %O0 = memref.alloc() : memref<4x8xf32>
   %O1 = memref.alloc() : memref<4x8xf32>
-  linalg.fill(%v1, %lhs) : f32, memref<f32>
-  linalg.fill(%v2, %rhs) : f32, memref<4x8xf32>
-  linalg.fill(%v0, %O0) : f32, memref<4x8xf32>
-  linalg.fill(%v0, %O1) : f32, memref<4x8xf32>
+  linalg.fill ins(%v1 : f32) outs(%lhs : memref<f32>)
+  linalg.fill ins(%v2 : f32) outs(%rhs : memref<4x8xf32>)
+  linalg.fill ins(%v0 : f32) outs(%O0 : memref<4x8xf32>)
+  linalg.fill ins(%v0 : f32) outs(%O1 : memref<4x8xf32>)
 
   call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
     (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
@@ -60,10 +60,10 @@ def log(*args):
   %B = memref.alloc() : memref<16x8xf32>
   %C0 = memref.alloc() : memref<4x8xf32>
   %C1 = memref.alloc() : memref<4x8xf32>
-  linalg.fill(%v1, %A) : i8, memref<4x16xi8>
-  linalg.fill(%v2, %B) : f32, memref<16x8xf32>
-  linalg.fill(%v0, %C0) : f32, memref<4x8xf32>
-  linalg.fill(%v0, %C1) : f32, memref<4x8xf32>
+  linalg.fill ins(%v1 : i8) outs(%A : memref<4x16xi8>)
+  linalg.fill ins(%v2 : f32) outs(%B : memref<16x8xf32>)
+  linalg.fill ins(%v0 : f32) outs(%C0 : memref<4x8xf32>)
+  linalg.fill ins(%v0 : f32) outs(%C1 : memref<4x8xf32>)
 
   call @matmul_signed_on_buffers(%A, %B, %C0) :
     (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> ()
@@ -137,9 +137,9 @@ def log(*args):
   %input = memref.alloc() : memref<1x4x16x1xf64>
   %filter = memref.alloc() : memref<2x2x1xf64>
   %output = memref.alloc() : memref<1x2x4x1xi32>
-  linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
-  linalg.fill(%v2, %filter) : f64, memref<2x2x1xf64>
-  linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
+  linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>)
+  linalg.fill ins(%v2 : f64) outs(%filter : memref<2x2x1xf64>)
+  linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>)
 
   call @conv_on_buffers(%input, %filter, %output) :
     (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> ()
@@ -163,9 +163,9 @@ def log(*args):
   %input = memref.alloc() : memref<1x4x16x1xf64>
   %shape = memref.alloc() : memref<2x2xf64>
   %output = memref.alloc() : memref<1x2x4x1xi32>
-  linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
-  linalg.fill(%v1, %shape) : f64, memref<2x2xf64>
-  linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
+  linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>)
+  linalg.fill ins(%v1 : f64) outs(%shape : memref<2x2xf64>)
+  linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>)
 
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -368,15 +368,15 @@ def test_fill_builtin():
 
       @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32))
       def fill_0d_on_buffers(value, out):
-        linalg.fill_tensor(value, outs=[out])
+        linalg.fill(value, outs=[out])
 
       @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
       def fill_1d_on_buffers(value, out):
-        linalg.fill_tensor(value, outs=[out])
+        linalg.fill(value, outs=[out])
 
       @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
       def fill_2d_on_buffers(value, out):
-        linalg.fill_tensor(value, outs=[out])
+        linalg.fill(value, outs=[out])
 
     execution_engine = ExecutionEngine(transform(module, fill_boiler))
 
@@ -403,15 +403,15 @@ def test_fill_generic():
 
       @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32))
       def fill_0d_on_buffers(value, out):
-        linalg.fill_tensor(value, outs=[out], emit_generic=True)
+        linalg.fill(value, outs=[out], emit_generic=True)
 
       @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
       def fill_1d_on_buffers(value, out):
-        linalg.fill_tensor(value, outs=[out], emit_generic=True)
+        linalg.fill(value, outs=[out], emit_generic=True)
 
       @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
       def fill_2d_on_buffers(value, out):
-        linalg.fill_tensor(value, outs=[out], emit_generic=True)
+        linalg.fill(value, outs=[out], emit_generic=True)
 
     execution_engine = ExecutionEngine(transform(module, fill_boiler))
 


        


More information about the Mlir-commits mailing list