[Mlir-commits] [mlir] [MLIR][Linalg] Add pass to convert linalg.generic back to named ops (PR #95656)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Jun 27 04:18:26 PDT 2024
================
@@ -58,6 +68,181 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
return swapped;
}
+//===----------------------------------------------------------------------===//
+// Specialize linalg generic to matmul variants.
+//===----------------------------------------------------------------------===//
+/// Identifies linalg.generic that is essentially named op of the form:
+// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
+//
+// It is possible that a linalg.generic may be implementing a matmul but not
+// in a straight-forward way e.g. below is matrix multiply over some slice
+// ```
+// %0 = linalg.generic {
+// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
+// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
+// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
+// iterator_types = ["parallel", "parallel", "parallel"]}
+// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
+// outs(%C : tensor<20x20x20xf32>) {
+// ^bb0(%a: f32, %b: f32, %c : f32):
+// %mul = arith.mulf %a, %b : f32
+// %add = arith.addf %mul, %c : f32
+// linalg.yield %add : f32
+// } -> tensor<20x20x20xf32>
+// ```
+// It is not possible to represent above as named op.
+// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
+// not the same as linalg.generic above.
+namespace {
+enum class IndexMatchResult {
+ Match = 0, // identity map.
+ Transposed, // transposed map.
+ Mismatch // none of the above.
+};
+
+// Matches position of indices appearing the affine map of operand
+// with what is expected in non-transposed case. e.g.
+// consider the A matrix in `C[M,N] = A[M,K] * B[K,N]`. Below, we
+// check whether the index map of A is identity (match), transposed, or
+// something completely different (mis-match).
+// The naming and explanation is in terms of A, but the function checks
+// effectively maps for all A, B, C i.e. C<M,N>, A<M, K>, B<K,N>.
+static IndexMatchResult matchOperandMap(AffineMap map, unsigned batchSize,
+ unsigned expectedPosOfM,
+ unsigned expectedPosOfK) {
+ // Get the matrix multiply indices. They are past the batch indices.
+ auto exprOfM = map.getResults()[batchSize];
+ auto exprOfK = map.getResults()[batchSize + 1];
+
+ // They should be pure dim ids.
+ if (exprOfM.getKind() != AffineExprKind::DimId ||
+ exprOfK.getKind() != AffineExprKind::DimId)
+ return IndexMatchResult::Mismatch;
+
+ auto posM = cast<AffineDimExpr>(exprOfM).getPosition();
+ auto posK = cast<AffineDimExpr>(exprOfK).getPosition();
+
+ if (expectedPosOfM == posM && expectedPosOfK == posK)
+ return IndexMatchResult::Match;
+
+ if (expectedPosOfM == posK && expectedPosOfK == posM)
+ return IndexMatchResult::Transposed;
+
+ return IndexMatchResult::Mismatch;
+}
+
+// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
+// All the variants expressed as pseudo regular expression:
+// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+// have same number of ins/out, so its easy to stamp different versions.
+template <typename NamedOpTy>
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
+ op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
+ ValueRange{op.getDpsInits()[0]});
+ return namedOp;
+}
+
+// Converts linalg.generic to named linalg.*matmul* where possible.
+static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
+ return failure();
+
+ // Linalg generic contraction can be across multiple axis but for matmul
+ // variants it must be one.
+ if (genericOp.getNumReductionLoops() != 1)
+ return failure();
+
+ // Must be projected permutations.
+ auto mapRange = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(mapRange,
+ [](AffineMap m) { return !m.isProjectedPermutation(); }))
+ return failure();
+
+ if (!mlir::linalg::detail::isContractionBody(
+ *genericOp.getBlock(), [](Operation *first, Operation *second) {
+ if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
+ (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
+ (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
+ return true;
+ return false;
+ }))
+ return failure();
+
+ auto res = inferContractionDims(genericOp);
+ assert(succeeded(res) && "unexpected failure to infer contraction dims");
+ auto dims = *res;
+
+ // Other than `batch`, other dim sizes must be 1 for linalg.*_matmul_*.
+ // Note that linalg contraction can have more than one contraction dimension.
+ if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
+ return failure();
+
+ // Check rank of operands
+ auto indexingMaps = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
+ return m.getResults().size() !=
+ dims.batch.size() + 2 /* any two of {m,n,k} */;
+ }))
+ return failure();
+
+ auto batchSize = dims.batch.size();
----------------
banach-space wrote:
To me, "batch size" would be `2` in this example:
```mlir
linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
```
Whereas `batchSize = dims.batch.size()` is "the number of batch dims", so `1`. I suggest renaming `batchSize` as `numOfBatchDims`.
Also, would the number of batch dims be ever != 1?
https://github.com/llvm/llvm-project/pull/95656
More information about the Mlir-commits
mailing list