[mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Andrzej Warzyński
llvmlistbot at llvm.org
Fri Jan 24 03:59:42 PST 2025
================
@@ -3450,6 +3467,95 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+ AffineExpr exp = bcastMap.getResult(0);
+ return exp.isFunctionOfDim(0);
+}
+
+/// Checks if the given AffineMap's result dimensions are valid output result
+/// dimensions.
+static bool isValidOutputResultDim(AffineMap outputMap) {
+ enum Indices { batchPos, mPos, nPos };
+ AffineExpr exp0 = outputMap.getResult(batchPos);
+ AffineExpr exp1 = outputMap.getResult(mPos);
+ AffineExpr exp2 = outputMap.getResult(nPos);
+ return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) &&
+ exp2.isFunctionOfDim(nPos);
+}
+
+// Check general validity of input indexing map.
+static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+ AffineMap opIndexingMap,
+ AffineMap defaultIndexingMap, bool isLHS) {
+ // Check the result dims are valid.
+ if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+ return batchMatmulOp->emitOpError()
+ << "Unexpected dim expression in map result.";
+
+ // Check for valid number of result dims of input maps.
+ if (opIndexingMap.getNumResults() > 3)
+ return batchMatmulOp->emitOpError()
+ << "no. of result dim expression cannot exceed 3.";
+
+ // Check if the requested broadcast is valid.
+ if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
----------------
banach-space wrote:
Apologies, that's not what I had in mind. My bad, I should've been cleaner. Also, I've misread your comment.
Based on the discussion so far, this is my understanding of what "broadcast" means here:
```cpp
static bool isBroadcasted(AffineMap explicitMap, AffineMap defaultMap) {
return areResultExprsSubsetOf(explicitMap, defaultMap) &&
explictMap.getNumResults() < defaultMap.getNumResults();
}
```
i.e. there are 2 conditions for an `explicitMap` (i.e. a map provided by the user) to be a broadcast of `defaultMap`. This condition alone wouldn't be sufficient:
```cpp
explictMap.getNumResults() < defaultMap.getNumResults();
```
In particular, it would return `true` for these maps:
```
explicitMap = affine_map<(d0, d1, d2, d3) -> (d0)>,
defaultMap = affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
```
But in this case, `explicitMap` is not a broadcast of `defaultMap`, right?
https://github.com/llvm/llvm-project/pull/122275
More information about the Mlir-commits
mailing list