[Mlir-commits] [mlir] [MLIR][Linalg] Add pass to convert linalg.generic back to named ops (PR #95656)

Andrzej Warzyński llvmlistbot at llvm.org
Wed Jun 26 12:24:11 PDT 2024


================
@@ -58,6 +68,195 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
   return swapped;
 }
 
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+//    ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing a matmul but not
+// in a straight-forward way e.g. below is matrix multiply over some slice
+// ```
+//  %0 = linalg.generic {
+//          indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
+//                           affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
+//                           affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
+//          iterator_types = ["parallel", "parallel", "parallel"]}
+//          ins(%A, %B : tensor<20x20x20xf32>,  tensor<20x20x20xf32>)
+//          outs(%C : tensor<20x20x20xf32>) {
+//             ^bb0(%a: f32, %b: f32, %c : f32):
+//                %mul = arith.mulf %a, %b : f32
+//                %add = arith.addf %mul, %c : f32
+//                linalg.yield %add : f32
+//       } -> tensor<20x20x20xf32>
+// ```
+// It is not possible to represent above as named op.
+// e.g. linalg.batch_matmul(%A, %B :  tensor<20x20x20xf32>, ...) is
+// not  the same as linalg.generic above.
+namespace {
+enum class IndexMatchResult {
+  Match = 0,  // identity map.
+  Transposed, // transposed map.
+  Mismatch    // none of the above.
+};
+
+// Consider the A matrix in `C[M,N] = A[M,K] * B[K,N]`. Below, we
+// check whether the index map of A is identity (match), transposed, or
+// something completely different (mis-match).
+// The naming and explanation is in terms of A, but the function checks
+// effectively maps for all A, B, C i.e. <M,N>, <M, K>, <K,N>.
----------------
banach-space wrote:

I understand what you are trying to achieve here, but I still find the documentation confusing 😅 I appreciate that that's a subjective matter, but IMHO a specific example should be use to complement a generic description rather than replace one.

Here's a complete suggestion. Note, I replaced `Dim1` and Dim2` from my original suggestion with `RowDim` and `ColDim` (IIUC, that's what these are):
```cpp
// Checks whether the input Affine `map` contains two consecutive dims that can
// be interpreted as accessing a 2D matrix. It is assumed that the row and
// column dimension are located next to each other (in this order) and start at
// `rowDimIdx` in the input map.
//
// YOUR SPECIFIC EXAMPLE WITH MATRIX A <<HERE>>
static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
                                        unsigned expectedPosOfRowDim,
                                        unsigned expectedPosOfColDim) {
  // Get the matrix multiply indices. They are past the batch indices.
  auto exprOfRowDim = map.getResults()[rowDimIdx];
  auto exprOfColDim = map.getResults()[rowDimIdx + 1];

  // They should be pure dim ids.
  if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
      exprOfColDim.getKind() != AffineExprKind::DimId)
    return IndexMatchResult::Mismatch;

  auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
  auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();

  if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
    return IndexMatchResult::Match;

  if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
    return IndexMatchResult::Transposed;

  return IndexMatchResult::Mismatch;
}
```

Feel free to re-use (and/or change). I just feel that we shouldn't be referring to `batchSize` and/or "dimension M"/"dimension K" in such a generic hook. For example, from the point of view of this hook it doesn't matter what the batch size is, neither does what "M" and "K" are.

HTH

https://github.com/llvm/llvm-project/pull/95656


More information about the Mlir-commits mailing list