[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)
Andrzej Warzyński
llvmlistbot at llvm.org
Tue Jan 28 00:52:45 PST 2025
================
@@ -680,6 +680,125 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+ let summary = [{
+ Perform a contraction on two inputs, accumulating into the third.
+ }];
+ let description = [{
+ The semantics of contracting inputs `A` and `B` on top of `C` to produce
+ output `D` is given by
+
+ `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+ where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
+ dimension identifiers (meant to range over valid indices), corresponding to
+ the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
+ and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
+ the dimensions in the set `dims`.
+
+ The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+ domain of each of the `affine_map`s. Like for einsums, the iteration type of
+ each dim is inferred and is either:
+
+ - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
+ Per the above semantics, these dims will be contracted, i.e. reduced over.
+
+ - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
+ deriving from matmul terminology - is either an "M-like" dim (if in `A`
+ and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
+ `B`, and `C`).
+
+ For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+ `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+ `n` and `b` are of parallel iteration-type) and gets represented as:
+
+ ```
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
+ ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ ```
+
+ Note that by permuting the dims in the co-domains of the `affine_map`s, we
+ can apply arbitrary transposes to the inputs and output. Similarly,
+ arbitrary broadcasts can be achieved through leaving out dims on either
+ input operand.
----------------
banach-space wrote:
Question about broadcasts.
Iteration type is inferred as described above. It feels that that "inference" + broadcasting should be related, right? As in, both bits of logic "look at" missing dims in the index in maps?
Mostly making sure my mental model is correct. And that the docs are complete :)
https://github.com/llvm/llvm-project/pull/123618
More information about the Mlir-commits
mailing list