[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Mon Jan 13 08:04:08 PST 2025
================
@@ -3450,6 +3467,73 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+ assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr.");
+ AffineExpr exp = bcastMap.getResult(0);
+ return exp.isFunctionOfDim(0);
+}
+
+/// Checks if the given AffineMap's result dimensions are valid output result
+/// dimensions.
+static bool isValidOutputResultDim(AffineMap outputMap) {
+ enum Indices { batchPos, mPos, nPos };
+ AffineExpr exp0 = outputMap.getResult(batchPos);
+ AffineExpr exp1 = outputMap.getResult(mPos);
+ AffineExpr exp2 = outputMap.getResult(nPos);
+ return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) &&
+ exp2.isFunctionOfDim(nPos);
+}
+
+/// Verifies the broadcast and transpose semantic specified by the explicit
+/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult
+verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
+ unsigned opIndex) {
+ SmallVector<AffineMap, 3> opIndexingMaps =
+ batchMatmulOp.getIndexingMapsArray();
+ SmallVector<AffineMap, 3> defaultIndexingMaps =
+ batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+
+ auto opIndexingMap = opIndexingMaps[opIndex];
----------------
shahidact wrote:
If I undestand correctly, we do have such a test in "invalid.mlir" as below. However, let me know what case you have which crashes?
func.func @missing_indexing_map_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
// expected-error @+1 {{expected attribute value}}
linalg.batch_matmul indexing_maps = [
,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
outs(%arg2 :memref<?x?x?xf32>)
return
}
https://github.com/llvm/llvm-project/pull/122275
More information about the Mlir-commits
mailing list