[Mlir-commits] [mlir] [MLIR][Linalg] Add pass to convert linalg.generic back to named ops (PR #95656)

Adam Siemieniuk llvmlistbot at llvm.org
Mon Jun 17 04:41:25 PDT 2024


================
@@ -58,6 +68,175 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
   return swapped;
 }
 
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+//    ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing one of matmul
+// variants but not in a straight-forward way, or the linalg.generic's
+// affine map per operand capture more semantics than is possible with
+// named op (which has implicit map interpreted via name).
+//
+// But a named linalg matmul variant that was 'generalized' should be
+// convertible back to named op here.
+//
+namespace {
+enum class IndexMatchResult {
+  Match = 0,  // identity map.
+  Transposed, // transposed map.
+  Mismatch    // none of the above.
+};
+
+// Looks at the affine map of an operand and works out if generic accesses
+// the element as identity-map, transposed, or 'cant work out'.
+// This check skips the `offset` batch indices and focuses on the matmul part.
+static IndexMatchResult matchOperandMap(AffineMap m, unsigned offset,
+                                        unsigned i, unsigned j) {
+  auto expr_ei = dyn_cast<AffineDimExpr>(m.getResults()[offset]);
+  auto expr_ej = dyn_cast<AffineDimExpr>(m.getResults()[offset + 1]);
+  if (!expr_ei || !expr_ej)
+    return IndexMatchResult::Mismatch;
+
+  auto ei = expr_ei.getPosition();
+  auto ej = expr_ej.getPosition();
+
+  if (ei == i && ej == j)
+    return IndexMatchResult::Match;
+
+  if (ei == j && ej == i)
+    return IndexMatchResult::Transposed;
+
+  return IndexMatchResult::Mismatch;
+}
+
+//  All the variants `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+//  have same number of input/output.
+template <typename Variant>
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+  LinalgOp namedOp = rewriter.replaceOpWithNewOp<Variant>(
+      op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
+      ValueRange{op.getDpsInits()[0]});
+  return namedOp;
+}
+
+// Converts linalg.generic to named linalg.*matmul* where possible.
+static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
+                                                        GenericOp genericOp) {
+  if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+    return failure();
+
+  // Linalg generic contraction can be across multiple axis but for matmul
+  // variants it must be one.
+  if (genericOp.getNumReductionLoops() != 1)
+    return failure();
+
+  // Must be projected permutations.
+  auto mapRange = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(mapRange,
+                   [](AffineMap m) { return !m.isProjectedPermutation(); }))
+    return failure();
+
+  //  matmul contractions are of the form:
+  //  %0 = <elemwise>(permutation-of(cu(block-argument-0),
+  //                                 cu(block-argument-1)))
+  //  %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
+  //
+  //  where <elemwise> and <reduce> are binary operations constituting a
+  //  contraction (in the canonical case, <elemwise> is a multiplication and
+  //  <reduce> is an addition). All operands of all operations may be supplied
+  //  through a chain of side effect-free unary operations, such as casts,
+  //  which is denoted as `cu` above.
+  if (!mlir::linalg::detail::isContractionBody(
+          *genericOp.getBlock(), [](Operation *first, Operation *second) {
+            if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
+                (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
+                (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
+              return true;
+            return false;
+          }))
+    return failure();
+
+  // Finds 2 parallel (m and n) and 1 reduction (k) dimension candidates that
+  // form a matmul subcomputation. These dimensions are such that:
+  //   1. The m dimension is involved in an outer-product along LHS
+  //      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+  //   2. The n dimension is involved in an outer-product along RHS
+  //      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+  //   3. The k dimension appears as a permutation on LHS and RHS.
+  //   4. m, n and k appear only once in any given indexing.
+  //   5. Optional batch dimensions that appear in all operands are captured.
+  auto res = inferContractionDims(genericOp);
+  assert(succeeded(res) && "unexpected failure to infer contraction dims");
+  auto dims = *res;
+
+  // Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
+  if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
+    return failure();
+
+  // Check rank of operands
+  auto indexingMaps = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
+        return m.getResults().size() !=
+               dims.batch.size() + 2 /*two from {m,n,k}*/;
+      }))
+    return failure();
+
+  auto batchSize = dims.batch.size();
+  if (indexingMaps[0].getNumDims() != batchSize + 3) {
+  }
+  if (batchSize) {
+    // Each operand in a linalg generic contraction  could express different
+    // permutations for its batch dimension. But for named op it must be
+    // identity since separate maps are not specified.
+    if (llvm::any_of(indexingMaps, [batchSize](AffineMap m) {
+          for (unsigned i = 0; i < batchSize; ++i) {
+            auto expr = dyn_cast<AffineDimExpr>(m.getResults()[i]);
+            if (!expr || expr.getPosition() != i)
+              return true;
+          }
+          return false;
+        }))
+      return failure();
+  }
+
+  auto a = matchOperandMap(indexingMaps[0], batchSize, dims.m[0], dims.k[0]);
+  auto b = matchOperandMap(indexingMaps[1], batchSize, dims.k[0], dims.n[0]);
+  auto c = matchOperandMap(indexingMaps[2], batchSize, dims.m[0], dims.n[0]);
+
+  if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
+        return r == IndexMatchResult::Mismatch;
+      }))
+    return failure();
+
+  if (c != IndexMatchResult::Match ||
+      (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
+    return failure();
+
+  /// Codegen the different matmul variants.
+  if (batchSize) {
+    if (a == IndexMatchResult::Transposed)
+      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
+                                                               genericOp);
+    if (b == IndexMatchResult::Transposed)
+      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
+                                                               genericOp);
+    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+  }
+
+  if (a == IndexMatchResult::Transposed)
+    return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
+  if (b == IndexMatchResult::Transposed)
+    return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
+  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
----------------
adam-smnk wrote:

There's also `linalg.matmul_unsigned`. I guess its generalized form simply won't match?

https://github.com/llvm/llvm-project/pull/95656


More information about the Mlir-commits mailing list