[mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)

Andrzej Warzyński llvmlistbot at llvm.org
Tue Jan 21 07:53:38 PST 2025


================
@@ -3450,6 +3467,95 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+  AffineExpr exp = bcastMap.getResult(0);
+  return exp.isFunctionOfDim(0);
+}
+
+/// Checks if the given AffineMap's result dimensions are valid output result
+/// dimensions.
+static bool isValidOutputResultDim(AffineMap outputMap) {
+  enum Indices { batchPos, mPos, nPos };
+  AffineExpr exp0 = outputMap.getResult(batchPos);
+  AffineExpr exp1 = outputMap.getResult(mPos);
+  AffineExpr exp2 = outputMap.getResult(nPos);
+  return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) &&
+         exp2.isFunctionOfDim(nPos);
+}
+
+// Check general validity of input indexing map.
+static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+                                     AffineMap opIndexingMap,
+                                     AffineMap defaultIndexingMap, bool isLHS) {
+  // Check the result dims are valid.
+  if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+    return batchMatmulOp->emitOpError()
+           << "Unexpected dim expression in map result.";
+
+  // Check for valid number of result dims of input maps.
+  if (opIndexingMap.getNumResults() > 3)
+    return batchMatmulOp->emitOpError()
+           << "no. of result dim expression cannot exceed 3.";
+
+  // Check if the requested broadcast is valid.
+  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
----------------
banach-space wrote:

Thanks! This description would be very helpful somewhere in the code. 

Also, given the 2 conditions that you listed above, I think that you could introduce a hook, e.g. `isIndexingMapABroadcastOf(mapToCheck, mapToCompareAgainst)` that would check these two things. In fact, isn't that already covered by:
* `isValidResultDimExprs` + `isBroadcast`.?

So:
```cpp
bool isIndexingMapABroadcastOf(mapToCheck, mapToCompareAgainst) {
  return isValidResultDimExprs(mapToCheck, mapToCompareAgainst) && isBroadcast(mapToCheck, mapToCompareAgainst);
}
```

Something along these lines (though I would change names of `isValidResultDimExprs` and `isBroadcast`).

https://github.com/llvm/llvm-project/pull/122275


More information about the Mlir-commits mailing list