[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)

Renato Golin llvmlistbot at llvm.org
Tue Jan 21 03:15:51 PST 2025


================
@@ -680,6 +680,124 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+  let summary = [{
+    Perform a contraction on two inputs, accumulating on top of a third.
+  }];
+  let description = [{
+    The semantics of contracting inputs `A` and `B` on top of `C` to produce
+    output `D` is given by
+
+      `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+    where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
+    dimension identifiers (meant to range over valid indices), corresponding to
+    the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
+    and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
+    the dimensions in the set `dims`.
+
+    The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+    domain of each of the `affine_map`s. Like for einsums, the iteration type of
+    each dim is inferred and is either:
+
+    - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
+      Per the above semantics, these dims will be contracted, i.e. reduced over.
+
+    - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
+      deriving from matmul terminology - is either an "M-like" dim (if in `A`
+      and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
+      `B`, and `C`).
+
+    For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+    `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+    `n` and `b` are of parallel iteration-type) and gets represented as:
+
+    ```
+    %0 = linalg.contract
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
+        ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+        outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+    ```
+
+    Note that by permuting the dims in the co-domains of the `affine_map`s, we
+    can apply arbitrary transposes to the inputs and output. Similarly,
+    arbitrary broadcasts can be achieved through leaving out dims on either
+    input operand.
+
+    Numeric casting is performed on the operands to the inner multiplication,
+    promoting them to the same data type as the accumulator/output.
+  }];
+
+  let arguments = (ins
+    Variadic<AnyType>:$inputs,
+    Variadic<AnyShaped>:$outputs,
+    AffineMapArrayAttr:$indexing_maps
+  );
+  let results = (outs Variadic<AnyShaped>:$result_tensors);
+  let regions = (region SizedRegion<1>:$combiner);
----------------
rengolin wrote:

I'm curious as to what kind of errors you get. I agree with @MaheshRavishankar, if we don't need regions, we should declare them, and fix the errors. But it may not be a trivial fix (for example, interfaces and passes require linalg ops to have regions).

IIRC, in OpDSL, **all** operations have regions, even if they're not parsed/printed. That seemed to be a requirement from the time before OpDSL? Maybe it's something in the transforms that operate, or the interfaces that the ops implement. It would be good to identify the problem and understand the fix.

It's quite possible that the fix will be as big as this PR. If that's the case, I'd recommend leaving the region as is and mark it with a FIXME and then address after this PR.

https://github.com/llvm/llvm-project/pull/123618


More information about the Mlir-commits mailing list