[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)

Adam Siemieniuk llvmlistbot at llvm.org
Mon Jan 13 04:03:04 PST 2025


================
@@ -3450,6 +3467,73 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+  assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr.");
+  AffineExpr exp = bcastMap.getResult(0);
+  return exp.isFunctionOfDim(0);
+}
+
+/// Checks if the given AffineMap's result dimensions are valid output result
+/// dimensions.
+static bool isValidOutputResultDim(AffineMap outputMap) {
+  enum Indices { batchPos, mPos, nPos };
+  AffineExpr exp0 = outputMap.getResult(batchPos);
+  AffineExpr exp1 = outputMap.getResult(mPos);
+  AffineExpr exp2 = outputMap.getResult(nPos);
+  return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) &&
+         exp2.isFunctionOfDim(nPos);
+}
+
+/// Verifies the broadcast and transpose semantic specified by the explicit
+/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult
+verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
+                                  unsigned opIndex) {
+  SmallVector<AffineMap, 3> opIndexingMaps =
+      batchMatmulOp.getIndexingMapsArray();
+  SmallVector<AffineMap, 3> defaultIndexingMaps =
+      batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+
+  auto opIndexingMap = opIndexingMaps[opIndex];
+  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
+  // Check general validity of indexing map results.
+  if (opIndex < 2) {
+    if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+      return batchMatmulOp->emitOpError()
+             << "Unexpected dim expression in map result.";
+    // Check if the requested broadcast is valid.
+    if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+      if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap,
+                                                   opIndex == 0)) {
+        return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+      }
+    } else {
+      // Check for valid number of result dims of input maps.
+      if (opIndexingMap.getNumResults() != 3)
+        return batchMatmulOp->emitOpError()
+               << "no. of result dim expression cannot exceed 3.";
+
+      if (!isValidBatchDim(opIndexingMap))
+        return batchMatmulOp->emitOpError()
+               << "Invalid batch dimension expression.";
+    }
+  } else {
+    // Check for valid number of result dims of output map.
+    if (opIndexingMap.getNumResults() != 3)
+      return batchMatmulOp->emitOpError()
+             << "no. of result dim expression cannot exceed 3.";
+
+    if (!isValidOutputResultDim(opIndexingMap))
+      return batchMatmulOp->emitOpError()
+             << "Invalid output map result dimension.";
----------------
adam-smnk wrote:

nit: I think these cases could be refactored a bit, especially the repeated check for the num results

https://github.com/llvm/llvm-project/pull/122275


More information about the Mlir-commits mailing list