[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)
Rolf Morel
llvmlistbot at llvm.org
Tue Jan 21 14:56:27 PST 2025
================
@@ -680,6 +680,124 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+ let summary = [{
+ Perform a contraction on two inputs, accumulating on top of a third.
+ }];
+ let description = [{
+ The semantics of contracting inputs `A` and `B` on top of `C` to produce
+ output `D` is given by
+
+ `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+ where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
+ dimension identifiers (meant to range over valid indices), corresponding to
+ the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
+ and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
+ the dimensions in the set `dims`.
+
+ The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+ domain of each of the `affine_map`s. Like for einsums, the iteration type of
+ each dim is inferred and is either:
+
+ - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
+ Per the above semantics, these dims will be contracted, i.e. reduced over.
+
+ - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
+ deriving from matmul terminology - is either an "M-like" dim (if in `A`
+ and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
+ `B`, and `C`).
+
+ For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+ `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+ `n` and `b` are of parallel iteration-type) and gets represented as:
+
+ ```
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
+ ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ ```
+
+ Note that by permuting the dims in the co-domains of the `affine_map`s, we
+ can apply arbitrary transposes to the inputs and output. Similarly,
+ arbitrary broadcasts can be achieved through leaving out dims on either
+ input operand.
+
+ Numeric casting is performed on the operands to the inner multiplication,
+ promoting them to the same data type as the accumulator/output.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ AffineMapArrayAttr:$indexing_maps
+ );
+ let results = (outs Variadic<AnyShaped>:$result_tensors);
+ let regions = (region SizedRegion<1>:$combiner);
----------------
rolfmorel wrote:
Thanks both - I agree it would be nicer for linalg ops to not surreptitiously have a region attached.
Having said that, I looked at `LinalgStructuredInterface` and there we find methods like
```
InterfaceMethod<
/*desc=*/[{
Return the single block constituting the body of the operation by
calling the getBody method on the concrete operation.
}],
/*retTy=*/"Block*",
/*methodName=*/"getBlock",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Assume the concrete operation implements the
// SingleBlockImplicitTerminator trait.
return $_op.getBody();
}]
>
...
InterfaceMethod<
/*desc=*/[{
Return the input block arguments of the region.
}],
/*retTy=*/"Block::BlockArgListType",
/*methodName=*/"getRegionInputArgs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getBlock()->getArguments().take_front($_op.getNumDpsInputs());
}]
>
...
```
The above means that the assumption that linalg ops have a region is actually baked into the interface. (Note that there are other `LinalgStructuredInterface` methods that also make use of `getBlock()` _and_ assume that it will return a non-null pointer.)
The [revision that introduced `getBlock()`](https://reviews.llvm.org/D111393) is also clear: that the last linalg ops without a region had been removed meant it go into the interface. That every named op has a region is also used by the `generalizeNamedOp` transform - per
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp#L66-72:
```
// All named ops have a region attached that can be inlined.
assert(linalgOp->getNumRegions() == 1 &&
"expect named op to have one region attached");
GenericOp genericOp = rewriter.create<GenericOp>(
linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators);
rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
genericOp.getRegion().begin());
```
As such, my take is that for any linalg op to be regionless requires `LinalgStructuredInterface` changes and might well cause breakage downstream due to violating a long-standing assumption. As such such changes require due consideration, more than a discussion in PR comments.
That still leaves open the possibility that having an attached region for `linalg.contract` can actually cause issues. My answer to that is "only if you choose to go out of your way to do it." That is, the op's custom parser will _not_ accept parsing a region. The builders for the op do _not_ take arguments that can directly affect the region (just indirectly such as causing casting to be added to the body). The only "backdoor" is to get the region's block and start to actively modify it. Per the foregoing, this is possible for all (named) linalg ops.
I already checked with @rengolin - he agrees with that as the interface has an "attached region" assumption we should deal with that separately before pushing for regionless ops.
@MaheshRavishankar, given the above, is it okay with you if we deal with regionless linalg ops in a more principled way outside of this PR?
https://github.com/llvm/llvm-project/pull/123618
More information about the Mlir-commits
mailing list