[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)

Rolf Morel llvmlistbot at llvm.org
Wed Jan 29 07:39:16 PST 2025


================
@@ -3611,5 +3607,192 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ContractOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+  // domains are all the same, and each implements a projected permutation.
+  // Each iteration space dim must occur for at least one operand and either
+  // takes part in a contraction/reduction or else has parallel iteration type.
+  // We have that a dim is a contraction/reduction dim if and only if the dim
+  // occurs for the output operand. We use this fact for fast inference:
+  // NB: In case we allow dims to occur solely for one input, the above still
+  //     holds: per the einsum semantics, these are reduction dims as well.
+  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
+  for (auto result : outAffineMap.getResults()) {
+    auto dimExpr = dyn_cast<AffineDimExpr>(result);
+    assert(dimExpr && "affine_map is a projected permutation");
+    dimsInOutput[dimExpr.getPosition()] = true;
+  }
+
+  SmallVector<utils::IteratorType> iteratorTypes;
+  for (auto dimOccursInOutput : dimsInOutput)
+    iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+                                              : utils::IteratorType::reduction);
+
+  return iteratorTypes;
+}
+
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                               ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "ContractOp regionBuilder expects 3 args");
+  RegionBuilderHelper helper(b, block);
+
+  TypeFn castSignedness = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castSignedness = attr.getValue();
+  }
+
+  // TODO: Support fields with operators besides mult & add.
+  Type outType = block.getArgument(2).getType();
+  Value lhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+  Value rhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+  Value productAtOutType =
+      helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+                                      productAtOutType);
+  helper.yieldOutputs({result});
+}
+
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected 'indexing_maps' attribute");
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
+
+  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+                                regionBuilder);
+}
+
+void ContractOp::print(OpAsmPrinter &p) {
+  p << " indexing_maps = [";
+  llvm::interleaveComma(getIndexingMaps(), p,
+                        [&](Attribute attr) { p.printAttribute(attr); });
+  p << "]";
+  printNamedStructuredOp(
+      p, getOperation(), getInputs(), getOutputs(),
+      /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+}
+
+LogicalResult ContractOp::verify() {
+  int iterationSpaceDims = -1;
+  // Map iter space dims to #occurrences in inputs' and output's affine_maps:
+  // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
+  // access an input operand (so occurrence count can be at most 2) and
+  // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
+  SmallVector<size_t> inOccurrences;
+  SmallVector<size_t> outOccurrences;
+
+  // For each operand's affine_map and type, check that the rank of the
+  // affine_map's domain is the same as those seen prior, check that the
+  // affine_map's co-domain rank is the same as that of the corresponding type,
+  // check that the affine_map is a projected permutation, and, finally, update
+  // inputs and output occurrence counts for dims in the co-domains.
----------------
rolfmorel wrote:

Good call! Have now broken it up into appropriate inline comments.

https://github.com/llvm/llvm-project/pull/123618


More information about the Mlir-commits mailing list