[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)

Rolf Morel llvmlistbot at llvm.org
Tue Jan 28 14:35:04 PST 2025


================
@@ -680,6 +680,125 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+  let summary = [{
+    Perform a contraction on two inputs, accumulating into the third.
+  }];
+  let description = [{
+    The semantics of contracting inputs `A` and `B` on top of `C` to produce
+    output `D` is given by
+
+      `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+    where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
+    dimension identifiers (meant to range over valid indices), corresponding to
+    the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
+    and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
+    the dimensions in the set `dims`.
+
+    The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+    domain of each of the `affine_map`s. Like for einsums, the iteration type of
+    each dim is inferred and is either:
+
+    - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
+      Per the above semantics, these dims will be contracted, i.e. reduced over.
+
+    - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
+      deriving from matmul terminology - is either an "M-like" dim (if in `A`
+      and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
+      `B`, and `C`).
+
+    For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+    `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+    `n` and `b` are of parallel iteration-type) and gets represented as:
+
+    ```
+    %0 = linalg.contract
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
+        ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+        outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+    ```
+
+    Note that by permuting the dims in the co-domains of the `affine_map`s, we
+    can apply arbitrary transposes to the inputs and output. Similarly,
+    arbitrary broadcasts can be achieved through leaving out dims on either
+    input operand.
----------------
rolfmorel wrote:

Indeed - both broadcasting and iterator type inference follow from which dims are missing on operands. Have added a comment to indicate that dims that are involved with a broadcast necessarily have parallel iter type.

I have also added another example to the docs to illustrate both broadcasting and transpositions. The example is hopefully clear (and familiar) enough for most to grasp what broadcasts will look like once folded into a `contract`.

https://github.com/llvm/llvm-project/pull/123618


More information about the Mlir-commits mailing list