[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)

Renato Golin llvmlistbot at llvm.org
Thu Feb 6 06:29:00 PST 2025


================
@@ -3450,6 +3467,95 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+  AffineExpr exp = bcastMap.getResult(0);
+  return exp.isFunctionOfDim(0);
+}
+
+/// Checks if the given AffineMap's result dimensions are valid output result
+/// dimensions.
+static bool isValidOutputResultDim(AffineMap outputMap) {
+  enum Indices { batchPos, mPos, nPos };
+  AffineExpr exp0 = outputMap.getResult(batchPos);
+  AffineExpr exp1 = outputMap.getResult(mPos);
+  AffineExpr exp2 = outputMap.getResult(nPos);
+  return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) &&
+         exp2.isFunctionOfDim(nPos);
+}
+
+// Check general validity of input indexing map.
+static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+                                     AffineMap opIndexingMap,
+                                     AffineMap defaultIndexingMap, bool isLHS) {
+  // Check the result dims are valid.
+  if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+    return batchMatmulOp->emitOpError()
+           << "Unexpected dim expression in map result.";
+
+  // Check for valid number of result dims of input maps.
+  if (opIndexingMap.getNumResults() > 3)
+    return batchMatmulOp->emitOpError()
+           << "no. of result dim expression cannot exceed 3.";
+
+  // Check if the requested broadcast is valid.
+  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
----------------
rengolin wrote:

I think we need to be conservative on what broadcast / transpose means in terms of `matmul` variants, so that we don't introduce cases that weren't supported before and change behaviour. None of these concerns were addressed by the old op.

Right now, we have enough tests to show what it means to broadcast in this particular case. So I think we can merge this now and continue the work on `batch_reduce_matmul` for the final PR in this round, and address the more general affine map semantics in operations like `contract` and `element_wise`.

https://github.com/llvm/llvm-project/pull/122275


More information about the Mlir-commits mailing list