[mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)

Andrzej Warzyński llvmlistbot at llvm.org
Fri Jan 24 03:59:42 PST 2025


================
@@ -3450,6 +3469,88 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
+// Check general validity of input indexing map.
+static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+                                     AffineMap opIndexingMap,
+                                     AffineMap defaultIndexingMap, bool isLHS) {
+  // Check the result dims are valid.
+  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
+    return batchMatmulOp->emitOpError()
+           << "Unexpected result dim expression (outside the set of default "
+              "result dims).";
+
+  // Check for valid number of result dims of input maps.
+  if (opIndexingMap.getNumResults() > 3)
+    return batchMatmulOp->emitOpError()
+           << "no. of result dim expressions exceeds 3.";
+
+  auto hasValidBatchDim = [](AffineMap map) {
+    AffineExpr batchDim = map.getResult(0);
+    return batchDim.isFunctionOfDim(0);
+  };
+
+  // Check if the requested broadcast is valid.
+  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+    if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+      return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+  } else if (!hasValidBatchDim(opIndexingMap)) {
+    return batchMatmulOp->emitOpError()
+           << "Invalid batch dimension expression.";
+  }
+  return success();
+}
+
+/// This function checks if the given AffineMap for the output of a
+/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
+/// dimensions are valid.
+static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+                                     AffineMap opIndexingMap) {
+  if (opIndexingMap.getNumResults() != 3)
+    return batchMatmulOp->emitOpError()
+           << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
+           << ").";
+
+  auto areValidOutputResultDim = [](AffineMap outputMap) {
+    return outputMap.getResult(0).isFunctionOfDim(0) &&
+           outputMap.getResult(1).isFunctionOfDim(1) &&
+           outputMap.getResult(2).isFunctionOfDim(2);
+  };
+
+  if (!areValidOutputResultDim(opIndexingMap))
+    return batchMatmulOp->emitOpError()
+           << "Invalid output map result dimension.";
+
+  return success();
+}
+
+/// Verifies the broadcast and transpose semantic specified by the explicit
+/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
+/// opIndex.
----------------
banach-space wrote:

Something that was pointed out to me - MLIR doesn't use doxygen-style comments. So, instead of `\p op`, please use `op`. In fact, I've noticed that both styles are used within this PR. Please prioritise consistency.

https://github.com/llvm/llvm-project/pull/122275


More information about the Mlir-commits mailing list