[Mlir-commits] [mlir] [MLIR][Linalg] Add pass to convert linalg.generic back to named ops (PR #95656)

Renato Golin llvmlistbot at llvm.org
Tue Jun 25 01:54:54 PDT 2024


================
@@ -58,6 +68,195 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
   return swapped;
 }
 
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+//    ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing a matmul but not
+// in a straight-forward way e.g. below is matrix multiply over some slice
+// ```
+//  %0 = linalg.generic {
+//          indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
+//                           affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
+//                           affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
+//          iterator_types = ["parallel", "parallel", "parallel"]}
+//          ins(%A, %B : tensor<20x20x20xf32>,  tensor<20x20x20xf32>)
+//          outs(%C : tensor<20x20x20xf32>) {
+//             ^bb0(%a: f32, %b: f32, %c : f32):
+//                %mul = arith.mulf %a, %b : f32
+//                %add = arith.addf %mul, %c : f32
+//                linalg.yield %add : f32
+//       } -> tensor<20x20x20xf32>
+// ```
+// It is not possible to represent above as named op.
+// e.g. linalg.batch_matmul(%A, %B :  tensor<20x20x20xf32>, ...) is
+// not  the same as linalg.generic above.
+namespace {
+enum class IndexMatchResult {
+  Match = 0,  // identity map.
+  Transposed, // transposed map.
+  Mismatch    // none of the above.
+};
+
+// Consider the A matrix in `C[M,N] = A[M,K] * B[K,N]`. Below, we
+// check whether the index map of A is identity (match), transposed, or
+// something completely different (mis-match).
+// The naming and explanation is in terms of A, but the function checks
+// effectively maps for all A, B, C i.e. <M,N>, <M, K>, <K,N>.
+static IndexMatchResult matchOperandMap(AffineMap map, unsigned batchSize,
+                                        unsigned expectedPosOfM,
+                                        unsigned expectedPosOfK) {
+  // Get the matrix multiply indices. They are past the batch indices.
+  auto exprOfM = map.getResults()[batchSize];
+  auto exprOfK = map.getResults()[batchSize + 1];
+
+  // They should be pure dim ids.
+  if (exprOfM.getKind() != AffineExprKind::DimId ||
+      exprOfK.getKind() != AffineExprKind::DimId)
+    return IndexMatchResult::Mismatch;
+
+  auto posM = cast<AffineDimExpr>(exprOfM).getPosition();
+  auto posK = cast<AffineDimExpr>(exprOfK).getPosition();
+
+  if (expectedPosOfM == posM && expectedPosOfK == posK)
+    return IndexMatchResult::Match;
+
+  if (expectedPosOfM == posK && expectedPosOfK == posM)
+    return IndexMatchResult::Transposed;
+
+  return IndexMatchResult::Mismatch;
+}
+
+//  All the variants `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+//  have same number of input/output.
+template <typename Variant>
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+  LinalgOp namedOp = rewriter.replaceOpWithNewOp<Variant>(
+      op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
+      ValueRange{op.getDpsInits()[0]});
+  return namedOp;
+}
+
+// Converts linalg.generic to named linalg.*matmul* where possible.
+static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
+                                                        GenericOp genericOp) {
+  if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+    return failure();
+
+  // Linalg generic contraction can be across multiple axis but for matmul
+  // variants it must be one.
+  if (genericOp.getNumReductionLoops() != 1)
+    return failure();
+
+  // Must be projected permutations.
+  auto mapRange = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(mapRange,
+                   [](AffineMap m) { return !m.isProjectedPermutation(); }))
+    return failure();
+
+  //  matmul contractions are of the form:
+  //  %0 = <elemwise>(permutation-of(cu(block-argument-0),
+  //                                 cu(block-argument-1)))
+  //  %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
+  //
+  //  where <elemwise> and <reduce> are binary operations constituting a
+  //  contraction (in the canonical case, <elemwise> is a multiplication and
+  //  <reduce> is an addition). All operands of all operations may be supplied
+  //  through a chain of side effect-free unary operations, such as casts,
+  //  which is denoted as `cu` above.
+  if (!mlir::linalg::detail::isContractionBody(
----------------
rengolin wrote:

Does this check that the first and second ops are linked in a chain?

Ie. that at least one operand in the second is the return value of the first?

And that the return value of the second is the actual `yield`ed value?

https://github.com/llvm/llvm-project/pull/95656


More information about the Mlir-commits mailing list