[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)

Tue Jan 21 01:45:03 PST 2025

@@ -3611,5 +3613,175 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+// ContractOp
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+  /// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+  /// domains are all the same, and each implements a projected permutation.
+  /// Each dim in the domain must occur for at least one operand and is
+  /// classified as either batch, N-like, M-like, or K-like. Only the latter
+  /// corresponds to a reduction _and_ it is the only dim-kind which does not
+  /// occur for the output operand. We use this fact for fast inference:
+  // NB: In case we allow dims to occur solely for one input, the above still
+  //     holds: per the einsum semantics, these are reduction dims as well.
+  auto dimsInOutput = SmallVector<bool>(outAffineMap.getNumDims(), false);
+  for (auto result : outAffineMap.getResults()) {
+    auto dimExpr = dyn_cast<AffineDimExpr>(result);
+    assert(dimExpr && "affine_map is a projected permutation");
+    dimsInOutput[dimExpr.getPosition()] = true;
+  }
+  SmallVector<utils::IteratorType> iteratorTypes;
+  for (auto dimOccursInOutput : dimsInOutput)
+    iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+                                              : utils::IteratorType::reduction);
+  return iteratorTypes;
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                               ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "ContractOp regionBuilder expects 3 args");
+  RegionBuilderHelper helper(b, block);
+  TypeFn castSignedness = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castSignedness = attr.getValue();
+  }
+  // TODO: Support fields with operators besides mult & add.
+  Type outType = block.getArgument(2).getType();
+  Value lhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+  Value rhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+  Value productAtOutType =
+      helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+                                      productAtOutType);
+  helper.yieldOutputs({result});
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected 'indexing_map' attribute");
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
+  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+                                regionBuilder);
+void ContractOp::print(OpAsmPrinter &p) {
+  p << " indexing_maps = [";
+  llvm::interleaveComma(getIndexingMaps(), p,
+                        [&](Attribute attr) { p.printAttribute(attr); });
+  p << "]";
+  printNamedStructuredOp(
+      p, getOperation(), getInputs(), getOutputs(),
+      /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+LogicalResult ContractOp::verify() {
+  int iterationSpaceDims = -1;
+  // Maps iter space dim (as index) to num of occurrences in inputs and output.
+  SmallVector<size_t> inOccurrences;
+  SmallVector<size_t> outOccurrences;
+  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
+                                   bool isInput) -> LogicalResult {
+    if (iterationSpaceDims == -1) {
+      iterationSpaceDims = affineMap.getNumDims();
+      inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+      outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+    } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
+      return emitError("iteration spaces of provided affine_maps differ");
+    }
+    if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
+      if (affineMap.getNumResults() != shapedType.getRank())
+        return emitError("ranks of shaped operand and co-domain of "
+                         "corresponding affine_map differ");
+    } else if (affineMap.getNumResults() != 0) {
+      return emitError("affine_map specifies shaped access while operand has "
+                       "non-shaped type");
+    }
+    if (!affineMap.isProjectedPermutation())
+      return emitError("provided affine_map is not a projected permutation");
+    for (AffineExpr affineExpr : affineMap.getResults()) {
+      auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
+      if (!affineDimExpr)
+        llvm_unreachable("affine_map is a projected permutation");
+      if (isInput)
+        inOccurrences[affineDimExpr.getPosition()] += 1;
+      else
+        outOccurrences[affineDimExpr.getPosition()] += 1;
+    }
+    return success();
+  };
+  for (auto &&[affineMap, operandType, isInput] :
+       llvm::zip(getIndexingMapsArray(), getOperandTypes(),
+                 SmallVector<bool>{true, true, false}))
+    if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
+      return failure(); // NOTE: checking lambda will emit error.
+  bool hasContractingDim = false;
+  for (auto &&[inOccCount, outOccCount] : zip(inOccurrences, outOccurrences)) {
+    hasContractingDim |= inOccCount == 2 && outOccCount == 0;
+    if (inOccCount == 0)
+      return emitError("iteration space dim not used by either input");
+    // NB: A dim which occurs for only one input operand and not for the output.
+    //     In terms of einsum semantics, such dims have a sensible meaning -
+    //     namely an additional reduction per such dim - though this can also
+    //     always be expressed through an additional op. Additionally, at time
+    //     of writing, vector.contract's verifier accepts these dims but many of
+    //     its lowerings do not handle these kinds of dims. Hence...
+    // TODO: Remove following once we have comprehensive support for input-only
+    //       reduction dims, at both the linalg- and vector-dialect levels.
+    if (inOccCount == 1 && outOccCount != 1)
+      return emitError("iter type of dim is not one of M, N, K or batch");
adam-smnk wrote:

nit: I feel this error is a bit unintuitive (especially in context of arbitrary N dims) but maybe it's just me


