[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)

Rolf Morel llvmlistbot at llvm.org
Wed Jan 29 09:20:01 PST 2025


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/123618

>From 09263b480862dcb0fc0d54824bb18734b4bf9d53 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 19 Jan 2025 13:42:59 -0800
Subject: [PATCH 01/11] [MLIR][Linalg] Introduce linalg.contract

A new op that allows for representing arbitrary contractions on operands
of arbitrary rank, with arbitrary transposes and arbitrary broadcasts
specified through its indexing_maps attribute.

Supports the expected lowerings to linalg.generic and to vector.contract.
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 118 +++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 224 ++++++++++++++++--
 .../Dialect/Linalg/generalize-named-ops.mlir  | 168 ++++++++++++-
 .../generalize-named-polymorphic-ops.mlir     |  50 ++++
 mlir/test/Dialect/Linalg/invalid.mlir         |  97 ++++++++
 mlir/test/Dialect/Linalg/loops.mlir           |  47 ++++
 mlir/test/Dialect/Linalg/named-ops.mlir       |  23 +-
 mlir/test/Dialect/Linalg/roundtrip.mlir       |  17 +-
 mlir/test/Dialect/Linalg/tile-tensors.mlir    |  46 ++++
 .../Linalg/transform-op-vectorize.mlir        |  33 +++
 10 files changed, 789 insertions(+), 34 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..d4277bd34f3946 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,124 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+  let summary = [{
+    Perform a contraction on two inputs, accumulating on top of a third.
+  }];
+  let description = [{
+    The semantics of contracting inputs `A` and `B` on top of `C` to produce
+    output `D` is given by
+
+      `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+    where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
+    dimension identifiers (meant to range over valid indices), corresponding to
+    the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
+    and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
+    the dimensions in the set `dims`.
+
+    The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+    domain of each of the `affine_map`s. Like for einsums, the iteration type of
+    each dim is inferred and is either:
+
+    - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
+      Per the above semantics, these dims will be contracted, i.e. reduced over.
+
+    - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
+      deriving from matmul terminology - is either an "M-like" dim (if in `A`
+      and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
+      `B`, and `C`).
+
+    For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+    `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+    `n` and `b` are of parallel iteration-type) and gets represented as:
+
+    ```
+    %0 = linalg.contract
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
+        ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+        outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+    ```
+
+    Note that by permuting the dims in the co-domains of the `affine_map`s, we
+    can apply arbitrary transposes to the inputs and output. Similarly,
+    arbitrary broadcasts can be achieved through leaving out dims on either
+    input operand.
+
+    Numeric casting is performed on the operands to the inner multiplication,
+    promoting them to the same data type as the accumulator/output.
+  }];
+
+  let arguments = (ins
+    Variadic<AnyType>:$inputs,
+    Variadic<AnyShaped>:$outputs,
+    AffineMapArrayAttr:$indexing_maps
+  );
+  let results = (outs Variadic<AnyShaped>:$result_tensors);
+  let regions = (region SizedRegion<1>:$combiner);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+      "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
+                          outputs, attributes, regionBuilder);
+      }]>,
+    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+      "ArrayAttr":$indexingMaps,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+                          attributes, regionBuilder);
+      }]>
+  ];
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    // Declare/implement functions necessary for LinalgStructuredInterface.
+    /// Infer iterator types for each dim in the domain of IndexingMaps.
+    SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+    /// IndexingMaps always depends on attr associated to current Op instance.
+    bool hasDynamicIndexingMaps() { return true; };
+    bool hasUserDefinedMaps() { return true; };
+
+    static unsigned getNumRegionArgs();
+
+    static void regionBuilder(ImplicitLocOpBuilder &b,
+                              Block &block, ArrayRef<NamedAttribute> attrs);
+
+    static std::function<void(ImplicitLocOpBuilder &,
+                              Block &, ArrayRef<NamedAttribute>)>
+    getRegionBuilder() {
+      return regionBuilder;
+    }
+
+    std::string getLibraryCallName() {
+      return "op_has_no_registered_library_name";
+    }
+
+    // Implement function necessary for DestinationStyleOpInterface.
+    ::mlir::MutableOperandRange getDpsInitsMutable() {
+      return getOutputsMutable();
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 71bae4666d6619..7b0a14d8f15b07 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3531,44 +3531,45 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
   return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
 }
 
-ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
-  SmallVector<Attribute, 3> indexingMapsAttr;
-  Attribute mapAttr;
-  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
-    if (parser.parseEqual())
-      return failure();
+FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
+  if (parser.parseOptionalKeyword("indexing_maps"))
+    return {nullptr}; // Success in case indexing_maps was not provided.
 
-    if (parser.parseLSquare())
+  SmallVector<Attribute> indexingMaps;
+
+  auto parseIndexingMap = [&]() -> ParseResult {
+    AffineMapAttr affineMapAttr;
+    if (parser.parseAttribute(affineMapAttr))
       return failure();
+    indexingMaps.push_back(affineMapAttr);
+    return success();
+  };
 
-    do {
-      if (parser.parseAttribute(mapAttr))
-        return failure();
-      if (!isa<AffineMapAttr>(mapAttr)) {
-        return parser.emitError(parser.getCurrentLocation(),
-                                "expected affine map attribute");
-      }
-      indexingMapsAttr.push_back(mapAttr);
+  if (parser.parseEqual() ||
+      parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
+                                     parseIndexingMap))
+    return failure();
 
-      if (parser.parseOptionalComma())
-        break;
-    } while (true);
+  return parser.getBuilder().getArrayAttr(indexingMaps);
+}
 
-    if (parser.parseRSquare())
-      return failure();
-  }
-  // Initialize indexingMaps, if not supplied explicitly.
-  if (indexingMapsAttr.empty()) {
-    indexingMapsAttr = llvm::map_to_vector(
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr))
+    return failure();
+
+  if (*indexingMapsAttr == nullptr) {
+    auto indexingMapAttrs = llvm::map_to_vector(
         MatmulOp::getDefaultIndexingMaps(parser.getContext()),
         [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+    indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
   }
-  result.addAttribute("indexing_maps",
-                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
 
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
   return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
                                 MatmulOp::getRegionBuilder());
 }
+
 void MatmulOp::print(OpAsmPrinter &p) {
   SmallVector<StringRef, 3> elidedAttrs = {
       "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3602,6 +3603,7 @@ LogicalResult MatmulOp::verify() {
 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   return memref::foldMemRefCast(*this);
 }
+
 void MatmulOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -3614,5 +3616,175 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ContractOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+  /// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+  /// domains are all the same, and each implements a projected permutation.
+  /// Each dim in the domain must occur for at least one operand and is
+  /// classified as either batch, N-like, M-like, or K-like. Only the latter
+  /// corresponds to a reduction _and_ it is the only dim-kind which does not
+  /// occur for the output operand. We use this fact for fast inference:
+  // NB: In case we allow dims to occur solely for one input, the above still
+  //     holds: per the einsum semantics, these are reduction dims as well.
+  auto dimsInOutput = SmallVector<bool>(outAffineMap.getNumDims(), false);
+  for (auto result : outAffineMap.getResults()) {
+    auto dimExpr = dyn_cast<AffineDimExpr>(result);
+    assert(dimExpr && "affine_map is a projected permutation");
+    dimsInOutput[dimExpr.getPosition()] = true;
+  }
+
+  SmallVector<utils::IteratorType> iteratorTypes;
+  for (auto dimOccursInOutput : dimsInOutput)
+    iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+                                              : utils::IteratorType::reduction);
+
+  return iteratorTypes;
+}
+
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                               ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "ContractOp regionBuilder expects 3 args");
+  RegionBuilderHelper helper(b, block);
+
+  TypeFn castSignedness = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castSignedness = attr.getValue();
+  }
+
+  // TODO: Support fields with operators besides mult & add.
+  Type outType = block.getArgument(2).getType();
+  Value lhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+  Value rhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+  Value productAtOutType =
+      helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+                                      productAtOutType);
+  helper.yieldOutputs({result});
+}
+
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected 'indexing_map' attribute");
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
+
+  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+                                regionBuilder);
+}
+
+void ContractOp::print(OpAsmPrinter &p) {
+  p << " indexing_maps = [";
+  llvm::interleaveComma(getIndexingMaps(), p,
+                        [&](Attribute attr) { p.printAttribute(attr); });
+  p << "]";
+  printNamedStructuredOp(
+      p, getOperation(), getInputs(), getOutputs(),
+      /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+}
+
+LogicalResult ContractOp::verify() {
+  int iterationSpaceDims = -1;
+  // Maps iter space dim (as index) to num of occurrences in inputs and output.
+  SmallVector<size_t> inOccurrences;
+  SmallVector<size_t> outOccurrences;
+
+  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
+                                   bool isInput) -> LogicalResult {
+    if (iterationSpaceDims == -1) {
+      iterationSpaceDims = affineMap.getNumDims();
+      inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+      outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+    } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
+      return emitError("iteration spaces of provided affine_maps differ");
+    }
+
+    if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
+      if (affineMap.getNumResults() != shapedType.getRank())
+        return emitError("ranks of shaped operand and co-domain of "
+                         "corresponding affine_map differ");
+    } else if (affineMap.getNumResults() != 0) {
+      return emitError("affine_map specifies shaped access while operand has "
+                       "non-shaped type");
+    }
+
+    if (!affineMap.isProjectedPermutation())
+      return emitError("provided affine_map is not a projected permutation");
+
+    for (AffineExpr affineExpr : affineMap.getResults()) {
+      auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
+      if (!affineDimExpr)
+        llvm_unreachable("affine_map is a projected permutation");
+
+      if (isInput)
+        inOccurrences[affineDimExpr.getPosition()] += 1;
+      else
+        outOccurrences[affineDimExpr.getPosition()] += 1;
+    }
+
+    return success();
+  };
+
+  for (auto &&[affineMap, operandType, isInput] :
+       llvm::zip(getIndexingMapsArray(), getOperandTypes(),
+                 SmallVector<bool>{true, true, false}))
+    if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
+      return failure(); // NOTE: checking lambda will emit error.
+
+  bool hasContractingDim = false;
+  for (auto &&[inOccCount, outOccCount] : zip(inOccurrences, outOccurrences)) {
+    hasContractingDim |= inOccCount == 2 && outOccCount == 0;
+
+    if (inOccCount == 0)
+      return emitError("iteration space dim not used by either input");
+
+    // NB: A dim which occurs for only one input operand and not for the output.
+    //     In terms of einsum semantics, such dims have a sensible meaning -
+    //     namely an additional reduction per such dim - though this can also
+    //     always be expressed through an additional op. Additionally, at time
+    //     of writing, vector.contract's verifier accepts these dims but many of
+    //     its lowerings do not handle these kinds of dims. Hence...
+    // TODO: Remove following once we have comprehensive support for input-only
+    //       reduction dims, at both the linalg- and vector-dialect levels.
+    if (inOccCount == 1 && outOccCount != 1)
+      return emitError("iter type of dim is not one of M, N, K or batch");
+  }
+
+  if (!hasContractingDim)
+    return emitError("'indexing_maps' do not specify a contracting dimension");
+
+  return success();
+}
+
+LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+void ContractOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ContractOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..f7e570d5ce38f0 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -943,7 +943,6 @@ func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -969,7 +968,6 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
                       ]
                       ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -996,9 +994,173 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @contract_matmul(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @contract_matmul_transpose_a_b(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d2, d0)>,
+                    affine_map<(d0, d1, d2) -> (d1, d2)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @contract_batch_matmul(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<9x3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_batch_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<9x3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                  ]
+                  ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
+                  outs(%arg2: memref<9x3x7xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @contract_batch_reduce_matmul(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_batch_reduce_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                  ]
+                  ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<9x5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<9x7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(%arg0: memref<9x5x3xf32>, %arg1: memref<9x7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                  ]
+                  ins(%arg0, %arg1 : memref<9x5x3xf32>, memref<9x7x5xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> ()>
+
+// CHECK-LABEL:   func.func @contract_dot
+// CHECK-SAME:                           (%[[VAL_0:.*]]: memref<9xf32>, %[[VAL_1:.*]]: memref<9xf32>, %[[VAL_2:.*]]: memref<f32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_dot(%arg0: memref<9xf32>, %arg1: memref<9xf32>, %arg2: memref<f32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0) -> (d0)>,
+                    affine_map<(d0) -> (d0)>,
+                    affine_map<(d0) -> ()>
+                  ]
+                  ins(%arg0, %arg1 : memref<9xf32>, memref<9xf32>)
+                  outs(%arg2: memref<f32>)
+  return
+}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index c170c5be4abff9..9acb7562f96ee0 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -120,6 +120,56 @@ func.func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B
 
 // -----
 
+func.func @generalize_matmul_as_contraction_tensor_f16f64f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.contract indexing_maps = [
+                         affine_map<(d0, d1, d2) -> (d0, d2)>,
+                         affine_map<(d0, d1, d2) -> (d2, d1)>,
+                         affine_map<(d0, d1, d2) -> (d0, d1)>
+                       ]
+                       ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
+                       outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @generalize_matmul_as_contraction_tensor_f16f64f32
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+// Verify floating point extension and truncation.
+// CHECK-NEXT:   %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT:   %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
+// CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+
+func.func @generalize_matmul_as_contract_with_ext_and_trunc(%arg0: tensor<24x12xf16>,
+                            %arg1: tensor<12x25xf16>,
+                            %arg2: tensor<24x25xf32>) -> tensor<24x25xf16> {
+  %0 = linalg.contract indexing_maps = [
+                         affine_map<(m, n, k) -> (m, k)>,
+                         affine_map<(m, n, k) -> (k, n)>,
+                         affine_map<(m, n, k) -> (m, n)>
+                       ]
+                       ins(%arg0, %arg1 : tensor<24x12xf16>, tensor<12x25xf16>)
+                       outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  %1 = arith.truncf %0 : tensor<24x25xf32> to tensor<24x25xf16>
+  func.return %1 : tensor<24x25xf16>
+}
+
+// CHECK-LABEL: @generalize_matmul_as_contract_with_ext_and_trunc
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+// Verify floating point extension and truncation.
+// CHECK-NEXT:   %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT:   %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
+// CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<24x25xf32>
+// CHECK-NEXT:   %[[RES:.+]] = arith.truncf {{.*}} : tensor<24x25xf32> to tensor<24x25xf16>
+
+// -----
+
 func.func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
   %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
     ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 0853856d933035..d6c8f38586b923 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -547,6 +547,103 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
 
 // -----
 
+func.func @invalid_indexing_maps_placement_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{custom op 'linalg.contract' expected 'indexing_map' attribute}}
+  linalg.contract ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>)
+                  indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+  return
+}
+
+// -----
+
+func.func @invalid_affine_map_in_indexing_maps_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{provided affine_map is not a projected permutation}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0 + d2, d2)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @differing_iteration_space_of_affine_maps_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{iteration spaces of provided affine_maps differ}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @mismatched_ranks_affine_map_and_operand_contraction(%lhs: tensor<4x1x2xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{ranks of shaped operand and co-domain of corresponding affine_map differ}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : tensor<4x1x2xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+// -----
+
+func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: tensor<4x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{affine_map specifies shaped access while operand has non-shaped type}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1) -> (d0)>,
+                    affine_map<(d0, d1) -> (d0, d1)>,
+                    affine_map<(d0, d1) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : f32, tensor<4x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{iteration space dim not used by either input}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @unused_iteration_space_dim_contraction(%lhs: tensor<8x4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{iter type of dim is not one of M, N, K or batch}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : tensor<8x4x1xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
 func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
   // expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
   linalg.conv_2d_nhwc_hwcf
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 6286a11c11a21f..0a83750b81dea4 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -627,6 +627,53 @@ func.func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f
 //----------------------------------------------------------------------------//
 // Named ops to loops.
 //----------------------------------------------------------------------------//
+func.func @batch_reduce_matmul_as_contract(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                  ]
+                  ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+                  outs(%C : memref<?x?xf32>)
+  return
+}
+// CHECK-LABEL: @batch_reduce_matmul_as_contract
+//  CHECK-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECK-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECK-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//       CHECK: %[[B:.*]] = memref.dim %[[mA]], %c0 : memref<?x?x?xf32>
+//       CHECK: %[[M:.*]] = memref.dim %[[mA]], %c1 : memref<?x?x?xf32>
+//       CHECK: %[[K:.*]] = memref.dim %[[mA]], %c2 : memref<?x?x?xf32>
+//       CHECK: %[[N:.*]] = memref.dim %[[mB]], %c2 : memref<?x?x?xf32>
+//       CHECK: scf.for %[[b:.*]] = %{{.*}} to %[[B]]
+//       CHECK:   scf.for %[[m:.*]] = %{{.*}} to %[[M]]
+//       CHECK:     scf.for %[[n:.*]] = %{{.*}} to %[[N]]
+//       CHECK:       scf.for %[[k:.*]] = %{{.*}} to %[[K]]
+//       CHECK:       %[[va:.*]] = memref.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECK:       %[[vb:.*]] = memref.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECK:       %[[vc:.*]] = memref.load %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+//       CHECK:       %[[inc:.*]] = arith.mulf %[[va]], %[[vb]] : f32
+//       CHECK:       %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+//       CHECK:       store %[[res]], %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+
+// CHECKPARALLEL-LABEL: @batch_reduce_matmul_as_contract
+//  CHECKPARALLEL-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//       CHECKPARALLEL: %[[B:.*]] = memref.dim %[[mA]], %c0 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[M:.*]] = memref.dim %[[mA]], %c1 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[K:.*]] = memref.dim %[[mA]], %c2 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[N:.*]] = memref.dim %[[mB]], %c2 : memref<?x?x?xf32>
+//       CHECKPARALLEL: scf.for %[[b:.*]] = %{{.*}} to %[[B]]
+//       CHECKPARALLEL:   scf.parallel (%[[m:.*]], %[[n:.*]]) = ({{.*}}) to (%[[M]], %[[N]]) step ({{.*}}) {
+//       CHECKPARALLEL:     scf.for %[[k:.*]] = %{{.*}} to %[[K]]
+//       CHECKPARALLEL:         %[[va:.*]] = memref.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[vb:.*]] = memref.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[vc:.*]] = memref.load %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+//       CHECKPARALLEL:         %[[inc:.*]] = arith.mulf %[[va]], %[[vb]] : f32
+//       CHECKPARALLEL:         %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:         store %[[res]], %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+
 func.func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
   linalg.batch_matmul ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
                      outs(%C : memref<?x?x?xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 68aa5a85b5e0e6..3c68e7d394642e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1245,7 +1245,6 @@ func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -1259,7 +1258,6 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
                       ]
                       ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -1509,6 +1507,27 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: func @contract
+//       CHECK:   linalg.contract
+//  CHECK-SAME:     indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @contract(%arg0: memref<2x3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                  ]
+                  ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x5x7xf32>)
+                  outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @mmt4d
 func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> {
   // CHECK: %{{.+}} = linalg.mmt4d
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 1b8969bd115595..99cbb6647effbe 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -277,22 +277,33 @@ func.func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offs
 
 // -----
 
-
 func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
                 %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
-  -> (tensor<?x?x?xf32>)
+  -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
 {
   linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
                      outs(%c3: memref<?x?x?xf32>)
+  linalg.contract indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                                   affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                                   affine_map<(batch, m, n, k) -> (batch, m, n)>]
+                  ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
+                  outs(%c3: memref<?x?x?xf32>)
   %res1 = linalg.batch_matmul
                       ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
                      outs(%tc3: tensor<?x?x?xf32>)
                   -> tensor<?x?x?xf32>
-  return %res1 : tensor<?x?x?xf32>
+  %res2 = linalg.contract indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                                           affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                                           affine_map<(batch, m, n, k) -> (batch, m, n)>]
+                          ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+                          outs(%tc3: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
 }
 // CHECK-LABEL: func @named_ops
 //       CHECK:   linalg.batch_matmul
+//       CHECK:   linalg.contract
 //       CHECK:   linalg.batch_matmul
+//       CHECK:   linalg.contract
 
 // -----
 
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 8f13c690704572..1de1863d6deb13 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -37,6 +37,52 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK:       #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-NEXT:  #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-NEXT:  #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @matmul_as_contract_tensors(
+// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @matmul_as_contract_tensors(
+  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
+//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
+//      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xf32>) {
+//      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTD:.*]] = linalg.contract indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME:                                  outs(%[[sTC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
+//      CHECK:       %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<?x?xf32> into tensor<?x?xf32>
+//      CHECK:       scf.yield %[[TD]] : tensor<?x?xf32>
+//      CHECK:     scf.yield %[[TD2]] : tensor<?x?xf32>
+//      CHECK:   scf.yield %[[TD1]] : tensor<?x?xf32>
+  %0 = linalg.contract indexing_maps = [
+                         affine_map<(d0, d1, d2) -> (d0, d2)>,
+                         affine_map<(d0, d1, d2) -> (d2, d1)>,
+                         affine_map<(d0, d1, d2) -> (d0, d1)>
+                       ]
+                       ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+                       outs(%arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+
+//      CHECK: return %[[TD0]] : tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 3, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @matmul_tensors_with_size_zeros(
 // CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
 // CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 0d59dbba8940d4..2d30d62039642a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -54,6 +54,39 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @vectorize_matmul_as_contract
+// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32>
+// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32>
+// CHECK-SAME: %[[C:.*]]: tensor<24x25xf32>
+func.func @vectorize_matmul_as_contract(%arg0: tensor<24x12xf32>,
+                            %arg1: tensor<12x25xf32>,
+                            %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
+  // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
+  // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
+  // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+  // CHECK: vector.transfer_write %[[vR]], %[[C]]
+  %0 = linalg.contract indexing_maps = [
+                         affine_map<(m, n, k) -> (m, k)>,
+                         affine_map<(m, n, k) -> (k, n)>,
+                         affine_map<(m, n, k) -> (m, n)>
+                       ]
+                       ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>)
+                       outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @vectorize_copy_memref
 // CHECK-SAME: %[[A:.*]]: memref<100x100xf32>,
 // CHECK-SAME: %[[B:.*]]: memref<100x100xf32>

>From 08766651df532fa185b1b2eedce845687b365715 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 20 Jan 2025 15:46:48 -0800
Subject: [PATCH 02/11] Remove extraneous line that allowed adding a region

---
 mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index d4277bd34f3946..8c9a4e56ce00a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -742,7 +742,6 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
     AffineMapArrayAttr:$indexing_maps
   );
   let results = (outs Variadic<AnyShaped>:$result_tensors);
-  let regions = (region SizedRegion<1>:$combiner);
 
   let skipDefaultBuilders = 1;
   let builders = [

>From b497f548cf6be70b12cddafa56316568e30da39a Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 20 Jan 2025 16:14:07 -0800
Subject: [PATCH 03/11] Revert "Remove extraneous line that allowed adding a
 region"

This reverts commit 3d0d5b307d7326929e5a8da0942c51f4f133543e.

Actually, that line is needed...
---
 mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 8c9a4e56ce00a0..d4277bd34f3946 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -742,6 +742,7 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
     AffineMapArrayAttr:$indexing_maps
   );
   let results = (outs Variadic<AnyShaped>:$result_tensors);
+  let regions = (region SizedRegion<1>:$combiner);
 
   let skipDefaultBuilders = 1;
   let builders = [

>From e4df55d482930cf278d943d53b0b4a40266d049d Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 21 Jan 2025 16:12:37 -0800
Subject: [PATCH 04/11] Address @adam-smnk's comments

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  1 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 27 ++++++++++++-------
 mlir/test/Dialect/Linalg/invalid.mlir         |  4 +--
 3 files changed, 20 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index d4277bd34f3946..d4b3cd9172b6c3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -769,6 +769,7 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Declare/implement functions necessary for LinalgStructuredInterface.
+
     /// Infer iterator types for each dim in the domain of IndexingMaps.
     SmallVector<utils::IteratorType> getIteratorTypesArray();
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7b0a14d8f15b07..f035841f47ce80 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3746,22 +3746,29 @@ LogicalResult ContractOp::verify() {
       return failure(); // NOTE: checking lambda will emit error.
 
   bool hasContractingDim = false;
-  for (auto &&[inOccCount, outOccCount] : zip(inOccurrences, outOccurrences)) {
+  for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
+    size_t inOccCount = inOccurrences[dimIndex];
+    size_t outOccCount = outOccurrences[dimIndex];
+
     hasContractingDim |= inOccCount == 2 && outOccCount == 0;
 
     if (inOccCount == 0)
-      return emitError("iteration space dim not used by either input");
-
-    // NB: A dim which occurs for only one input operand and not for the output.
-    //     In terms of einsum semantics, such dims have a sensible meaning -
-    //     namely an additional reduction per such dim - though this can also
-    //     always be expressed through an additional op. Additionally, at time
-    //     of writing, vector.contract's verifier accepts these dims but many of
-    //     its lowerings do not handle these kinds of dims. Hence...
+      return emitError() << "iteration space dim at index " << dimIndex
+                         << " not used by either input";
+
+    // NB: We disallow a dim which occurs for only one input operand and not
+    //     for the output. In terms of einsum semantics such dims have a
+    //     sensible meaning - namely an additional reduction per each such dim.
+    //     By contrast, the ContractionOpInterface does not know about this
+    //     iter type - cf. inferContractionDims' supported dim kinds. Similarly,
+    //     while vector.contract's verifier accepts dims of this kind many of
+    //     its lowerings give up on encountering these dims.
     // TODO: Remove following once we have comprehensive support for input-only
     //       reduction dims, at both the linalg- and vector-dialect levels.
     if (inOccCount == 1 && outOccCount != 1)
-      return emitError("iter type of dim is not one of M, N, K or batch");
+      return emitError()
+             << "iteration space dim at index " << dimIndex
+             << " is neither a contracting dim nor of parallel iteration type";
   }
 
   if (!hasContractingDim)
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index d6c8f38586b923..6f4b1c0cf045da 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -617,7 +617,7 @@ func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: ten
 // -----
 
 func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
-  // expected-error @+1 {{iteration space dim not used by either input}}
+  // expected-error @+1 {{iteration space dim at index 3 not used by either input}}
   linalg.contract indexing_maps = [
                     affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
                     affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
@@ -631,7 +631,7 @@ func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: t
 // -----
 
 func.func @unused_iteration_space_dim_contraction(%lhs: tensor<8x4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
-  // expected-error @+1 {{iter type of dim is not one of M, N, K or batch}}
+  // expected-error @+1 {{iteration space dim at index 3 is neither a contracting dim nor of parallel iteration type}}
   linalg.contract indexing_maps = [
                     affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
                     affine_map<(d0, d1, d2, d3) -> (d2, d1)>,

>From d41bbe7c4857c14cd476a6e10a26ab6cee6c3d58 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 23 Jan 2025 12:29:25 -0800
Subject: [PATCH 05/11] Address Adam's comments, round 2

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 21 ++--
 .../Dialect/Linalg/generalize-named-ops.mlir  | 25 +++++
 mlir/test/Dialect/Linalg/invalid.mlir         |  2 +-
 mlir/test/Dialect/Linalg/named-ops.mlir       | 99 +++++++++++++++++++
 4 files changed, 136 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f035841f47ce80..87c5c0d4df8998 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3622,15 +3622,15 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
 
 SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
   AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
-  /// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
-  /// domains are all the same, and each implements a projected permutation.
-  /// Each dim in the domain must occur for at least one operand and is
-  /// classified as either batch, N-like, M-like, or K-like. Only the latter
-  /// corresponds to a reduction _and_ it is the only dim-kind which does not
-  /// occur for the output operand. We use this fact for fast inference:
+  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+  // domains are all the same, and each implements a projected permutation.
+  // Each iteration space dim must occur for at least one operand and either
+  // takes part in a contraction/reduction or else has parallel iteration type.
+  // We have that a dim is a contraction/reduction dim if and only if the dim
+  // occurs for the output operand. We use this fact for fast inference:
   // NB: In case we allow dims to occur solely for one input, the above still
   //     holds: per the einsum semantics, these are reduction dims as well.
-  auto dimsInOutput = SmallVector<bool>(outAffineMap.getNumDims(), false);
+  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
   for (auto result : outAffineMap.getResults()) {
     auto dimExpr = dyn_cast<AffineDimExpr>(result);
     assert(dimExpr && "affine_map is a projected permutation");
@@ -3741,9 +3741,10 @@ LogicalResult ContractOp::verify() {
 
   for (auto &&[affineMap, operandType, isInput] :
        llvm::zip(getIndexingMapsArray(), getOperandTypes(),
-                 SmallVector<bool>{true, true, false}))
+                 SmallVector<bool>{true, true, false})) {
     if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
       return failure(); // NOTE: checking lambda will emit error.
+  }
 
   bool hasContractingDim = false;
   for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
@@ -3752,9 +3753,9 @@ LogicalResult ContractOp::verify() {
 
     hasContractingDim |= inOccCount == 2 && outOccCount == 0;
 
-    if (inOccCount == 0)
+    if (inOccCount == 0 && outOccCount == 0)
       return emitError() << "iteration space dim at index " << dimIndex
-                         << " not used by either input";
+                         << " not used to access any operand";
 
     // NB: We disallow a dim which occurs for only one input operand and not
     //     for the output. In terms of einsum semantics such dims have a
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index f7e570d5ce38f0..3b21467ca45fef 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1164,3 +1164,28 @@ func.func @contract_dot(%arg0: memref<9xf32>, %arg1: memref<9xf32>, %arg2: memre
                   outs(%arg2: memref<f32>)
   return
 }
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @contract_matmul_bcast_a_b
+// CHECK-SAME:                           (%[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>, %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d2)>,
+                    affine_map<(d0, d1, d2) -> (d2)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+  return
+}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 6f4b1c0cf045da..eeefb0af9aecef 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -617,7 +617,7 @@ func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: ten
 // -----
 
 func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
-  // expected-error @+1 {{iteration space dim at index 3 not used by either input}}
+  // expected-error @+1 {{iteration space dim at index 3 not used for any operand}}
   linalg.contract indexing_maps = [
                     affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
                     affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 3c68e7d394642e..6defa827b77c48 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1528,6 +1528,105 @@ func.func @contract(%arg0: memref<2x3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: m
 
 // -----
 
+func.func @contract_matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d2)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @contract_matmul_bcast_a
+//       CHECK:   linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @contract_matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2) -> (d2)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @contract_matmul_bcast_b
+//       CHECK:   linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @contract_matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d2)>,
+                    affine_map<(d0, d1, d2) -> (d2)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @contract_matmul_bcast_a_b
+//       CHECK:     linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
+//       CHECK:       ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5xf32>)
+//       CHECK:       outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @contract_matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                  affine_map<(d0, d1, d2) -> (d2)>,
+                  affine_map<(d0, d1, d2) -> (d1, d2)>,
+                  affine_map<(d0, d1, d2) -> (d0, d1)>
+                ]
+                ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @contract_matmul_bcast_a_transpose_b
+//       CHECK:     linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+//       CHECK:       ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<7x5xf32>)
+//       CHECK:       outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @contract_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @contract_matmul_bcast_b_transpose_a
+//       CHECK:     linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+//       CHECK:       ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5xf32>)
+//       CHECK:       outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
 // CHECK-LABEL: func @mmt4d
 func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> {
   // CHECK: %{{.+}} = linalg.mmt4d

>From a676cf8bb43f35d52fbe5d6b799eb86fdb253525 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 24 Jan 2025 00:34:43 -0800
Subject: [PATCH 06/11] Fix up error message in test case

---
 mlir/test/Dialect/Linalg/invalid.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index eeefb0af9aecef..0f81f4f723b938 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -617,7 +617,7 @@ func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: ten
 // -----
 
 func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
-  // expected-error @+1 {{iteration space dim at index 3 not used for any operand}}
+  // expected-error @+1 {{iteration space dim at index 3 not used to access any operand}}
   linalg.contract indexing_maps = [
                     affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
                     affine_map<(d0, d1, d2, d3) -> (d2, d1)>,

>From 5f37e769870ca5f5e47cc05568da7aefd21f4705 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 27 Jan 2025 14:33:42 -0800
Subject: [PATCH 07/11] Address @banach-space's comments

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |   2 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  24 +-
 .../Dialect/Linalg/generalize-named-ops.mlir  | 241 +++++++++---------
 .../generalize-named-polymorphic-ops.mlir     |  74 +++---
 mlir/test/Dialect/Linalg/invalid.mlir         | 114 ++++-----
 mlir/test/Dialect/Linalg/loops.mlir           |  16 +-
 mlir/test/Dialect/Linalg/named-ops.mlir       | 163 ++++++------
 mlir/test/Dialect/Linalg/roundtrip.mlir       |  21 +-
 mlir/test/Dialect/Linalg/tile-tensors.mlir    |  21 +-
 .../Linalg/transform-op-vectorize.mlir        |  33 ---
 .../Linalg/vectorization-with-patterns.mlir   |  32 +++
 11 files changed, 378 insertions(+), 363 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index d4b3cd9172b6c3..5bc04ac4ceb3eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -688,7 +688,7 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
                AttrSizedOperandSegments,
                LinalgContractionOpInterface]> {
   let summary = [{
-    Perform a contraction on two inputs, accumulating on top of a third.
+    Perform a contraction on two inputs, accumulating into the third.
   }];
   let description = [{
     The semantics of contracting inputs `A` and `B` on top of `C` to produce
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 87c5c0d4df8998..c4edea2d8c56a4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3535,22 +3535,16 @@ FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
   if (parser.parseOptionalKeyword("indexing_maps"))
     return {nullptr}; // Success in case indexing_maps was not provided.
 
-  SmallVector<Attribute> indexingMaps;
-
-  auto parseIndexingMap = [&]() -> ParseResult {
-    AffineMapAttr affineMapAttr;
-    if (parser.parseAttribute(affineMapAttr))
-      return failure();
-    indexingMaps.push_back(affineMapAttr);
-    return success();
-  };
-
-  if (parser.parseEqual() ||
-      parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
-                                     parseIndexingMap))
+  ArrayAttr arrayAttr;
+  if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
     return failure();
 
-  return parser.getBuilder().getArrayAttr(indexingMaps);
+  if (llvm::any_of(arrayAttr,
+                   [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
+    return parser.emitError(parser.getCurrentLocation())
+           << "element of indexing_maps array is not an affine_map";
+
+  return arrayAttr;
 }
 
 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -3680,7 +3674,7 @@ ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
   FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
   if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
     return parser.emitError(parser.getCurrentLocation(),
-                            "expected 'indexing_map' attribute");
+                            "expected 'indexing_maps' attribute");
   result.addAttribute("indexing_maps", *indexingMapsAttr);
 
   return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 3b21467ca45fef..7ee25ac25ac6c9 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -999,193 +999,206 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
 
 // -----
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-// CHECK-LABEL:   func.func @contract_matmul(
-// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<3x5xf32>,
-// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<5x7xf32>,
-// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
-
-// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
-// CHECK-NEXT:     ^{{.+}}(
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @contract_matmul(
+// CHECK-SAME:      %[[A:.*]]: memref<3x5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<5x7xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
 // CHECK-NEXT:      arith.mulf
 // CHECK-NEXT:      arith.addf
 // CHECK-NEXT:      linalg.yield
 
 func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2) -> (d0, d2)>,
-                    affine_map<(d0, d1, d2) -> (d2, d1)>,
-                    affine_map<(d0, d1, d2) -> (d0, d1)>
-                  ]
-                  ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
-                  outs(%arg2: memref<3x7xf32>)
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
+      outs(%arg2: memref<3x7xf32>)
 
   return
 }
 
 // -----
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
-// CHECK-LABEL:   func.func @contract_matmul_transpose_a_b(
-// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
-// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
-// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK-LABEL: func.func @contract_matmul_transpose_a_b(
+// CHECK-SAME:      %[[A:.*]]: memref<5x3xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<7x5xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
 
-// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
-// CHECK-NEXT:     ^{{.+}}(
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
 // CHECK-NEXT:      arith.mulf
 // CHECK-NEXT:      arith.addf
 // CHECK-NEXT:      linalg.yield
 
 func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2) -> (d2, d0)>,
-                    affine_map<(d0, d1, d2) -> (d1, d2)>,
-                    affine_map<(d0, d1, d2) -> (d0, d1)>
-                  ]
-                  ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
-                  outs(%arg2: memref<3x7xf32>)
-
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+      outs(%arg2: memref<3x7xf32>)
   return
 }
 
 // -----
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 
-// CHECK-LABEL:   func.func @contract_batch_matmul(
-// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<9x3x5xf32>,
-// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<9x5x7xf32>,
-// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<9x3x7xf32>) {
+// CHECK-LABEL: func.func @contract_batch_matmul(
+// CHECK-SAME:      %[[A:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<9x3x7xf32>) {
 
-// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-NEXT:     ^{{.+}}(
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
 // CHECK-NEXT:      arith.mulf
 // CHECK-NEXT:      arith.addf
 // CHECK-NEXT:      linalg.yield
 
 func.func @contract_batch_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<9x3x7xf32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-                  ]
-                  ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
-                  outs(%arg2: memref<9x3x7xf32>)
-
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+      ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
+      outs(%arg2: memref<9x3x7xf32>)
   return
 }
 
 // -----
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-// CHECK-LABEL:   func.func @contract_batch_reduce_matmul(
-// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<9x3x5xf32>,
-// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<9x5x7xf32>,
-// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK-LABEL: func.func @contract_batch_reduce_matmul(
+// CHECK-SAME:      %[[A:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
 
-// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
-// CHECK-NEXT:     ^{{.+}}(
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["reduction", "parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
 // CHECK-NEXT:      arith.mulf
 // CHECK-NEXT:      arith.addf
 // CHECK-NEXT:      linalg.yield
 
-func.func @contract_batch_reduce_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<3x7xf32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                    affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                  ]
-                  ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
-                  outs(%arg2: memref<3x7xf32>)
+#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @contract_batch_reduce_matmul(
+    %A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<3x7xf32>) {
+  linalg.contract
+      indexing_maps = [#accessA, #accessB, #accessC]
+      ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>)
+      outs(%C: memref<3x7xf32>)
   return
 }
 
 // -----
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-// CHECK-LABEL:   func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
-// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<9x5x3xf32>,
-// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<9x7x5xf32>,
-// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK-LABEL: func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
+// CHECK-SAME:      %[[A:.*]]: memref<9x5x3xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<9x7x5xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
 
-// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
-// CHECK-NEXT:     ^{{.+}}(
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["reduction", "parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
 // CHECK-NEXT:      arith.mulf
 // CHECK-NEXT:      arith.addf
 // CHECK-NEXT:      linalg.yield
 
-func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(%arg0: memref<9x5x3xf32>, %arg1: memref<9x7x5xf32>, %arg2: memref<3x7xf32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
-                    affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
-                    affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                  ]
-                  ins(%arg0, %arg1 : memref<9x5x3xf32>, memref<9x7x5xf32>)
-                  outs(%arg2: memref<3x7xf32>)
+#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
+    %A: memref<9x5x3xf32>, %B: memref<9x7x5xf32>, %C: memref<3x7xf32>) {
+  linalg.contract
+      indexing_maps = [#accessA, #accessB, #accessC]
+      ins(%A, %B : memref<9x5x3xf32>, memref<9x7x5xf32>)
+      outs(%C: memref<3x7xf32>)
   return
 }
 
 // -----
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> ()>
+// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0) -> ()>
 
-// CHECK-LABEL:   func.func @contract_dot
-// CHECK-SAME:                           (%[[VAL_0:.*]]: memref<9xf32>, %[[VAL_1:.*]]: memref<9xf32>, %[[VAL_2:.*]]: memref<f32>) {
+// CHECK-LABEL: func.func @contract_dot(
+// CHECK-SAME:      %[[A:.*]]: memref<9xf32>, %[[B:.*]]: memref<9xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<f32>) {
 
-// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["reduction"]}
-// CHECK-NEXT:     ^{{.+}}(
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["reduction"]
+// CHECK-NEXT:    ^{{.+}}(
 // CHECK-NEXT:      arith.mulf
 // CHECK-NEXT:      arith.addf
 // CHECK-NEXT:      linalg.yield
 
-func.func @contract_dot(%arg0: memref<9xf32>, %arg1: memref<9xf32>, %arg2: memref<f32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0) -> (d0)>,
-                    affine_map<(d0) -> (d0)>,
-                    affine_map<(d0) -> ()>
-                  ]
-                  ins(%arg0, %arg1 : memref<9xf32>, memref<9xf32>)
-                  outs(%arg2: memref<f32>)
+#accessAB = affine_map<(d0) -> (d0)>
+#accessC = affine_map<(d0) -> ()>
+func.func @contract_dot(
+    %A: memref<9xf32>, %B: memref<9xf32>, %C: memref<f32>) {
+  linalg.contract
+      indexing_maps = [#accessAB, #accessAB, #accessC]
+      ins(%A, %B : memref<9xf32>, memref<9xf32>)
+      outs(%C: memref<f32>)
   return
 }
 
 // -----
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
-// CHECK-LABEL:   func.func @contract_matmul_bcast_a_b
-// CHECK-SAME:                           (%[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>, %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_b(
+// CHECK-SAME:      %[[A:.*]]: memref<5xf32>, %[[B:.*]]: memref<5xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
 
-// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
-// CHECK-NEXT:     ^{{.+}}(
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
 // CHECK-NEXT:      arith.mulf
 // CHECK-NEXT:      arith.addf
 // CHECK-NEXT:      linalg.yield
 
-func.func @contract_matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2) -> (d2)>,
-                    affine_map<(d0, d1, d2) -> (d2)>,
-                    affine_map<(d0, d1, d2) -> (d0, d1)>
-                  ]
-                  ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>)
-                  outs(%arg2: memref<3x7xf32>)
+#accessAB = affine_map<(d0, d1, d2) -> (d2)>
+#accessC = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @contract_matmul_bcast_a_b(
+    %A: memref<5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+  linalg.contract
+      indexing_maps = [#accessAB, #accessAB, #accessC]
+      ins(%A, %B : memref<5xf32>, memref<5xf32>)
+      outs(%C: memref<3x7xf32>)
   return
 }
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 9acb7562f96ee0..a192a2f72b286a 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -120,53 +120,55 @@ func.func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B
 
 // -----
 
-func.func @generalize_matmul_as_contraction_tensor_f16f64f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
-  %0 = linalg.contract indexing_maps = [
-                         affine_map<(d0, d1, d2) -> (d0, d2)>,
-                         affine_map<(d0, d1, d2) -> (d2, d1)>,
-                         affine_map<(d0, d1, d2) -> (d0, d1)>
-                       ]
-                       ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
-                       outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+func.func @generalize_matmul_as_contraction_tensor_f16f64f32(
+    %A : tensor<16x8xf16>,
+    %B: tensor<8x32xf64>,
+    %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                      affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
+      outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
   return %0: tensor<16x32xf32>
 }
 
 // CHECK-LABEL: @generalize_matmul_as_contraction_tensor_f16f64f32
-// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+// CHECK:         ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
 // Verify floating point extension and truncation.
-// CHECK-NEXT:   %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
-// CHECK-NEXT:   %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
-// CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
-// CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
-// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
-// CHECK-NEXT: -> tensor<16x32xf32>
-
-// -----
-
-func.func @generalize_matmul_as_contract_with_ext_and_trunc(%arg0: tensor<24x12xf16>,
-                            %arg1: tensor<12x25xf16>,
-                            %arg2: tensor<24x25xf32>) -> tensor<24x25xf16> {
-  %0 = linalg.contract indexing_maps = [
-                         affine_map<(m, n, k) -> (m, k)>,
-                         affine_map<(m, n, k) -> (k, n)>,
-                         affine_map<(m, n, k) -> (m, n)>
-                       ]
-                       ins(%arg0, %arg1 : tensor<24x12xf16>, tensor<12x25xf16>)
-                       outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+// CHECK-NEXT:      %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT:      %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
+// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT:      %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+// CHECK-NEXT:    -> tensor<16x32xf32>
+
+// -----
+
+func.func @generalize_matmul_as_contract_with_ext_and_trunc(
+    %A: tensor<24x12xf16>,
+    %B: tensor<12x25xf16>,
+    %C: tensor<24x25xf32>) -> tensor<24x25xf16> {
+  %0 = linalg.contract
+      indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+                       affine_map<(m, n, k) -> (k, n)>,
+                       affine_map<(m, n, k) -> (m, n)>]
+      ins(%A, %B : tensor<24x12xf16>, tensor<12x25xf16>)
+      outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
   %1 = arith.truncf %0 : tensor<24x25xf32> to tensor<24x25xf16>
   func.return %1 : tensor<24x25xf16>
 }
 
 // CHECK-LABEL: @generalize_matmul_as_contract_with_ext_and_trunc
-// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+// CHECK:         ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
 // Verify floating point extension and truncation.
-// CHECK-NEXT:   %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
-// CHECK-NEXT:   %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
-// CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
-// CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
-// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
-// CHECK-NEXT: -> tensor<24x25xf32>
-// CHECK-NEXT:   %[[RES:.+]] = arith.truncf {{.*}} : tensor<24x25xf32> to tensor<24x25xf16>
+// CHECK-NEXT:      %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT:      %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
+// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT:      %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+// CHECK-NEXT:    -> tensor<24x25xf32>
+// CHECK-NEXT:    %[[RES:.+]] = arith.truncf {{.*}} : tensor<24x25xf32> to tensor<24x25xf16>
 
 // -----
 
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 0f81f4f723b938..633fa2cf236823 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -547,98 +547,98 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
 
 // -----
 
-func.func @invalid_indexing_maps_placement_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
-  // expected-error @+1 {{custom op 'linalg.contract' expected 'indexing_map' attribute}}
-  linalg.contract ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
-                  outs(%init : tensor<4x64xf32>)
-                  indexing_maps = [
-                    affine_map<(d0, d1, d2) -> (d0, d2)>,
-                    affine_map<(d0, d1, d2) -> (d2, d1)>,
-                    affine_map<(d0, d1, d2) -> (d0, d1)>
-                  ]
+func.func @invalid_indexing_maps_placement_contraction(
+    %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+2 {{custom op 'linalg.contract' expected 'indexing_maps' attribute}}
+  linalg.contract
+      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>)
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
   return
 }
 
 // -----
 
-func.func @invalid_affine_map_in_indexing_maps_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+func.func @invalid_affine_map_in_indexing_maps_contraction(
+    %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
   // expected-error @+1 {{provided affine_map is not a projected permutation}}
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2) -> (d0 + d2, d2)>,
-                    affine_map<(d0, d1, d2) -> (d2, d1)>,
-                    affine_map<(d0, d1, d2) -> (d0, d1)>
-                  ]
-                  ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
-                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d2, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
   return
 }
 
 // -----
 
-func.func @differing_iteration_space_of_affine_maps_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+func.func @differing_iteration_space_of_affine_maps_contraction(
+    %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
   // expected-error @+1 {{iteration spaces of provided affine_maps differ}}
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2) -> (d0, d2)>,
-                    affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
-                    affine_map<(d0, d1, d2) -> (d0, d1)>
-                  ]
-                  ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
-                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
   return
 }
 
 // -----
 
-func.func @mismatched_ranks_affine_map_and_operand_contraction(%lhs: tensor<4x1x2xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+func.func @mismatched_ranks_affine_map_and_operand_contraction(
+    %lhs: tensor<4x1x2xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
   // expected-error @+1 {{ranks of shaped operand and co-domain of corresponding affine_map differ}}
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2) -> (d0, d2)>,
-                    affine_map<(d0, d1, d2) -> (d2, d1)>,
-                    affine_map<(d0, d1, d2) -> (d0, d1)>
-                  ]
-                  ins(%lhs, %rhs : tensor<4x1x2xf32>, tensor<1x64xf32>)
-                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%lhs, %rhs : tensor<4x1x2xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
   return
 }
 // -----
 
-func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: tensor<4x64xf32>, %init: tensor<4x64xf32>) {
+func.func @mismatch_type_affine_map_and_operand_contraction(
+    %lhs: f32, %rhs: tensor<4x64xf32>, %init: tensor<4x64xf32>) {
   // expected-error @+1 {{affine_map specifies shaped access while operand has non-shaped type}}
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1) -> (d0)>,
-                    affine_map<(d0, d1) -> (d0, d1)>,
-                    affine_map<(d0, d1) -> (d0, d1)>
-                  ]
-                  ins(%lhs, %rhs : f32, tensor<4x64xf32>)
-                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+                       affine_map<(d0, d1) -> (d0, d1)>,
+                       affine_map<(d0, d1) -> (d0, d1)>]
+      ins(%lhs, %rhs : f32, tensor<4x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
   return
 }
 
 // -----
 
-func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+func.func @unused_iteration_space_dim_contraction(
+    %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
   // expected-error @+1 {{iteration space dim at index 3 not used to access any operand}}
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
-                    affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
-                    affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-                  ]
-                  ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
-                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1)>]
+      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
   return
 }
 
 // -----
 
-func.func @unused_iteration_space_dim_contraction(%lhs: tensor<8x4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+func.func @unused_iteration_space_dim_contraction(
+    %lhs: tensor<8x4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
   // expected-error @+1 {{iteration space dim at index 3 is neither a contracting dim nor of parallel iteration type}}
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
-                    affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
-                    affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-                  ]
-                  ins(%lhs, %rhs : tensor<8x4x1xf32>, tensor<1x64xf32>)
-                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1)>]
+      ins(%lhs, %rhs : tensor<8x4x1xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
   return
 }
 
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 0a83750b81dea4..efe8010cffc916 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -627,14 +627,14 @@ func.func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f
 //----------------------------------------------------------------------------//
 // Named ops to loops.
 //----------------------------------------------------------------------------//
-func.func @batch_reduce_matmul_as_contract(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                    affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                  ]
-                  ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
-                  outs(%C : memref<?x?xf32>)
+func.func @batch_reduce_matmul_as_contract(
+    %A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+      ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%C : memref<?x?xf32>)
   return
 }
 // CHECK-LABEL: @batch_reduce_matmul_as_contract
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 6defa827b77c48..06e0bb801b5444 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1515,115 +1515,120 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
 //  CHECK-SAME:     indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x5x7xf32>)
 //  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
-func.func @contract(%arg0: memref<2x3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-                  ]
-                  ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x5x7xf32>)
-                  outs(%arg2: memref<2x3x7xf32>)
+func.func @contract(
+    %A: memref<2x3x5xf32>, %B: memref<2x5x7xf32>, %C: memref<2x3x7xf32>) {
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+      ins(%A, %B : memref<2x3x5xf32>, memref<2x5x7xf32>)
+      outs(%C: memref<2x3x7xf32>)
   return
 }
 
 // -----
 
-func.func @contract_matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2) -> (d2)>,
-                    affine_map<(d0, d1, d2) -> (d2, d1)>,
-                    affine_map<(d0, d1, d2) -> (d0, d1)>
-                  ]
-                  ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @contract_matmul_bcast_a
+func.func @contract_matmul_bcast_a(%A: memref<5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) {
+// CHECK:  linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<5xf32>, memref<5x7xf32>)
+      outs(%C: memref<3x7xf32>)
   return
 }
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL: func @contract_matmul_bcast_a
-//       CHECK:   linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
-//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
-//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
-
 // -----
 
-func.func @contract_matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2) -> (d0, d2)>,
-                    affine_map<(d0, d1, d2) -> (d2)>,
-                    affine_map<(d0, d1, d2) -> (d0, d1)>
-                  ]
-                  ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @contract_matmul_bcast_b
+func.func @contract_matmul_bcast_b(%A: memref<3x5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+// CHECK:  linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<3x5xf32>, memref<5xf32>)
+      outs(%C: memref<3x7xf32>)
   return
 }
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL: func @contract_matmul_bcast_b
-//       CHECK:   linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
-//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
-//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
-
 // -----
 
-func.func @contract_matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
-  linalg.contract indexing_maps = [
-                    affine_map<(d0, d1, d2) -> (d2)>,
-                    affine_map<(d0, d1, d2) -> (d2)>,
-                    affine_map<(d0, d1, d2) -> (d0, d1)>
-                  ]
-                  ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
+func.func @contract_matmul_bcast_a_b(
+    %A: memref<5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+// CHECK:  linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_A]], #[[$ACCESS_B]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<5xf32>, memref<5xf32>)
+      outs(%C: memref<3x7xf32>)
   return
 }
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL:   func.func @contract_matmul_bcast_a_b
-//       CHECK:     linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
-//       CHECK:       ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5xf32>)
-//       CHECK:       outs(%{{.+}} : memref<3x7xf32>)
-
 // -----
 
-func.func @contract_matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
-  linalg.contract indexing_maps = [
-                  affine_map<(d0, d1, d2) -> (d2)>,
-                  affine_map<(d0, d1, d2) -> (d1, d2)>,
-                  affine_map<(d0, d1, d2) -> (d0, d1)>
-                ]
-                ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_transpose_b
+func.func @contract_matmul_bcast_a_transpose_b(
+    %A: memref<5xf32>, %B: memref<7x5xf32>, %C: memref<3x7xf32>) {
+// CHECK:  linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<7x5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<5xf32>, memref<7x5xf32>)
+      outs(%C: memref<3x7xf32>)
   return
 }
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL:   func.func @contract_matmul_bcast_a_transpose_b
-//       CHECK:     linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
-//       CHECK:       ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<7x5xf32>)
-//       CHECK:       outs(%{{.+}} : memref<3x7xf32>)
 
 // -----
 
-func.func @contract_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
-  linalg.contract indexing_maps = [
-                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @contract_matmul_bcast_b_transpose_a
+func.func @contract_matmul_bcast_b_transpose_a(%A: memref<5x3xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+// CHECK:      linalg.contract
+// CHECK-SAME:     indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5xf32>)
+// CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
                        affine_map<(d0, d1, d2) -> (d2)>,
-                       affine_map<(d0, d1, d2) -> (d0, d1)>
-                     ]
-                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<5x3xf32>, memref<5xf32>)
+      outs(%C: memref<3x7xf32>)
   return
 }
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL:   func.func @contract_matmul_bcast_b_transpose_a
-//       CHECK:     linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
-//       CHECK:       ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5xf32>)
-//       CHECK:       outs(%{{.+}} : memref<3x7xf32>)
 
 // -----
 
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 99cbb6647effbe..dc556761b09e56 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -277,26 +277,27 @@ func.func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offs
 
 // -----
 
+#accessA = affine_map<(batch, m, n, k) -> (batch, m, k)>
+#accessB = affine_map<(batch, m, n, k) -> (batch, k, n)>
+#accessC = affine_map<(batch, m, n, k) -> (batch, m, n)>
 func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
                 %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
   -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
 {
   linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
                      outs(%c3: memref<?x?x?xf32>)
-  linalg.contract indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
-                                   affine_map<(batch, m, n, k) -> (batch, k, n)>,
-                                   affine_map<(batch, m, n, k) -> (batch, m, n)>]
-                  ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
-                  outs(%c3: memref<?x?x?xf32>)
+  linalg.contract
+      indexing_maps = [#accessA, #accessB, #accessC]
+      ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%c3: memref<?x?x?xf32>)
   %res1 = linalg.batch_matmul
                       ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
                      outs(%tc3: tensor<?x?x?xf32>)
                   -> tensor<?x?x?xf32>
-  %res2 = linalg.contract indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
-                                           affine_map<(batch, m, n, k) -> (batch, k, n)>,
-                                           affine_map<(batch, m, n, k) -> (batch, m, n)>]
-                          ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
-                          outs(%tc3: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %res2 = linalg.contract
+      indexing_maps = [#accessA, #accessB, #accessC]
+      ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+      outs(%tc3: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
   return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
 }
 // CHECK-LABEL: func @named_ops
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 1de1863d6deb13..557233d8aa3ec4 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -40,13 +40,16 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 // CHECK-NEXT:  #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
 // CHECK-NEXT:  #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+#access_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                affine_map<(d0, d1, d2) -> (d2, d1)>,
+                affine_map<(d0, d1, d2) -> (d0, d1)>]
 
 // CHECK-LABEL: func @matmul_as_contract_tensors(
 // CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
 // CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
 // CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
 func.func @matmul_as_contract_tensors(
-  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+  %A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
     -> tensor<?x?xf32> {
 //      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
 //      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
@@ -54,19 +57,17 @@ func.func @matmul_as_contract_tensors(
 //      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
 //      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
 //      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
-//      CHECK:       %[[sTD:.*]] = linalg.contract indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME:                                  outs(%[[sTC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
+//      CHECK:       %[[sTD:.*]] = linalg.contract
+// CHECK-SAME:          indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:          ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME:          outs(%[[sTC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
 //      CHECK:       %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<?x?xf32> into tensor<?x?xf32>
 //      CHECK:       scf.yield %[[TD]] : tensor<?x?xf32>
 //      CHECK:     scf.yield %[[TD2]] : tensor<?x?xf32>
 //      CHECK:   scf.yield %[[TD1]] : tensor<?x?xf32>
-  %0 = linalg.contract indexing_maps = [
-                         affine_map<(d0, d1, d2) -> (d0, d2)>,
-                         affine_map<(d0, d1, d2) -> (d2, d1)>,
-                         affine_map<(d0, d1, d2) -> (d0, d1)>
-                       ]
-                       ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
-                       outs(%arg2: tensor<?x?xf32>)
+  %0 = linalg.contract indexing_maps = #access_maps
+                       ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                       outs(%C: tensor<?x?xf32>)
     -> tensor<?x?xf32>
 
 //      CHECK: return %[[TD0]] : tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 2d30d62039642a..0d59dbba8940d4 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -54,39 +54,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: @vectorize_matmul_as_contract
-// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32>
-// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32>
-// CHECK-SAME: %[[C:.*]]: tensor<24x25xf32>
-func.func @vectorize_matmul_as_contract(%arg0: tensor<24x12xf32>,
-                            %arg1: tensor<12x25xf32>,
-                            %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
-  // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
-  // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
-  // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
-  // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
-  // CHECK: vector.transfer_write %[[vR]], %[[C]]
-  %0 = linalg.contract indexing_maps = [
-                         affine_map<(m, n, k) -> (m, k)>,
-                         affine_map<(m, n, k) -> (k, n)>,
-                         affine_map<(m, n, k) -> (m, n)>
-                       ]
-                       ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>)
-                       outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
-  func.return %0 : tensor<24x25xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
-    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
 // CHECK-LABEL: @vectorize_copy_memref
 // CHECK-SAME: %[[A:.*]]: memref<100x100xf32>,
 // CHECK-SAME: %[[B:.*]]: memref<100x100xf32>
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index b688a677500c22..3e10b8402af4e0 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -82,6 +82,38 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @matmul_as_contract
+// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32>
+// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32>
+// CHECK-SAME: %[[C:.*]]: tensor<24x25xf32>
+func.func @matmul_as_contract(%A: tensor<24x12xf32>,
+                              %B: tensor<12x25xf32>,
+                              %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
+  // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
+  // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
+  // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+  // CHECK: vector.transfer_write %[[vR]], %[[C]]
+  %0 = linalg.contract
+      indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+                       affine_map<(m, n, k) -> (k, n)>,
+                       affine_map<(m, n, k) -> (m, n)>]
+      ins(%A, %B : tensor<24x12xf32>, tensor<12x25xf32>)
+      outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 #matmul_trait = {
   indexing_maps = [
     affine_map<(m, n, k) -> (m, k)>,

>From 49557e88dd54e598c7c8936313b8a755a099efbf Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 27 Jan 2025 14:41:19 -0800
Subject: [PATCH 08/11] Add todo as pointed out by @banach-space

---
 mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 3e10b8402af4e0..5ae3f893c2e739 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -108,6 +108,7 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    // TODO: also tests the other available vectorization strategies
     transform.yield
   }
 }

>From cd045685d20aa75e4cea00cd66a83e9f3f6512d9 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 28 Jan 2025 05:47:46 -0800
Subject: [PATCH 09/11] Further fixes per @banach-space's comments

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td   | 17 ++++++++++-------
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp       | 13 +++++++++++--
 .../Dialect/Linalg/generalize-named-ops.mlir   | 18 +++++++++---------
 .../generalize-named-polymorphic-ops.mlir      |  2 +-
 mlir/test/Dialect/Linalg/invalid.mlir          |  3 ++-
 mlir/test/Dialect/Linalg/named-ops.mlir        |  3 ---
 6 files changed, 33 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5bc04ac4ceb3eb..0f1dde270fc7f7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -698,9 +698,9 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
 
     where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
     dimension identifiers (meant to range over valid indices), corresponding to
-    the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
-    and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
-    the dimensions in the set `dims`.
+    the co-domains of the mandatory (projected permutation) `indexing_maps` of
+    `A`, `B` and `C`, respectively. `SUM_{dims}` means reduce over all valid
+    indices for the dimensions in the set `dims`.
 
     The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
     domain of each of the `affine_map`s. Like for einsums, the iteration type of
@@ -719,21 +719,24 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
     `n` and `b` are of parallel iteration-type) and gets represented as:
 
     ```
-    %0 = linalg.contract
+    %D = linalg.contract
         indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
                          affine_map<(batch, m, n, k) -> (batch, k, n)>,
                          affine_map<(batch, m, n, k) -> (batch, m, n)>]
-        ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
-        outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+        ins(%A, %B: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+        outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
     ```
 
     Note that by permuting the dims in the co-domains of the `affine_map`s, we
     can apply arbitrary transposes to the inputs and output. Similarly,
     arbitrary broadcasts can be achieved through leaving out dims on either
-    input operand.
+    input operand - these dims' inferred iter type will be parallel.
 
     Numeric casting is performed on the operands to the inner multiplication,
     promoting them to the same data type as the accumulator/output.
+
+    TODO: Allow control over the combining/accumulating op and possibly the
+          multiplication op.
   }];
 
   let arguments = (ins
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c4edea2d8c56a4..7c710d955535f6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3693,10 +3693,18 @@ void ContractOp::print(OpAsmPrinter &p) {
 
 LogicalResult ContractOp::verify() {
   int iterationSpaceDims = -1;
-  // Maps iter space dim (as index) to num of occurrences in inputs and output.
+  // Map iter space dims to #occurrences in inputs' and output's affine_maps:
+  // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
+  // access an input operand (so occurrence count can be at most 2) and
+  // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
   SmallVector<size_t> inOccurrences;
   SmallVector<size_t> outOccurrences;
 
+  // For each operand's affine_map and type, check that the rank of the
+  // affine_map's domain is the same as those seen prior, check that the
+  // affine_map's co-domain rank is the same as that of the corresponding type,
+  // check that the affine_map is a projected permutation, and, finally, update
+  // inputs and output occurrence counts for dims in the co-domains.
   auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
                                    bool isInput) -> LogicalResult {
     if (iterationSpaceDims == -1) {
@@ -3737,7 +3745,7 @@ LogicalResult ContractOp::verify() {
        llvm::zip(getIndexingMapsArray(), getOperandTypes(),
                  SmallVector<bool>{true, true, false})) {
     if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
-      return failure(); // NOTE: checking lambda will emit error.
+      return failure(); // NB: checkAffineMapAndType will emit relevant error.
   }
 
   bool hasContractingDim = false;
@@ -3745,6 +3753,7 @@ LogicalResult ContractOp::verify() {
     size_t inOccCount = inOccurrences[dimIndex];
     size_t outOccCount = outOccurrences[dimIndex];
 
+    // We have a contracting dim if and only if ...
     hasContractingDim |= inOccCount == 2 && outOccCount == 0;
 
     if (inOccCount == 0 && outOccCount == 0)
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 7ee25ac25ac6c9..a3611b8e4ec62e 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1016,13 +1016,13 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
 // CHECK-NEXT:      arith.addf
 // CHECK-NEXT:      linalg.yield
 
-func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+func.func @contract_matmul(%A: memref<3x5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) {
   linalg.contract
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
                        affine_map<(d0, d1, d2) -> (d2, d1)>,
                        affine_map<(d0, d1, d2) -> (d0, d1)>]
-      ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
-      outs(%arg2: memref<3x7xf32>)
+      ins(%A, %B : memref<3x5xf32>, memref<5x7xf32>)
+      outs(%C: memref<3x7xf32>)
 
   return
 }
@@ -1046,13 +1046,13 @@ func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2
 // CHECK-NEXT:      arith.addf
 // CHECK-NEXT:      linalg.yield
 
-func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+func.func @contract_matmul_transpose_a_b(%A: memref<5x3xf32>, %B: memref<7x5xf32>, %C: memref<3x7xf32>) {
   linalg.contract
       indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
                        affine_map<(d0, d1, d2) -> (d1, d2)>,
                        affine_map<(d0, d1, d2) -> (d0, d1)>]
-      ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
-      outs(%arg2: memref<3x7xf32>)
+      ins(%A, %B : memref<5x3xf32>, memref<7x5xf32>)
+      outs(%C: memref<3x7xf32>)
   return
 }
 
@@ -1075,13 +1075,13 @@ func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7
 // CHECK-NEXT:      arith.addf
 // CHECK-NEXT:      linalg.yield
 
-func.func @contract_batch_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<9x3x7xf32>) {
+func.func @contract_batch_matmul(%A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<9x3x7xf32>) {
   linalg.contract
       indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
-      ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
-      outs(%arg2: memref<9x3x7xf32>)
+      ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>)
+      outs(%C: memref<9x3x7xf32>)
   return
 }
 
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index a192a2f72b286a..bbd6e0fc8e2ccb 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -121,7 +121,7 @@ func.func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B
 // -----
 
 func.func @generalize_matmul_as_contraction_tensor_f16f64f32(
-    %A : tensor<16x8xf16>,
+    %A: tensor<16x8xf16>,
     %B: tensor<8x32xf64>,
     %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
   %0 = linalg.contract
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 633fa2cf236823..b69306fe79b252 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -549,7 +549,8 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
 
 func.func @invalid_indexing_maps_placement_contraction(
     %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
-  // expected-error @+2 {{custom op 'linalg.contract' expected 'indexing_maps' attribute}}
+  // expected-error @+3 {{custom op 'linalg.contract' expected 'indexing_maps' attribute}}
+  // NB: indexing_maps should be provided before ins and outs
   linalg.contract
       ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
       outs(%init : tensor<4x64xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 06e0bb801b5444..578d24a550b08c 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1568,7 +1568,6 @@ func.func @contract_matmul_bcast_b(%A: memref<3x5xf32>, %B: memref<5xf32>, %C: m
 
 // -----
 
-
 // CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
 // CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL: func.func @contract_matmul_bcast_a_b
@@ -1608,7 +1607,6 @@ func.func @contract_matmul_bcast_a_transpose_b(
   return
 }
 
-
 // -----
 
 // CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
@@ -1629,7 +1627,6 @@ func.func @contract_matmul_bcast_b_transpose_a(%A: memref<5x3xf32>, %B: memref<5
   return
 }
 
-
 // -----
 
 // CHECK-LABEL: func @mmt4d

>From ce14c956859a517d77953b2af3df22102b5a9b26 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 28 Jan 2025 14:35:41 -0800
Subject: [PATCH 10/11] Expand docs on transpose and broadcast

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 0f1dde270fc7f7..6cb6138721e176 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -727,10 +727,21 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
         outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
     ```
 
-    Note that by permuting the dims in the co-domains of the `affine_map`s, we
-    can apply arbitrary transposes to the inputs and output. Similarly,
-    arbitrary broadcasts can be achieved through leaving out dims on either
-    input operand - these dims' inferred iter type will be parallel.
+    Note that by permuting dims in the co-domains of the `affine_map`s arbitrary
+    transposes can be applied to the inputs and output. Similarly, arbitrary
+    broadcasts can be achieved through leaving out dims on either input operand
+    (these dims' inferred iter type will be parallel). For example, the
+    following is a variant of batch-matmul where a transposition is applied to
+    `A` while matrix `B` gets broadcasted along the batch dimension:
+
+    ```
+    linalg.contract
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
+                         affine_map<(batch, m, n, k) -> (k, n)>,
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
+        ins(%A, %B: memref<?x?x?xf32>, memref<?x?xf32>)
+        outs(%C: memref<?x?x?xf32>)
+    ```
 
     Numeric casting is performed on the operands to the inner multiplication,
     promoting them to the same data type as the accumulator/output.

>From 072da4b7cacfa0130a18a613698efee5fdef34ac Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 29 Jan 2025 07:21:34 -0800
Subject: [PATCH 11/11] Further doc updates per discussion with @banach-space

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 41 ++++++++++---------
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 32 +++++++--------
 mlir/test/Dialect/Linalg/invalid.mlir         |  2 +-
 3 files changed, 39 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6cb6138721e176..e3d122189f8b77 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -696,27 +696,28 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
 
       `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
 
-    where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
-    dimension identifiers (meant to range over valid indices), corresponding to
-    the co-domains of the mandatory (projected permutation) `indexing_maps` of
-    `A`, `B` and `C`, respectively. `SUM_{dims}` means reduce over all valid
-    indices for the dimensions in the set `dims`.
+    where `I`, `J`, and `H` are tuples of (pairwise distinct) dimension
+    identifiers - meant to range over valid indices - corresponding to the
+    results of the mandatory (projected permutation) `indexing_maps` for `A`,
+    `B` and `C`. `SUM_{dims}` means reduce over all valid indices for the
+    dimensions in the set `dims` (with `I`, `J`, and `K` treated as _sets_ of
+    dim identifiers).
 
     The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
     domain of each of the `affine_map`s. Like for einsums, the iteration type of
     each dim is inferred and is either:
 
-    - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
-      Per the above semantics, these dims will be contracted, i.e. reduced over.
+    - reduction: the dim is used to index into `A` and `B` but not `C`. Per the
+      above semantics, these dims will be contracted, i.e. reduced over.
 
-    - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
-      deriving from matmul terminology - is either an "M-like" dim (if in `A`
-      and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
-      `B`, and `C`).
+    - parallel: the dim is used to index into `C` and at least one of `A` and
+      `B`, and - deriving from matmul terminology - is either an "M-like" dim
+      (if used on `A` and `C`), an "N-like" dim (if used on `B` and `C`) or a
+      "batch"-dim (if used to index into `A`, `B`, and `C`).
 
     For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
     `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
-    `n` and `b` are of parallel iteration-type) and gets represented as:
+    `n` and `b` have parallel iteration-type) and gets represented as:
 
     ```
     %D = linalg.contract
@@ -727,12 +728,11 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
         outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
     ```
 
-    Note that by permuting dims in the co-domains of the `affine_map`s arbitrary
-    transposes can be applied to the inputs and output. Similarly, arbitrary
-    broadcasts can be achieved through leaving out dims on either input operand
-    (these dims' inferred iter type will be parallel). For example, the
-    following is a variant of batch-matmul where a transposition is applied to
-    `A` while matrix `B` gets broadcasted along the batch dimension:
+    Note that by permuting dims in the `affine_map`s' results, accesses to
+    to the inputs and output can be arbitrarily transposed. Similarly, arbitrary
+    broadcasts can be achieved through leaving out dims on either input operand.
+    For example, the following is a variant of batch-matmul with a transposition
+    applied to `A` while `B`'s 2D-matrix gets broadcasted along the batch dim:
 
     ```
     linalg.contract
@@ -744,7 +744,7 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
     ```
 
     Numeric casting is performed on the operands to the inner multiplication,
-    promoting them to the same data type as the accumulator/output.
+    promoting/truncating them to the same data type as the accumulator/output.
 
     TODO: Allow control over the combining/accumulating op and possibly the
           multiplication op.
@@ -756,6 +756,9 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
     AffineMapArrayAttr:$indexing_maps
   );
   let results = (outs Variadic<AnyShaped>:$result_tensors);
+  // NB: The only reason this op has a region - and it get populated at op build
+  //     time - is that currently the LinalgOp interface exposes methods that
+  //     assume a relevant region is available to be queried at any time.
   let regions = (region SizedRegion<1>:$combiner);
 
   let skipDefaultBuilders = 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7c710d955535f6..b33ba1cfb87dc4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3700,33 +3700,33 @@ LogicalResult ContractOp::verify() {
   SmallVector<size_t> inOccurrences;
   SmallVector<size_t> outOccurrences;
 
-  // For each operand's affine_map and type, check that the rank of the
-  // affine_map's domain is the same as those seen prior, check that the
-  // affine_map's co-domain rank is the same as that of the corresponding type,
-  // check that the affine_map is a projected permutation, and, finally, update
-  // inputs and output occurrence counts for dims in the co-domains.
+  // A helper so that for each operand's affine_map and type we check that ...
   auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
                                    bool isInput) -> LogicalResult {
-    if (iterationSpaceDims == -1) {
-      iterationSpaceDims = affineMap.getNumDims();
-      inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
-      outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
-    } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
-      return emitError("iteration spaces of provided affine_maps differ");
-    }
+    // ... the affine_map is a projected permutation;
+    if (!affineMap.isProjectedPermutation())
+      return emitError("provided affine_map is not a projected permutation");
 
+    // ... the rank of the affine_map's results and corresponding type match;
     if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
       if (affineMap.getNumResults() != shapedType.getRank())
-        return emitError("ranks of shaped operand and co-domain of "
-                         "corresponding affine_map differ");
+        return emitError("ranks of shaped operand and results of corresponding "
+                         "affine_map differ");
     } else if (affineMap.getNumResults() != 0) {
       return emitError("affine_map specifies shaped access while operand has "
                        "non-shaped type");
     }
 
-    if (!affineMap.isProjectedPermutation())
-      return emitError("provided affine_map is not a projected permutation");
+    // ... the rank of the affine_map's domain is the same as those seen prior;
+    if (iterationSpaceDims == -1) {
+      iterationSpaceDims = affineMap.getNumDims();
+      inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+      outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+    } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
+      return emitError("iteration spaces of provided affine_maps differ");
+    }
 
+    // ... update counts of dims used to access either an input or the output.
     for (AffineExpr affineExpr : affineMap.getResults()) {
       auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
       if (!affineDimExpr)
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index b69306fe79b252..0ea805fef5361f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -592,7 +592,7 @@ func.func @differing_iteration_space_of_affine_maps_contraction(
 
 func.func @mismatched_ranks_affine_map_and_operand_contraction(
     %lhs: tensor<4x1x2xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
-  // expected-error @+1 {{ranks of shaped operand and co-domain of corresponding affine_map differ}}
+  // expected-error @+1 {{ranks of shaped operand and results of corresponding affine_map differ}}
   linalg.contract
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
                        affine_map<(d0, d1, d2) -> (d2, d1)>,



More information about the Mlir-commits mailing list