[Mlir-commits] [mlir] 0d4efa2 - [MLIR][Linalg] Introduce linalg.contract (#123618)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 29 09:28:55 PST 2025


Author: Rolf Morel
Date: 2025-01-29T17:28:52Z
New Revision: 0d4efa27252cbbea4b5672d4d8ffc15a3ba51d83

URL: https://github.com/llvm/llvm-project/commit/0d4efa27252cbbea4b5672d4d8ffc15a3ba51d83
DIFF: https://github.com/llvm/llvm-project/commit/0d4efa27252cbbea4b5672d4d8ffc15a3ba51d83.diff

LOG: [MLIR][Linalg] Introduce linalg.contract (#123618)

A new op that allows for representing arbitrary contractions on operands
of arbitrary rank, with arbitrary transposes and arbitrary broadcasts
specified through its indexing_maps attribute.

Supports the expected lowerings to linalg.generic and to
vector.contract.

Corresponding RFC is here:
https://discourse.llvm.org/t/mlir-rfc-introduce-linalg-contract/83589

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/named-ops.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/tile-tensors.mlir
    mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..e3d122189f8b77 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,142 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+  let summary = [{
+    Perform a contraction on two inputs, accumulating into the third.
+  }];
+  let description = [{
+    The semantics of contracting inputs `A` and `B` on top of `C` to produce
+    output `D` is given by
+
+      `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+    where `I`, `J`, and `H` are tuples of (pairwise distinct) dimension
+    identifiers - meant to range over valid indices - corresponding to the
+    results of the mandatory (projected permutation) `indexing_maps` for `A`,
+    `B` and `C`. `SUM_{dims}` means reduce over all valid indices for the
+    dimensions in the set `dims` (with `I`, `J`, and `K` treated as _sets_ of
+    dim identifiers).
+
+    The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+    domain of each of the `affine_map`s. Like for einsums, the iteration type of
+    each dim is inferred and is either:
+
+    - reduction: the dim is used to index into `A` and `B` but not `C`. Per the
+      above semantics, these dims will be contracted, i.e. reduced over.
+
+    - parallel: the dim is used to index into `C` and at least one of `A` and
+      `B`, and - deriving from matmul terminology - is either an "M-like" dim
+      (if used on `A` and `C`), an "N-like" dim (if used on `B` and `C`) or a
+      "batch"-dim (if used to index into `A`, `B`, and `C`).
+
+    For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+    `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+    `n` and `b` have parallel iteration-type) and gets represented as:
+
+    ```
+    %D = linalg.contract
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
+        ins(%A, %B: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+        outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+    ```
+
+    Note that by permuting dims in the `affine_map`s' results, accesses to
+    to the inputs and output can be arbitrarily transposed. Similarly, arbitrary
+    broadcasts can be achieved through leaving out dims on either input operand.
+    For example, the following is a variant of batch-matmul with a transposition
+    applied to `A` while `B`'s 2D-matrix gets broadcasted along the batch dim:
+
+    ```
+    linalg.contract
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
+                         affine_map<(batch, m, n, k) -> (k, n)>,
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
+        ins(%A, %B: memref<?x?x?xf32>, memref<?x?xf32>)
+        outs(%C: memref<?x?x?xf32>)
+    ```
+
+    Numeric casting is performed on the operands to the inner multiplication,
+    promoting/truncating them to the same data type as the accumulator/output.
+
+    TODO: Allow control over the combining/accumulating op and possibly the
+          multiplication op.
+  }];
+
+  let arguments = (ins
+    Variadic<AnyType>:$inputs,
+    Variadic<AnyShaped>:$outputs,
+    AffineMapArrayAttr:$indexing_maps
+  );
+  let results = (outs Variadic<AnyShaped>:$result_tensors);
+  // NB: The only reason this op has a region - and it get populated at op build
+  //     time - is that currently the LinalgOp interface exposes methods that
+  //     assume a relevant region is available to be queried at any time.
+  let regions = (region SizedRegion<1>:$combiner);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+      "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
+                          outputs, attributes, regionBuilder);
+      }]>,
+    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+      "ArrayAttr":$indexingMaps,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+                          attributes, regionBuilder);
+      }]>
+  ];
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    // Declare/implement functions necessary for LinalgStructuredInterface.
+
+    /// Infer iterator types for each dim in the domain of IndexingMaps.
+    SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+    /// IndexingMaps always depends on attr associated to current Op instance.
+    bool hasDynamicIndexingMaps() { return true; };
+    bool hasUserDefinedMaps() { return true; };
+
+    static unsigned getNumRegionArgs();
+
+    static void regionBuilder(ImplicitLocOpBuilder &b,
+                              Block &block, ArrayRef<NamedAttribute> attrs);
+
+    static std::function<void(ImplicitLocOpBuilder &,
+                              Block &, ArrayRef<NamedAttribute>)>
+    getRegionBuilder() {
+      return regionBuilder;
+    }
+
+    std::string getLibraryCallName() {
+      return "op_has_no_registered_library_name";
+    }
+
+    // Implement function necessary for DestinationStyleOpInterface.
+    ::mlir::MutableOperandRange getDpsInitsMutable() {
+      return getOutputsMutable();
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 71bae4666d6619..b33ba1cfb87dc4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3531,44 +3531,39 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
   return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
 }
 
-ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
-  SmallVector<Attribute, 3> indexingMapsAttr;
-  Attribute mapAttr;
-  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
-    if (parser.parseEqual())
-      return failure();
+FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
+  if (parser.parseOptionalKeyword("indexing_maps"))
+    return {nullptr}; // Success in case indexing_maps was not provided.
 
-    if (parser.parseLSquare())
-      return failure();
+  ArrayAttr arrayAttr;
+  if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
+    return failure();
 
-    do {
-      if (parser.parseAttribute(mapAttr))
-        return failure();
-      if (!isa<AffineMapAttr>(mapAttr)) {
-        return parser.emitError(parser.getCurrentLocation(),
-                                "expected affine map attribute");
-      }
-      indexingMapsAttr.push_back(mapAttr);
+  if (llvm::any_of(arrayAttr,
+                   [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
+    return parser.emitError(parser.getCurrentLocation())
+           << "element of indexing_maps array is not an affine_map";
 
-      if (parser.parseOptionalComma())
-        break;
-    } while (true);
+  return arrayAttr;
+}
 
-    if (parser.parseRSquare())
-      return failure();
-  }
-  // Initialize indexingMaps, if not supplied explicitly.
-  if (indexingMapsAttr.empty()) {
-    indexingMapsAttr = llvm::map_to_vector(
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr))
+    return failure();
+
+  if (*indexingMapsAttr == nullptr) {
+    auto indexingMapAttrs = llvm::map_to_vector(
         MatmulOp::getDefaultIndexingMaps(parser.getContext()),
         [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+    indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
   }
-  result.addAttribute("indexing_maps",
-                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
 
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
   return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
                                 MatmulOp::getRegionBuilder());
 }
+
 void MatmulOp::print(OpAsmPrinter &p) {
   SmallVector<StringRef, 3> elidedAttrs = {
       "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3602,6 +3597,7 @@ LogicalResult MatmulOp::verify() {
 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   return memref::foldMemRefCast(*this);
 }
+
 void MatmulOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -3614,5 +3610,192 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ContractOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+  // domains are all the same, and each implements a projected permutation.
+  // Each iteration space dim must occur for at least one operand and either
+  // takes part in a contraction/reduction or else has parallel iteration type.
+  // We have that a dim is a contraction/reduction dim if and only if the dim
+  // occurs for the output operand. We use this fact for fast inference:
+  // NB: In case we allow dims to occur solely for one input, the above still
+  //     holds: per the einsum semantics, these are reduction dims as well.
+  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
+  for (auto result : outAffineMap.getResults()) {
+    auto dimExpr = dyn_cast<AffineDimExpr>(result);
+    assert(dimExpr && "affine_map is a projected permutation");
+    dimsInOutput[dimExpr.getPosition()] = true;
+  }
+
+  SmallVector<utils::IteratorType> iteratorTypes;
+  for (auto dimOccursInOutput : dimsInOutput)
+    iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+                                              : utils::IteratorType::reduction);
+
+  return iteratorTypes;
+}
+
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                               ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "ContractOp regionBuilder expects 3 args");
+  RegionBuilderHelper helper(b, block);
+
+  TypeFn castSignedness = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castSignedness = attr.getValue();
+  }
+
+  // TODO: Support fields with operators besides mult & add.
+  Type outType = block.getArgument(2).getType();
+  Value lhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+  Value rhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+  Value productAtOutType =
+      helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+                                      productAtOutType);
+  helper.yieldOutputs({result});
+}
+
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected 'indexing_maps' attribute");
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
+
+  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+                                regionBuilder);
+}
+
+void ContractOp::print(OpAsmPrinter &p) {
+  p << " indexing_maps = [";
+  llvm::interleaveComma(getIndexingMaps(), p,
+                        [&](Attribute attr) { p.printAttribute(attr); });
+  p << "]";
+  printNamedStructuredOp(
+      p, getOperation(), getInputs(), getOutputs(),
+      /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+}
+
+LogicalResult ContractOp::verify() {
+  int iterationSpaceDims = -1;
+  // Map iter space dims to #occurrences in inputs' and output's affine_maps:
+  // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
+  // access an input operand (so occurrence count can be at most 2) and
+  // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
+  SmallVector<size_t> inOccurrences;
+  SmallVector<size_t> outOccurrences;
+
+  // A helper so that for each operand's affine_map and type we check that ...
+  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
+                                   bool isInput) -> LogicalResult {
+    // ... the affine_map is a projected permutation;
+    if (!affineMap.isProjectedPermutation())
+      return emitError("provided affine_map is not a projected permutation");
+
+    // ... the rank of the affine_map's results and corresponding type match;
+    if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
+      if (affineMap.getNumResults() != shapedType.getRank())
+        return emitError("ranks of shaped operand and results of corresponding "
+                         "affine_map 
diff er");
+    } else if (affineMap.getNumResults() != 0) {
+      return emitError("affine_map specifies shaped access while operand has "
+                       "non-shaped type");
+    }
+
+    // ... the rank of the affine_map's domain is the same as those seen prior;
+    if (iterationSpaceDims == -1) {
+      iterationSpaceDims = affineMap.getNumDims();
+      inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+      outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+    } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
+      return emitError("iteration spaces of provided affine_maps 
diff er");
+    }
+
+    // ... update counts of dims used to access either an input or the output.
+    for (AffineExpr affineExpr : affineMap.getResults()) {
+      auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
+      if (!affineDimExpr)
+        llvm_unreachable("affine_map is a projected permutation");
+
+      if (isInput)
+        inOccurrences[affineDimExpr.getPosition()] += 1;
+      else
+        outOccurrences[affineDimExpr.getPosition()] += 1;
+    }
+
+    return success();
+  };
+
+  for (auto &&[affineMap, operandType, isInput] :
+       llvm::zip(getIndexingMapsArray(), getOperandTypes(),
+                 SmallVector<bool>{true, true, false})) {
+    if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
+      return failure(); // NB: checkAffineMapAndType will emit relevant error.
+  }
+
+  bool hasContractingDim = false;
+  for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
+    size_t inOccCount = inOccurrences[dimIndex];
+    size_t outOccCount = outOccurrences[dimIndex];
+
+    // We have a contracting dim if and only if ...
+    hasContractingDim |= inOccCount == 2 && outOccCount == 0;
+
+    if (inOccCount == 0 && outOccCount == 0)
+      return emitError() << "iteration space dim at index " << dimIndex
+                         << " not used to access any operand";
+
+    // NB: We disallow a dim which occurs for only one input operand and not
+    //     for the output. In terms of einsum semantics such dims have a
+    //     sensible meaning - namely an additional reduction per each such dim.
+    //     By contrast, the ContractionOpInterface does not know about this
+    //     iter type - cf. inferContractionDims' supported dim kinds. Similarly,
+    //     while vector.contract's verifier accepts dims of this kind many of
+    //     its lowerings give up on encountering these dims.
+    // TODO: Remove following once we have comprehensive support for input-only
+    //       reduction dims, at both the linalg- and vector-dialect levels.
+    if (inOccCount == 1 && outOccCount != 1)
+      return emitError()
+             << "iteration space dim at index " << dimIndex
+             << " is neither a contracting dim nor of parallel iteration type";
+  }
+
+  if (!hasContractingDim)
+    return emitError("'indexing_maps' do not specify a contracting dimension");
+
+  return success();
+}
+
+LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+void ContractOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ContractOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..a3611b8e4ec62e 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -943,7 +943,6 @@ func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -969,7 +968,6 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
                       ]
                       ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -996,9 +994,211 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
 // -----
 
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @contract_matmul(
+// CHECK-SAME:      %[[A:.*]]: memref<3x5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<5x7xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_matmul(%A: memref<3x5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) {
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<3x5xf32>, memref<5x7xf32>)
+      outs(%C: memref<3x7xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @contract_matmul_transpose_a_b(
+// CHECK-SAME:      %[[A:.*]]: memref<5x3xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<7x5xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_matmul_transpose_a_b(%A: memref<5x3xf32>, %B: memref<7x5xf32>, %C: memref<3x7xf32>) {
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<5x3xf32>, memref<7x5xf32>)
+      outs(%C: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func.func @contract_batch_matmul(
+// CHECK-SAME:      %[[A:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<9x3x7xf32>) {
+
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_batch_matmul(%A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<9x3x7xf32>) {
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+      ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>)
+      outs(%C: memref<9x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @contract_batch_reduce_matmul(
+// CHECK-SAME:      %[[A:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["reduction", "parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @contract_batch_reduce_matmul(
+    %A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<3x7xf32>) {
+  linalg.contract
+      indexing_maps = [#accessA, #accessB, #accessC]
+      ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>)
+      outs(%C: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
+// CHECK-SAME:      %[[A:.*]]: memref<9x5x3xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<9x7x5xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["reduction", "parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
+    %A: memref<9x5x3xf32>, %B: memref<9x7x5xf32>, %C: memref<3x7xf32>) {
+  linalg.contract
+      indexing_maps = [#accessA, #accessB, #accessC]
+      ins(%A, %B : memref<9x5x3xf32>, memref<9x7x5xf32>)
+      outs(%C: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0) -> ()>
+
+// CHECK-LABEL: func.func @contract_dot(
+// CHECK-SAME:      %[[A:.*]]: memref<9xf32>, %[[B:.*]]: memref<9xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<f32>) {
+
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["reduction"]
+// CHECK-NEXT:    ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+#accessAB = affine_map<(d0) -> (d0)>
+#accessC = affine_map<(d0) -> ()>
+func.func @contract_dot(
+    %A: memref<9xf32>, %B: memref<9xf32>, %C: memref<f32>) {
+  linalg.contract
+      indexing_maps = [#accessAB, #accessAB, #accessC]
+      ins(%A, %B : memref<9xf32>, memref<9xf32>)
+      outs(%C: memref<f32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_b(
+// CHECK-SAME:      %[[A:.*]]: memref<5xf32>, %[[B:.*]]: memref<5xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+
+// CHECK:         linalg.generic
+// CHECK-SAME:        indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:        iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-NEXT:    ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+#accessAB = affine_map<(d0, d1, d2) -> (d2)>
+#accessC = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @contract_matmul_bcast_a_b(
+    %A: memref<5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+  linalg.contract
+      indexing_maps = [#accessAB, #accessAB, #accessC]
+      ins(%A, %B : memref<5xf32>, memref<5xf32>)
+      outs(%C: memref<3x7xf32>)
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index c170c5be4abff9..bbd6e0fc8e2ccb 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -120,6 +120,58 @@ func.func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B
 
 // -----
 
+func.func @generalize_matmul_as_contraction_tensor_f16f64f32(
+    %A: tensor<16x8xf16>,
+    %B: tensor<8x32xf64>,
+    %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                      affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
+      outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @generalize_matmul_as_contraction_tensor_f16f64f32
+// CHECK:         ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+// Verify floating point extension and truncation.
+// CHECK-NEXT:      %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT:      %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
+// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT:      %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+// CHECK-NEXT:    -> tensor<16x32xf32>
+
+// -----
+
+func.func @generalize_matmul_as_contract_with_ext_and_trunc(
+    %A: tensor<24x12xf16>,
+    %B: tensor<12x25xf16>,
+    %C: tensor<24x25xf32>) -> tensor<24x25xf16> {
+  %0 = linalg.contract
+      indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+                       affine_map<(m, n, k) -> (k, n)>,
+                       affine_map<(m, n, k) -> (m, n)>]
+      ins(%A, %B : tensor<24x12xf16>, tensor<12x25xf16>)
+      outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+  %1 = arith.truncf %0 : tensor<24x25xf32> to tensor<24x25xf16>
+  func.return %1 : tensor<24x25xf16>
+}
+
+// CHECK-LABEL: @generalize_matmul_as_contract_with_ext_and_trunc
+// CHECK:         ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+// Verify floating point extension and truncation.
+// CHECK-NEXT:      %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT:      %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
+// CHECK-NEXT:      %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT:      %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+// CHECK-NEXT:    -> tensor<24x25xf32>
+// CHECK-NEXT:    %[[RES:.+]] = arith.truncf {{.*}} : tensor<24x25xf32> to tensor<24x25xf16>
+
+// -----
+
 func.func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
   %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
     ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 0853856d933035..0ea805fef5361f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -547,6 +547,104 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
 
 // -----
 
+func.func @invalid_indexing_maps_placement_contraction(
+    %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+3 {{custom op 'linalg.contract' expected 'indexing_maps' attribute}}
+  // NB: indexing_maps should be provided before ins and outs
+  linalg.contract
+      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>)
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+  return
+}
+
+// -----
+
+func.func @invalid_affine_map_in_indexing_maps_contraction(
+    %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{provided affine_map is not a projected permutation}}
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d2, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @
diff ering_iteration_space_of_affine_maps_contraction(
+    %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{iteration spaces of provided affine_maps 
diff er}}
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @mismatched_ranks_affine_map_and_operand_contraction(
+    %lhs: tensor<4x1x2xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{ranks of shaped operand and results of corresponding affine_map 
diff er}}
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%lhs, %rhs : tensor<4x1x2xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+// -----
+
+func.func @mismatch_type_affine_map_and_operand_contraction(
+    %lhs: f32, %rhs: tensor<4x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{affine_map specifies shaped access while operand has non-shaped type}}
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+                       affine_map<(d0, d1) -> (d0, d1)>,
+                       affine_map<(d0, d1) -> (d0, d1)>]
+      ins(%lhs, %rhs : f32, tensor<4x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @unused_iteration_space_dim_contraction(
+    %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{iteration space dim at index 3 not used to access any operand}}
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1)>]
+      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
+func.func @unused_iteration_space_dim_contraction(
+    %lhs: tensor<8x4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+1 {{iteration space dim at index 3 is neither a contracting dim nor of parallel iteration type}}
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1)>]
+      ins(%lhs, %rhs : tensor<8x4x1xf32>, tensor<1x64xf32>)
+      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return
+}
+
+// -----
+
 func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
   // expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
   linalg.conv_2d_nhwc_hwcf

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 6286a11c11a21f..efe8010cffc916 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -627,6 +627,53 @@ func.func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f
 //----------------------------------------------------------------------------//
 // Named ops to loops.
 //----------------------------------------------------------------------------//
+func.func @batch_reduce_matmul_as_contract(
+    %A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+      ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%C : memref<?x?xf32>)
+  return
+}
+// CHECK-LABEL: @batch_reduce_matmul_as_contract
+//  CHECK-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECK-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECK-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//       CHECK: %[[B:.*]] = memref.dim %[[mA]], %c0 : memref<?x?x?xf32>
+//       CHECK: %[[M:.*]] = memref.dim %[[mA]], %c1 : memref<?x?x?xf32>
+//       CHECK: %[[K:.*]] = memref.dim %[[mA]], %c2 : memref<?x?x?xf32>
+//       CHECK: %[[N:.*]] = memref.dim %[[mB]], %c2 : memref<?x?x?xf32>
+//       CHECK: scf.for %[[b:.*]] = %{{.*}} to %[[B]]
+//       CHECK:   scf.for %[[m:.*]] = %{{.*}} to %[[M]]
+//       CHECK:     scf.for %[[n:.*]] = %{{.*}} to %[[N]]
+//       CHECK:       scf.for %[[k:.*]] = %{{.*}} to %[[K]]
+//       CHECK:       %[[va:.*]] = memref.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECK:       %[[vb:.*]] = memref.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECK:       %[[vc:.*]] = memref.load %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+//       CHECK:       %[[inc:.*]] = arith.mulf %[[va]], %[[vb]] : f32
+//       CHECK:       %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+//       CHECK:       store %[[res]], %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+
+// CHECKPARALLEL-LABEL: @batch_reduce_matmul_as_contract
+//  CHECKPARALLEL-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//       CHECKPARALLEL: %[[B:.*]] = memref.dim %[[mA]], %c0 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[M:.*]] = memref.dim %[[mA]], %c1 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[K:.*]] = memref.dim %[[mA]], %c2 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[N:.*]] = memref.dim %[[mB]], %c2 : memref<?x?x?xf32>
+//       CHECKPARALLEL: scf.for %[[b:.*]] = %{{.*}} to %[[B]]
+//       CHECKPARALLEL:   scf.parallel (%[[m:.*]], %[[n:.*]]) = ({{.*}}) to (%[[M]], %[[N]]) step ({{.*}}) {
+//       CHECKPARALLEL:     scf.for %[[k:.*]] = %{{.*}} to %[[K]]
+//       CHECKPARALLEL:         %[[va:.*]] = memref.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[vb:.*]] = memref.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:         %[[vc:.*]] = memref.load %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+//       CHECKPARALLEL:         %[[inc:.*]] = arith.mulf %[[va]], %[[vb]] : f32
+//       CHECKPARALLEL:         %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:         store %[[res]], %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+
 func.func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
   linalg.batch_matmul ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
                      outs(%C : memref<?x?x?xf32>)

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 68aa5a85b5e0e6..578d24a550b08c 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1245,7 +1245,6 @@ func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -1259,7 +1258,6 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
                       ]
                       ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -1509,6 +1507,128 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: func @contract
+//       CHECK:   linalg.contract
+//  CHECK-SAME:     indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @contract(
+    %A: memref<2x3x5xf32>, %B: memref<2x5x7xf32>, %C: memref<2x3x7xf32>) {
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+      ins(%A, %B : memref<2x3x5xf32>, memref<2x5x7xf32>)
+      outs(%C: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @contract_matmul_bcast_a
+func.func @contract_matmul_bcast_a(%A: memref<5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) {
+// CHECK:  linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<5xf32>, memref<5x7xf32>)
+      outs(%C: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @contract_matmul_bcast_b
+func.func @contract_matmul_bcast_b(%A: memref<3x5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+// CHECK:  linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<3x5xf32>, memref<5xf32>)
+      outs(%C: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
+func.func @contract_matmul_bcast_a_b(
+    %A: memref<5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+// CHECK:  linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_A]], #[[$ACCESS_B]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<5xf32>, memref<5xf32>)
+      outs(%C: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_transpose_b
+func.func @contract_matmul_bcast_a_transpose_b(
+    %A: memref<5xf32>, %B: memref<7x5xf32>, %C: memref<3x7xf32>) {
+// CHECK:  linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<7x5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<5xf32>, memref<7x5xf32>)
+      outs(%C: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @contract_matmul_bcast_b_transpose_a
+func.func @contract_matmul_bcast_b_transpose_a(%A: memref<5x3xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+// CHECK:      linalg.contract
+// CHECK-SAME:     indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5xf32>)
+// CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+  linalg.contract
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>]
+      ins(%A, %B : memref<5x3xf32>, memref<5xf32>)
+      outs(%C: memref<3x7xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @mmt4d
 func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> {
   // CHECK: %{{.+}} = linalg.mmt4d

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 1b8969bd115595..dc556761b09e56 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -277,22 +277,34 @@ func.func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offs
 
 // -----
 
-
+#accessA = affine_map<(batch, m, n, k) -> (batch, m, k)>
+#accessB = affine_map<(batch, m, n, k) -> (batch, k, n)>
+#accessC = affine_map<(batch, m, n, k) -> (batch, m, n)>
 func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
                 %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
-  -> (tensor<?x?x?xf32>)
+  -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
 {
   linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
                      outs(%c3: memref<?x?x?xf32>)
+  linalg.contract
+      indexing_maps = [#accessA, #accessB, #accessC]
+      ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%c3: memref<?x?x?xf32>)
   %res1 = linalg.batch_matmul
                       ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
                      outs(%tc3: tensor<?x?x?xf32>)
                   -> tensor<?x?x?xf32>
-  return %res1 : tensor<?x?x?xf32>
+  %res2 = linalg.contract
+      indexing_maps = [#accessA, #accessB, #accessC]
+      ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+      outs(%tc3: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
 }
 // CHECK-LABEL: func @named_ops
 //       CHECK:   linalg.batch_matmul
+//       CHECK:   linalg.contract
 //       CHECK:   linalg.batch_matmul
+//       CHECK:   linalg.contract
 
 // -----
 

diff  --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 8f13c690704572..557233d8aa3ec4 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -37,6 +37,53 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK:       #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-NEXT:  #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-NEXT:  #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+#access_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                affine_map<(d0, d1, d2) -> (d2, d1)>,
+                affine_map<(d0, d1, d2) -> (d0, d1)>]
+
+// CHECK-LABEL: func @matmul_as_contract_tensors(
+// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @matmul_as_contract_tensors(
+  %A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
+//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
+//      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xf32>) {
+//      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTD:.*]] = linalg.contract
+// CHECK-SAME:          indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:          ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME:          outs(%[[sTC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
+//      CHECK:       %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<?x?xf32> into tensor<?x?xf32>
+//      CHECK:       scf.yield %[[TD]] : tensor<?x?xf32>
+//      CHECK:     scf.yield %[[TD2]] : tensor<?x?xf32>
+//      CHECK:   scf.yield %[[TD1]] : tensor<?x?xf32>
+  %0 = linalg.contract indexing_maps = #access_maps
+                       ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                       outs(%C: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+
+//      CHECK: return %[[TD0]] : tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 3, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @matmul_tensors_with_size_zeros(
 // CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
 // CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>

diff  --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index b688a677500c22..5ae3f893c2e739 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -82,6 +82,39 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @matmul_as_contract
+// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32>
+// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32>
+// CHECK-SAME: %[[C:.*]]: tensor<24x25xf32>
+func.func @matmul_as_contract(%A: tensor<24x12xf32>,
+                              %B: tensor<12x25xf32>,
+                              %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
+  // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
+  // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
+  // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+  // CHECK: vector.transfer_write %[[vR]], %[[C]]
+  %0 = linalg.contract
+      indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+                       affine_map<(m, n, k) -> (k, n)>,
+                       affine_map<(m, n, k) -> (m, n)>]
+      ins(%A, %B : tensor<24x12xf32>, tensor<12x25xf32>)
+      outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    // TODO: also tests the other available vectorization strategies
+    transform.yield
+  }
+}
+
+// -----
+
 #matmul_trait = {
   indexing_maps = [
     affine_map<(m, n, k) -> (m, k)>,


        


More information about the Mlir-commits mailing list