[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)
Rolf Morel
llvmlistbot at llvm.org
Tue Jan 28 05:51:36 PST 2025
================
@@ -680,6 +680,125 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+ let summary = [{
+ Perform a contraction on two inputs, accumulating into the third.
+ }];
+ let description = [{
+ The semantics of contracting inputs `A` and `B` on top of `C` to produce
+ output `D` is given by
+
+ `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+ where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
+ dimension identifiers (meant to range over valid indices), corresponding to
+ the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
+ and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
+ the dimensions in the set `dims`.
+
+ The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+ domain of each of the `affine_map`s. Like for einsums, the iteration type of
+ each dim is inferred and is either:
+
+ - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
+ Per the above semantics, these dims will be contracted, i.e. reduced over.
+
+ - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
+ deriving from matmul terminology - is either an "M-like" dim (if in `A`
+ and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
+ `B`, and `C`).
+
+ For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+ `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+ `n` and `b` are of parallel iteration-type) and gets represented as:
+
+ ```
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
+ ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ ```
+
+ Note that by permuting the dims in the co-domains of the `affine_map`s, we
+ can apply arbitrary transposes to the inputs and output. Similarly,
+ arbitrary broadcasts can be achieved through leaving out dims on either
+ input operand.
+
+ Numeric casting is performed on the operands to the inner multiplication,
+ promoting them to the same data type as the accumulator/output.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ AffineMapArrayAttr:$indexing_maps
+ );
+ let results = (outs Variadic<AnyShaped>:$result_tensors);
+ let regions = (region SizedRegion<1>:$combiner);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("indexing_maps", indexingMaps);
+ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
+ outputs, attributes, regionBuilder);
+ }]>,
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+ "ArrayAttr":$indexingMaps,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("indexing_maps", indexingMaps);
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, regionBuilder);
+ }]>
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ // Declare/implement functions necessary for LinalgStructuredInterface.
+
+ /// Infer iterator types for each dim in the domain of IndexingMaps.
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// IndexingMaps always depends on attr associated to current Op instance.
+ bool hasDynamicIndexingMaps() { return true; };
+ bool hasUserDefinedMaps() { return true; };
----------------
rolfmorel wrote:
Indeed, no default maps whatsoever. This is reflected in the parser, e.g. error on them not being specified before `ins(..)`, and in the builders, i.e. must provide the indexing maps.
Have added the word 'mandatory' now to the docs to highlight they are non-optional. I agree this is good to point out, as - as you say - this is the first structured op besides generic to require them.
https://github.com/llvm/llvm-project/pull/123618
More information about the Mlir-commits
mailing list