[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)
Andrzej Warzyński
llvmlistbot at llvm.org
Tue Jan 28 00:52:46 PST 2025
================
@@ -3611,5 +3607,183 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// ContractOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+ AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+ // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+ // domains are all the same, and each implements a projected permutation.
+ // Each iteration space dim must occur for at least one operand and either
+ // takes part in a contraction/reduction or else has parallel iteration type.
+ // We have that a dim is a contraction/reduction dim if and only if the dim
+ // occurs for the output operand. We use this fact for fast inference:
+ // NB: In case we allow dims to occur solely for one input, the above still
+ // holds: per the einsum semantics, these are reduction dims as well.
+ SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
+ for (auto result : outAffineMap.getResults()) {
+ auto dimExpr = dyn_cast<AffineDimExpr>(result);
+ assert(dimExpr && "affine_map is a projected permutation");
+ dimsInOutput[dimExpr.getPosition()] = true;
+ }
+
+ SmallVector<utils::IteratorType> iteratorTypes;
+ for (auto dimOccursInOutput : dimsInOutput)
+ iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+ : utils::IteratorType::reduction);
+
+ return iteratorTypes;
+}
+
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "ContractOp regionBuilder expects 3 args");
+ RegionBuilderHelper helper(b, block);
+
+ TypeFn castSignedness = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castSignedness = attr.getValue();
+ }
+
+ // TODO: Support fields with operators besides mult & add.
+ Type outType = block.getArgument(2).getType();
+ Value lhsAtOutType =
+ helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+ Value rhsAtOutType =
+ helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+ Value productAtOutType =
+ helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+ Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+ productAtOutType);
+ helper.yieldOutputs({result});
+}
+
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+ FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+ if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected 'indexing_maps' attribute");
+ result.addAttribute("indexing_maps", *indexingMapsAttr);
+
+ return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+ regionBuilder);
+}
+
+void ContractOp::print(OpAsmPrinter &p) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ printNamedStructuredOp(
+ p, getOperation(), getInputs(), getOutputs(),
+ /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+}
+
+LogicalResult ContractOp::verify() {
+ int iterationSpaceDims = -1;
+ // Maps iter space dim (as index) to num of occurrences in inputs and output.
+ SmallVector<size_t> inOccurrences;
+ SmallVector<size_t> outOccurrences;
+
+ auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
+ bool isInput) -> LogicalResult {
+ if (iterationSpaceDims == -1) {
+ iterationSpaceDims = affineMap.getNumDims();
+ inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+ outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+ } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
+ return emitError("iteration spaces of provided affine_maps differ");
+ }
+
+ if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
+ if (affineMap.getNumResults() != shapedType.getRank())
+ return emitError("ranks of shaped operand and co-domain of "
+ "corresponding affine_map differ");
+ } else if (affineMap.getNumResults() != 0) {
+ return emitError("affine_map specifies shaped access while operand has "
+ "non-shaped type");
+ }
+
+ if (!affineMap.isProjectedPermutation())
+ return emitError("provided affine_map is not a projected permutation");
+
+ for (AffineExpr affineExpr : affineMap.getResults()) {
+ auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
+ if (!affineDimExpr)
+ llvm_unreachable("affine_map is a projected permutation");
+
+ if (isInput)
+ inOccurrences[affineDimExpr.getPosition()] += 1;
+ else
+ outOccurrences[affineDimExpr.getPosition()] += 1;
+ }
+
+ return success();
+ };
----------------
banach-space wrote:
This is a matter of personal preference, but to me something this long deserves a proper helper function 😅
Regardless of whether this is a lambda or a free function, could you add a quick description and a summary of conditions being checked?
https://github.com/llvm/llvm-project/pull/123618
More information about the Mlir-commits
mailing list