[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)
Renato Golin
llvmlistbot at llvm.org
Mon Apr 14 04:36:37 PDT 2025
================
@@ -568,6 +568,107 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
print(module)
+# CHECK-LABEL: TEST: testBatchReduceMatmulOp
+ at run
+def testBatchReduceMatmulOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (5, 4, 8)
+ b_shape = (5, 8, 12)
+ b_transposed_shape = (5, 12, 8)
+ c_shape = (4, 12)
+
+ dimBatch = ir.AffineDimExpr.get(0)
+ dimM = ir.AffineDimExpr.get(1)
+ dimN = ir.AffineDimExpr.get(2)
+ dimK = ir.AffineDimExpr.get(3)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
+ b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
+ c_map = ir.AffineMap.get(4, 0, [dimM, dimN])
+
+ # CHECK: func.func @batch_reduce_matmul_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<5x4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<5x4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<5x8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<5x8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<5x12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<5x12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def batch_reduce_matmul_op(
+ A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
+ ):
+ # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.BatchReduceMatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ )
+ linalg.fill_builtin_region(res.operation)
+ # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.batch_reduce_matmul(A, B, outs=(C,))
+
+ # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.BatchReduceMatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
----------------
rengolin wrote:
I think the main reason is that the people implementing those wrappers (us) are not the main users, so we don't know exactly what you guys need. Feel free to propose changes, or even just submit a PR with fix-ups on top of this one. Whatever is easier for you.
https://github.com/llvm/llvm-project/pull/130944
More information about the Mlir-commits
mailing list