[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)

Rolf Morel llvmlistbot at llvm.org
Mon Jan 20 15:48:12 PST 2025


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/123618

>From e3716872e9b379553fa52c67cfec49392cacefad Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 19 Jan 2025 13:42:59 -0800
Subject: [PATCH 1/2] [MLIR][Linalg] Introduce linalg.contract

A new op that allows for representing arbitrary contractions on operands
of arbitrary rank, with arbitrary transposes and arbitrary broadcasts
specified through its indexing_maps attribute.

Supports the expected lowerings to linalg.generic and to vector.contract.
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 118 +++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 224 ++++++++++++++++--
 .../Dialect/Linalg/generalize-named-ops.mlir  | 168 ++++++++++++-
 .../generalize-named-polymorphic-ops.mlir     |  50 ++++
 mlir/test/Dialect/Linalg/invalid.mlir         |  97 ++++++++
 mlir/test/Dialect/Linalg/loops.mlir           |  47 ++++
 mlir/test/Dialect/Linalg/named-ops.mlir       |  23 +-
 mlir/test/Dialect/Linalg/roundtrip.mlir       |  17 +-
 mlir/test/Dialect/Linalg/tile-tensors.mlir    |  46 ++++
 .../Linalg/transform-op-vectorize.mlir        |  33 +++
 10 files changed, 789 insertions(+), 34 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..d4277bd34f3946 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,124 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+  let summary = [{
+    Perform a contraction on two inputs, accumulating on top of a third.
+  }];
+  let description = [{
+    The semantics of contracting inputs `A` and `B` on top of `C` to produce
+    output `D` is given by
+
+      `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+    where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
+    dimension identifiers (meant to range over valid indices), corresponding to
+    the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
+    and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
+    the dimensions in the set `dims`.
+
+    The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+    domain of each of the `affine_map`s. Like for einsums, the iteration type of
+    each dim is inferred and is either:
+
+    - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
+      Per the above semantics, these dims will be contracted, i.e. reduced over.
+
+    - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
+      deriving from matmul terminology - is either an "M-like" dim (if in `A`
+      and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
+      `B`, and `C`).
+
+    For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+    `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+    `n` and `b` are of parallel iteration-type) and gets represented as:
+
+    ```
+    %0 = linalg.contract
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
+        ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+        outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+    ```
+
+    Note that by permuting the dims in the co-domains of the `affine_map`s, we
+    can apply arbitrary transposes to the inputs and output. Similarly,
+    arbitrary broadcasts can be achieved through leaving out dims on either
+    input operand.
+
+    Numeric casting is performed on the operands to the inner multiplication,
+    promoting them to the same data type as the accumulator/output.
+  }];
+
+  let arguments = (ins
+    Variadic<AnyType>:$inputs,
+    Variadic<AnyShaped>:$outputs,
+    AffineMapArrayAttr:$indexing_maps
+  );
+  let results = (outs Variadic<AnyShaped>:$result_tensors);
+  let regions = (region SizedRegion<1>:$combiner);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+      "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
+                          outputs, attributes, regionBuilder);
+      }]>,
+    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+      "ArrayAttr":$indexingMaps,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+                          attributes, regionBuilder);
+      }]>
+  ];
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    // Declare/implement functions necessary for LinalgStructuredInterface.
+    /// Infer iterator types for each dim in the domain of IndexingMaps.
+    SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+    /// IndexingMaps always depends on attr associated to current Op instance.
+    bool hasDynamicIndexingMaps() { return true; };
+    bool hasUserDefinedMaps() { return true; };
+
+    static unsigned getNumRegionArgs();
+
+    static void regionBuilder(ImplicitLocOpBuilder &b,
+                              Block &block, ArrayRef<NamedAttribute> attrs);
+
+    static std::function<void(ImplicitLocOpBuilder &,
+                              Block &, ArrayRef<NamedAttribute>)>
+    getRegionBuilder() {
+      return regionBuilder;
+    }
+
+    std::string getLibraryCallName() {
+      return "op_has_no_registered_library_name";
+    }
+
+    // Implement function necessary for DestinationStyleOpInterface.
+    ::mlir::MutableOperandRange getDpsInitsMutable() {
+      return getOutputsMutable();
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..355ed2d269291c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3528,44 +3528,45 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
   return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
 }
 
-ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
-  SmallVector<Attribute, 3> indexingMapsAttr;
-  Attribute mapAttr;
-  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
-    if (parser.parseEqual())
-      return failure();
+FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
+  if (parser.parseOptionalKeyword("indexing_maps"))
+    return {nullptr}; // Success in case indexing_maps was not provided.
 
-    if (parser.parseLSquare())
+  SmallVector<Attribute> indexingMaps;
+
+  auto parseIndexingMap = [&]() -> ParseResult {
+    AffineMapAttr affineMapAttr;
+    if (parser.parseAttribute(affineMapAttr))
       return failure();
+    indexingMaps.push_back(affineMapAttr);
+    return success();
+  };
 
-    do {
-      if (parser.parseAttribute(mapAttr))
-        return failure();
-      if (!isa<AffineMapAttr>(mapAttr)) {
-        return parser.emitError(parser.getCurrentLocation(),
-                                "expected affine map attribute");
-      }
-      indexingMapsAttr.push_back(mapAttr);
+  if (parser.parseEqual() ||
+      parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
+                                     parseIndexingMap))
+    return failure();
 
-      if (parser.parseOptionalComma())
-        break;
-    } while (true);
+  return parser.getBuilder().getArrayAttr(indexingMaps);
+}
 
-    if (parser.parseRSquare())
-      return failure();
-  }
-  // Initialize indexingMaps, if not supplied explicitly.
-  if (indexingMapsAttr.empty()) {
-    indexingMapsAttr = llvm::map_to_vector(
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr))
+    return failure();
+
+  if (*indexingMapsAttr == nullptr) {
+    auto indexingMapAttrs = llvm::map_to_vector(
         MatmulOp::getDefaultIndexingMaps(parser.getContext()),
         [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+    indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
   }
-  result.addAttribute("indexing_maps",
-                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
 
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
   return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
                                 MatmulOp::getRegionBuilder());
 }
+
 void MatmulOp::print(OpAsmPrinter &p) {
   SmallVector<StringRef, 3> elidedAttrs = {
       "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3599,6 +3600,7 @@ LogicalResult MatmulOp::verify() {
 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   return memref::foldMemRefCast(*this);
 }
+
 void MatmulOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -3611,5 +3613,175 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ContractOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+  /// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+  /// domains are all the same, and each implements a projected permutation.
+  /// Each dim in the domain must occur for at least one operand and is
+  /// classified as either batch, N-like, M-like, or K-like. Only the latter
+  /// corresponds to a reduction _and_ it is the only dim-kind which does not
+  /// occur for the output operand. We use this fact for fast inference:
+  // NB: In case we allow dims to occur solely for one input, the above still
+  //     holds: per the einsum semantics, these are reduction dims as well.
+  auto dimsInOutput = SmallVector<bool>(outAffineMap.getNumDims(), false);
+  for (auto result : outAffineMap.getResults()) {
+    auto dimExpr = dyn_cast<AffineDimExpr>(result);
+    assert(dimExpr && "affine_map is a projected permutation");
+    dimsInOutput[dimExpr.getPosition()] = true;
+  }
+
+  SmallVector<utils::IteratorType> iteratorTypes;
+  for (auto dimOccursInOutput : dimsInOutput)
+    iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+                                              : utils::IteratorType::reduction);
+
+  return iteratorTypes;
+}
+
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                               ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "ContractOp regionBuilder expects 3 args");
+  RegionBuilderHelper helper(b, block);
+
+  TypeFn castSignedness = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castSignedness = attr.getValue();
+  }
+
+  // TODO: Support fields with operators besides mult & add.
+  Type outType = block.getArgument(2).getType();
+  Value lhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+  Value rhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+  Value productAtOutType =
+      helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+                                      productAtOutType);
+  helper.yieldOutputs({result});
+}
+
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected 'indexing_map' attribute");
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
+
+  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+                                regionBuilder);
+}
+
+void ContractOp::print(OpAsmPrinter &p) {
+  p << " indexing_maps = [";
+  llvm::interleaveComma(getIndexingMaps(), p,
+                        [&](Attribute attr) { p.printAttribute(attr); });
+  p << "]";
+  printNamedStructuredOp(
+      p, getOperation(), getInputs(), getOutputs(),
+      /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+}
+
+LogicalResult ContractOp::verify() {
+  int iterationSpaceDims = -1;
+  // Maps iter space dim (as index) to num of occurrences in inputs and output.
+  SmallVector<size_t> inOccurrences;
+  SmallVector<size_t> outOccurrences;
+
+  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
+                                   bool isInput) -> LogicalResult {
+    if (iterationSpaceDims == -1) {
+      iterationSpaceDims = affineMap.getNumDims();
+      inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+      outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+    } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
+      return emitError("iteration spaces of provided affine_maps differ");
+    }
+
+    if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
+      if (affineMap.getNumResults() != shapedType.getRank())
+        return emitError("ranks of shaped operand and co-domain of "
+                         "corresponding affine_map differ");
+    } else if (affineMap.getNumResults() != 0) {
+      return emitError("affine_map specifies shaped access while operand has "
+                       "non-shaped type");
+    }
+
+    if (!affineMap.isProjectedPermutation())
+      return emitError("provided affine_map is not a projected permutation");
+
+    for (AffineExpr affineExpr : affineMap.getResults()) {
+      auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
+      if (!affineDimExpr)
+        llvm_unreachable("affine_map is a projected permutation");
+
+      if (isInput)
+        inOccurrences[affineDimExpr.getPosition()] += 1;
+      else
+        outOccurrences[affineDimExpr.getPosition()] += 1;
+    }
+
+    return success();
+  };
+
+  for (auto &&[affineMap, operandType, isInput] :
+       llvm::zip(getIndexingMapsArray(), getOperandTypes(),
+                 SmallVector<bool>{true, true, false}))
+    if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
+      return failure(); // NOTE: checking lambda will emit error.
+
+  bool hasContractingDim = false;
+  for (auto &&[inOccCount, outOccCount] : zip(inOccurrences, outOccurrences)) {
+    hasContractingDim |= inOccCount == 2 && outOccCount == 0;
+
+    if (inOccCount == 0)
+      return emitError("iteration space dim not used by either input");
+
+    // NB: A dim which occurs for only one input operand and not for the output.
+    //     In terms of einsum semantics, such dims have a sensible meaning -
+    //     namely an additional reduction per such dim - though this can also
+    //     always be expressed through an additional op. Additionally, at time
+    //     of writing, vector.contract's verifier accepts these dims but many of
+    //     its lowerings do not handle these kinds of dims. Hence...
+    // TODO: Remove following once we have comprehensive support for input-only
+    //       reduction dims, at both the linalg- and vector-dialect levels.
+    if (inOccCount == 1 && outOccCount != 1)
+      return emitError("iter type of dim is not one of M, N, K or batch");
+  }
+
+  if (!hasContractingDim)
+    return emitError("'indexing_maps' do not specify a contracting dimension");
+
+  return success();
+}
+
+LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+void ContractOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ContractOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..f7e570d5ce38f0 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -943,7 +943,6 @@ func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -969,7 +968,6 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
                       ]
                       ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -996,9 +994,173 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @contract_matmul(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @contract_matmul_transpose_a_b(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d2, d0)>,
+                    affine_map<(d0, d1, d2) -> (d1, d2)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @contract_batch_matmul(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<9x3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_batch_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<9x3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                  ]
+                  ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
+                  outs(%arg2: memref<9x3x7xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @contract_batch_reduce_matmul(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_batch_reduce_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                  ]
+                  ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<9x5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<9x7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(%arg0: memref<9x5x3xf32>, %arg1: memref<9x7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                  ]
+                  ins(%arg0, %arg1 : memref<9x5x3xf32>, memref<9x7x5xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> ()>
+
+// CHECK-LABEL:   func.func @contract_dot
+// CHECK-SAME:                           (%[[VAL_0:.*]]: memref<9xf32>, %[[VAL_1:.*]]: memref<9xf32>, %[[VAL_2:.*]]: memref<f32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_dot(%arg0: memref<9xf32>, %arg1: memref<9xf32>, %arg2: memref<f32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0) -> (d0)>,
+                    affine_map<(d0) -> (d0)>,
+                    affine_map<(d0) -> ()>
+                  ]
+                  ins(%arg0, %arg1 : memref<9xf32>, memref<9xf32>)
+                  outs(%arg2: memref<f32>)
+  return
+}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index c170c5be4abff9..9acb7562f96ee0 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -120,6 +120,56 @@ func.func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B
 
 // -----
 
+func.func @generalize_matmul_as_contraction_tensor_f16f64f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.contract indexing_maps = [
+                         affine_map<(d0, d1, d2) -> (d0, d2)>,
+                         affine_map<(d0, d1, d2) -> (d2, d1)>,
+                         affine_map<(d0, d1, d2) -> (d0, d1)>
+                       ]
+                       ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
+                       outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @generalize_matmul_as_contraction_tensor_f16f64f32
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+// Verify floating point extension and truncation.
+// CHECK-NEXT:   %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT:   %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
+// CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+
+func.func @generalize_matmul_as_contract_with_ext_and_trunc(%arg0: tensor<24x12xf16>,
+                            %arg1: tensor<12x25xf16>,
+                            %arg2: tensor<24x25xf32>) -> tensor<24x25xf16> {
+  %0 = linalg.contract indexing_maps = [
+                         affine_map<(m, n, k) -> (m, k)>,
+                         affine_map<(m, n, k) -> (k, n)>,
+                         affine_map<(m, n, k) -> (m, n)>
+                       ]
+                       ins(%arg0, %arg1 : tensor<24x12xf16>, tensor<12x25xf16>)
+                       outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  %1 = arith.truncf %0 : tensor<24x25xf32> to tensor<24x25xf16>
+  func.return %1 : tensor<24x25xf16>
+}
+
+// CHECK-LABEL: @generalize_matmul_as_contract_with_ext_and_trunc
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+// Verify floating point extension and truncation.
+// CHECK-NEXT:   %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT:   %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
+// CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<24x25xf32>
+// CHECK-NEXT:   %[[RES:.+]] = arith.truncf {{.*}} : tensor<24x25xf32> to tensor<24x25xf16>
+
+// -----
+
 func.func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
   %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
     ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a59472377a732c..59eeb3953e548f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -529,6 +529,103 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
 
 // -----
 
+func.func @invalid_indexing_maps_placement_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{custom op 'linalg.contract' expected 'indexing_map' attribute}}
+  linalg.contract ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>)
+                  indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+  return
+}
+
+// -----
+
+func.func @invalid_affine_map_in_indexing_maps_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{provided affine_map is not a projected permutation}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0 + d2, d2)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @differing_iteration_space_of_affine_maps_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{iteration spaces of provided affine_maps differ}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @mismatched_ranks_affine_map_and_operand_contraction(%lhs: tensor<4x1x2xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{ranks of shaped operand and co-domain of corresponding affine_map differ}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : tensor<4x1x2xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+// -----
+
+func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: tensor<4x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{affine_map specifies shaped access while operand has non-shaped type}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1) -> (d0)>,
+                    affine_map<(d0, d1) -> (d0, d1)>,
+                    affine_map<(d0, d1) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : f32, tensor<4x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{iteration space dim not used by either input}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @unused_iteration_space_dim_contraction(%lhs: tensor<8x4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{iter type of dim is not one of M, N, K or batch}}
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+                  ]
+                  ins(%lhs, %rhs : tensor<8x4x1xf32>, tensor<1x64xf32>)
+                  outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
 func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
   // expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
   linalg.conv_2d_nhwc_hwcf
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 6286a11c11a21f..0a83750b81dea4 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -627,6 +627,53 @@ func.func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f
 //----------------------------------------------------------------------------//
 // Named ops to loops.
 //----------------------------------------------------------------------------//
+func.func @batch_reduce_matmul_as_contract(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                  ]
+                  ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+                  outs(%C : memref<?x?xf32>)
+  return
+}
+// CHECK-LABEL: @batch_reduce_matmul_as_contract
+//  CHECK-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECK-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECK-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//       CHECK: %[[B:.*]] = memref.dim %[[mA]], %c0 : memref<?x?x?xf32>
+//       CHECK: %[[M:.*]] = memref.dim %[[mA]], %c1 : memref<?x?x?xf32>
+//       CHECK: %[[K:.*]] = memref.dim %[[mA]], %c2 : memref<?x?x?xf32>
+//       CHECK: %[[N:.*]] = memref.dim %[[mB]], %c2 : memref<?x?x?xf32>
+//       CHECK: scf.for %[[b:.*]] = %{{.*}} to %[[B]]
+//       CHECK:   scf.for %[[m:.*]] = %{{.*}} to %[[M]]
+//       CHECK:     scf.for %[[n:.*]] = %{{.*}} to %[[N]]
+//       CHECK:       scf.for %[[k:.*]] = %{{.*}} to %[[K]]
+//       CHECK:       %[[va:.*]] = memref.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECK:       %[[vb:.*]] = memref.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECK:       %[[vc:.*]] = memref.load %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+//       CHECK:       %[[inc:.*]] = arith.mulf %[[va]], %[[vb]] : f32
+//       CHECK:       %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+//       CHECK:       store %[[res]], %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+
+// CHECKPARALLEL-LABEL: @batch_reduce_matmul_as_contract
+//  CHECKPARALLEL-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//       CHECKPARALLEL: %[[B:.*]] = memref.dim %[[mA]], %c0 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[M:.*]] = memref.dim %[[mA]], %c1 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[K:.*]] = memref.dim %[[mA]], %c2 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[N:.*]] = memref.dim %[[mB]], %c2 : memref<?x?x?xf32>
+//       CHECKPARALLEL: scf.for %[[b:.*]] = %{{.*}} to %[[B]]
+//       CHECKPARALLEL:   scf.parallel (%[[m:.*]], %[[n:.*]]) = ({{.*}}) to (%[[M]], %[[N]]) step ({{.*}}) {
+//       CHECKPARALLEL:     scf.for %[[k:.*]] = %{{.*}} to %[[K]]
+//       CHECKPARALLEL:         %[[va:.*]] = memref.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[vb:.*]] = memref.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[vc:.*]] = memref.load %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+//       CHECKPARALLEL:         %[[inc:.*]] = arith.mulf %[[va]], %[[vb]] : f32
+//       CHECKPARALLEL:         %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:         store %[[res]], %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+
 func.func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
   linalg.batch_matmul ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
                      outs(%C : memref<?x?x?xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 68aa5a85b5e0e6..3c68e7d394642e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1245,7 +1245,6 @@ func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -1259,7 +1258,6 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
                       ]
                       ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -1509,6 +1507,27 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: func @contract
+//       CHECK:   linalg.contract
+//  CHECK-SAME:     indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @contract(%arg0: memref<2x3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                  ]
+                  ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x5x7xf32>)
+                  outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @mmt4d
 func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> {
   // CHECK: %{{.+}} = linalg.mmt4d
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 1b8969bd115595..99cbb6647effbe 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -277,22 +277,33 @@ func.func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offs
 
 // -----
 
-
 func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
                 %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
-  -> (tensor<?x?x?xf32>)
+  -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
 {
   linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
                      outs(%c3: memref<?x?x?xf32>)
+  linalg.contract indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                                   affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                                   affine_map<(batch, m, n, k) -> (batch, m, n)>]
+                  ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
+                  outs(%c3: memref<?x?x?xf32>)
   %res1 = linalg.batch_matmul
                       ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
                      outs(%tc3: tensor<?x?x?xf32>)
                   -> tensor<?x?x?xf32>
-  return %res1 : tensor<?x?x?xf32>
+  %res2 = linalg.contract indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                                           affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                                           affine_map<(batch, m, n, k) -> (batch, m, n)>]
+                          ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+                          outs(%tc3: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
 }
 // CHECK-LABEL: func @named_ops
 //       CHECK:   linalg.batch_matmul
+//       CHECK:   linalg.contract
 //       CHECK:   linalg.batch_matmul
+//       CHECK:   linalg.contract
 
 // -----
 
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 8f13c690704572..1de1863d6deb13 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -37,6 +37,52 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK:       #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-NEXT:  #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-NEXT:  #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @matmul_as_contract_tensors(
+// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @matmul_as_contract_tensors(
+  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
+//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
+//      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xf32>) {
+//      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTD:.*]] = linalg.contract indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME:                                  outs(%[[sTC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
+//      CHECK:       %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<?x?xf32> into tensor<?x?xf32>
+//      CHECK:       scf.yield %[[TD]] : tensor<?x?xf32>
+//      CHECK:     scf.yield %[[TD2]] : tensor<?x?xf32>
+//      CHECK:   scf.yield %[[TD1]] : tensor<?x?xf32>
+  %0 = linalg.contract indexing_maps = [
+                         affine_map<(d0, d1, d2) -> (d0, d2)>,
+                         affine_map<(d0, d1, d2) -> (d2, d1)>,
+                         affine_map<(d0, d1, d2) -> (d0, d1)>
+                       ]
+                       ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+                       outs(%arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+
+//      CHECK: return %[[TD0]] : tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 3, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @matmul_tensors_with_size_zeros(
 // CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
 // CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 0d59dbba8940d4..2d30d62039642a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -54,6 +54,39 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @vectorize_matmul_as_contract
+// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32>
+// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32>
+// CHECK-SAME: %[[C:.*]]: tensor<24x25xf32>
+func.func @vectorize_matmul_as_contract(%arg0: tensor<24x12xf32>,
+                            %arg1: tensor<12x25xf32>,
+                            %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
+  // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
+  // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
+  // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+  // CHECK: vector.transfer_write %[[vR]], %[[C]]
+  %0 = linalg.contract indexing_maps = [
+                         affine_map<(m, n, k) -> (m, k)>,
+                         affine_map<(m, n, k) -> (k, n)>,
+                         affine_map<(m, n, k) -> (m, n)>
+                       ]
+                       ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>)
+                       outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @vectorize_copy_memref
 // CHECK-SAME: %[[A:.*]]: memref<100x100xf32>,
 // CHECK-SAME: %[[B:.*]]: memref<100x100xf32>

>From 3d0d5b307d7326929e5a8da0942c51f4f133543e Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 20 Jan 2025 15:46:48 -0800
Subject: [PATCH 2/2] Remove extraneous line that allowed adding a region

---
 mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index d4277bd34f3946..8c9a4e56ce00a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -742,7 +742,6 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
     AffineMapArrayAttr:$indexing_maps
   );
   let results = (outs Variadic<AnyShaped>:$result_tensors);
-  let regions = (region SizedRegion<1>:$combiner);
 
   let skipDefaultBuilders = 1;
   let builders = [



More information about the Mlir-commits mailing list