[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Tue Jan 21 05:47:47 PST 2025
================
@@ -3450,6 +3467,95 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+ AffineExpr exp = bcastMap.getResult(0);
+ return exp.isFunctionOfDim(0);
+}
+
+/// Checks if the given AffineMap's result dimensions are valid output result
+/// dimensions.
+static bool isValidOutputResultDim(AffineMap outputMap) {
+ enum Indices { batchPos, mPos, nPos };
+ AffineExpr exp0 = outputMap.getResult(batchPos);
+ AffineExpr exp1 = outputMap.getResult(mPos);
+ AffineExpr exp2 = outputMap.getResult(nPos);
+ return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) &&
+ exp2.isFunctionOfDim(nPos);
+}
+
+// Check general validity of input indexing map.
+static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+ AffineMap opIndexingMap,
+ AffineMap defaultIndexingMap, bool isLHS) {
+ // Check the result dims are valid.
+ if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+ return batchMatmulOp->emitOpError()
+ << "Unexpected dim expression in map result.";
+
+ // Check for valid number of result dims of input maps.
+ if (opIndexingMap.getNumResults() > 3)
+ return batchMatmulOp->emitOpError()
+ << "no. of result dim expression cannot exceed 3.";
+
+ // Check if the requested broadcast is valid.
+ if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
----------------
shahidact wrote:
> 1. Is there a definition of what "broadcasting" means in the context of this file?
Here broadcast indexing maps are defined in context of corresponding default indexing maps for the given Op. This way the check becomes very simple i.e just check the number of result dims. I think inline check would be better.
https://github.com/llvm/llvm-project/pull/122275
More information about the Mlir-commits
mailing list