[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)
Rolf Morel
llvmlistbot at llvm.org
Tue Jan 21 16:17:05 PST 2025
================
@@ -3611,5 +3613,175 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// ContractOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+ AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+ /// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+ /// domains are all the same, and each implements a projected permutation.
+ /// Each dim in the domain must occur for at least one operand and is
+ /// classified as either batch, N-like, M-like, or K-like. Only the latter
+ /// corresponds to a reduction _and_ it is the only dim-kind which does not
+ /// occur for the output operand. We use this fact for fast inference:
+ // NB: In case we allow dims to occur solely for one input, the above still
+ // holds: per the einsum semantics, these are reduction dims as well.
+ auto dimsInOutput = SmallVector<bool>(outAffineMap.getNumDims(), false);
+ for (auto result : outAffineMap.getResults()) {
+ auto dimExpr = dyn_cast<AffineDimExpr>(result);
+ assert(dimExpr && "affine_map is a projected permutation");
+ dimsInOutput[dimExpr.getPosition()] = true;
+ }
+
+ SmallVector<utils::IteratorType> iteratorTypes;
+ for (auto dimOccursInOutput : dimsInOutput)
+ iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+ : utils::IteratorType::reduction);
+
+ return iteratorTypes;
+}
+
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "ContractOp regionBuilder expects 3 args");
+ RegionBuilderHelper helper(b, block);
+
+ TypeFn castSignedness = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castSignedness = attr.getValue();
+ }
+
+ // TODO: Support fields with operators besides mult & add.
+ Type outType = block.getArgument(2).getType();
+ Value lhsAtOutType =
+ helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+ Value rhsAtOutType =
+ helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+ Value productAtOutType =
+ helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+ Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+ productAtOutType);
+ helper.yieldOutputs({result});
+}
+
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+ FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+ if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected 'indexing_map' attribute");
+ result.addAttribute("indexing_maps", *indexingMapsAttr);
+
+ return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+ regionBuilder);
+}
+
+void ContractOp::print(OpAsmPrinter &p) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ printNamedStructuredOp(
+ p, getOperation(), getInputs(), getOutputs(),
+ /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+}
+
+LogicalResult ContractOp::verify() {
+ int iterationSpaceDims = -1;
+ // Maps iter space dim (as index) to num of occurrences in inputs and output.
+ SmallVector<size_t> inOccurrences;
+ SmallVector<size_t> outOccurrences;
+
+ auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
+ bool isInput) -> LogicalResult {
+ if (iterationSpaceDims == -1) {
+ iterationSpaceDims = affineMap.getNumDims();
+ inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+ outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+ } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
+ return emitError("iteration spaces of provided affine_maps differ");
+ }
+
+ if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
+ if (affineMap.getNumResults() != shapedType.getRank())
+ return emitError("ranks of shaped operand and co-domain of "
+ "corresponding affine_map differ");
+ } else if (affineMap.getNumResults() != 0) {
+ return emitError("affine_map specifies shaped access while operand has "
+ "non-shaped type");
+ }
+
+ if (!affineMap.isProjectedPermutation())
+ return emitError("provided affine_map is not a projected permutation");
+
+ for (AffineExpr affineExpr : affineMap.getResults()) {
+ auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
+ if (!affineDimExpr)
+ llvm_unreachable("affine_map is a projected permutation");
+
+ if (isInput)
+ inOccurrences[affineDimExpr.getPosition()] += 1;
+ else
+ outOccurrences[affineDimExpr.getPosition()] += 1;
+ }
+
+ return success();
+ };
+
+ for (auto &&[affineMap, operandType, isInput] :
+ llvm::zip(getIndexingMapsArray(), getOperandTypes(),
+ SmallVector<bool>{true, true, false}))
+ if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
+ return failure(); // NOTE: checking lambda will emit error.
+
+ bool hasContractingDim = false;
+ for (auto &&[inOccCount, outOccCount] : zip(inOccurrences, outOccurrences)) {
+ hasContractingDim |= inOccCount == 2 && outOccCount == 0;
+
+ if (inOccCount == 0)
+ return emitError("iteration space dim not used by either input");
+
+ // NB: A dim which occurs for only one input operand and not for the output.
+ // In terms of einsum semantics, such dims have a sensible meaning -
+ // namely an additional reduction per such dim - though this can also
+ // always be expressed through an additional op. Additionally, at time
+ // of writing, vector.contract's verifier accepts these dims but many of
+ // its lowerings do not handle these kinds of dims. Hence...
+ // TODO: Remove following once we have comprehensive support for input-only
+ // reduction dims, at both the linalg- and vector-dialect levels.
+ if (inOccCount == 1 && outOccCount != 1)
+ return emitError("iter type of dim is not one of M, N, K or batch");
----------------
rolfmorel wrote:
P.S. I also improved the preceding comments a bit.
https://github.com/llvm/llvm-project/pull/123618
More information about the Mlir-commits
mailing list