[Mlir-commits] [mlir] [MLIR][Linalg] Add pass to convert linalg.generic back to named ops (PR #95656)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Jun 19 01:41:18 PDT 2024


================
@@ -58,6 +68,176 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
   return swapped;
 }
 
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+//    ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing one of matmul
+// variants but not in a straight-forward way, or the linalg.generic's
+// affine map per operand capture more semantics than is possible with
+// named op (which has implicit map interpreted via name).
+//
+// But a named linalg matmul variant that was 'generalized' should be
+// convertible back to named op here.
+//
+namespace {
+enum class IndexMatchResult {
+  Match = 0,  // identity map.
+  Transposed, // transposed map.
+  Mismatch    // none of the above.
+};
+
+// Looks at the affine map of an operand and works out if generic accesses
+// the element as identity-map, transposed, or 'cant work out'.
+// This check skips the `offset` batch indices and focuses on the matmul part.
----------------
banach-space wrote:

> Looks at the affine map of an operand 

This function doesn't take any operands as inputs ;-) 

> unsigned i, unsigned j

What's `i` and `j`? Are you assuming that the dims of A and B in A * B are `I x K` and `K x J`, respectively? If yes, please document. In various other places we would assume `N x K` and `M x K` instead, so it's very helpful when these things are made clear.

Separately, it feel that this method should be able to re-use stuff from AffineMap.h. For example: https://github.com/llvm/llvm-project/blob/891ec2af45c02718c65f539cb6dad1758f079e73/mlir/include/mlir/IR/AffineMap.h#L137-L140

In particular - why aren't we verifying that `getNumDims() == getNumResults())`?

https://github.com/llvm/llvm-project/pull/95656


More information about the Mlir-commits mailing list