[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Apr 14 03:15:37 PDT 2025
================
@@ -568,6 +568,107 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
print(module)
+# CHECK-LABEL: TEST: testBatchReduceMatmulOp
+ at run
+def testBatchReduceMatmulOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (5, 4, 8)
+ b_shape = (5, 8, 12)
+ b_transposed_shape = (5, 12, 8)
+ c_shape = (4, 12)
+
+ dimBatch = ir.AffineDimExpr.get(0)
+ dimM = ir.AffineDimExpr.get(1)
+ dimN = ir.AffineDimExpr.get(2)
+ dimK = ir.AffineDimExpr.get(3)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
+ b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
+ c_map = ir.AffineMap.get(4, 0, [dimM, dimN])
+
+ # CHECK: func.func @batch_reduce_matmul_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<5x4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<5x4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<5x8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<5x8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<5x12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<5x12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def batch_reduce_matmul_op(
+ A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
+ ):
+ # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.BatchReduceMatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ )
+ linalg.fill_builtin_region(res.operation)
+ # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.batch_reduce_matmul(A, B, outs=(C,))
+
+ # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.BatchReduceMatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
----------------
nicolasvasilache wrote:
@makslevental is your concern about the direction in general or about the lack of default "no-indexing map == canonical identity" APIs?
If it is the latter, this is an easy fix.
It if is the former, my perspective is a bunch of folks have an itch to displace matching code towards C++ verifier.
I am fine with that as it does indeed make IR quite more readable with fewer generics, even if the improvement in code logic is very marginal.
It would be even more readable if the affine_map actually pretty printed with the names batch, m, n, k to easily see what is what. I'd welcome a followup improvement to the pretty-printer (like op result can have asm overrides), @rengolin @ftynse what say you?
While we're at pretty-printer wishlists, if would be nice to pretty print the op in a formatted fashion. I'd love to have my named linalg ops be pretty-printed exactly like this:
```
func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
affine_map<(batch, m, n, k) -> (batch, n, k)>]
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
outs(%arg2: memref<?x?xf32>)
return
}
```
https://github.com/llvm/llvm-project/pull/130944
More information about the Mlir-commits
mailing list