[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Jan 28 00:52:45 PST 2025


================
@@ -3611,5 +3607,183 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ContractOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+  // domains are all the same, and each implements a projected permutation.
+  // Each iteration space dim must occur for at least one operand and either
+  // takes part in a contraction/reduction or else has parallel iteration type.
+  // We have that a dim is a contraction/reduction dim if and only if the dim
+  // occurs for the output operand. We use this fact for fast inference:
+  // NB: In case we allow dims to occur solely for one input, the above still
+  //     holds: per the einsum semantics, these are reduction dims as well.
+  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
+  for (auto result : outAffineMap.getResults()) {
+    auto dimExpr = dyn_cast<AffineDimExpr>(result);
+    assert(dimExpr && "affine_map is a projected permutation");
+    dimsInOutput[dimExpr.getPosition()] = true;
+  }
+
+  SmallVector<utils::IteratorType> iteratorTypes;
+  for (auto dimOccursInOutput : dimsInOutput)
+    iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+                                              : utils::IteratorType::reduction);
+
+  return iteratorTypes;
+}
+
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                               ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "ContractOp regionBuilder expects 3 args");
+  RegionBuilderHelper helper(b, block);
+
+  TypeFn castSignedness = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castSignedness = attr.getValue();
+  }
+
+  // TODO: Support fields with operators besides mult & add.
+  Type outType = block.getArgument(2).getType();
+  Value lhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+  Value rhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+  Value productAtOutType =
+      helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+                                      productAtOutType);
+  helper.yieldOutputs({result});
+}
+
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected 'indexing_maps' attribute");
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
+
+  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+                                regionBuilder);
+}
+
+void ContractOp::print(OpAsmPrinter &p) {
+  p << " indexing_maps = [";
+  llvm::interleaveComma(getIndexingMaps(), p,
+                        [&](Attribute attr) { p.printAttribute(attr); });
+  p << "]";
+  printNamedStructuredOp(
+      p, getOperation(), getInputs(), getOutputs(),
+      /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+}
+
+LogicalResult ContractOp::verify() {
+  int iterationSpaceDims = -1;
+  // Maps iter space dim (as index) to num of occurrences in inputs and output.
----------------
banach-space wrote:

Just because this is a vector rather than a map, could you add a small example? Is it e.g. `inOccurances[0]` for dim-0?

https://github.com/llvm/llvm-project/pull/123618


More information about the Mlir-commits mailing list