[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)
Maksim Levental
llvmlistbot at llvm.org
Mon Apr 14 05:36:08 PDT 2025
================
@@ -568,6 +568,107 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
print(module)
+# CHECK-LABEL: TEST: testBatchReduceMatmulOp
+ at run
+def testBatchReduceMatmulOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (5, 4, 8)
+ b_shape = (5, 8, 12)
+ b_transposed_shape = (5, 12, 8)
+ c_shape = (4, 12)
+
+ dimBatch = ir.AffineDimExpr.get(0)
+ dimM = ir.AffineDimExpr.get(1)
+ dimN = ir.AffineDimExpr.get(2)
+ dimK = ir.AffineDimExpr.get(3)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
+ b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
+ c_map = ir.AffineMap.get(4, 0, [dimM, dimN])
+
+ # CHECK: func.func @batch_reduce_matmul_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<5x4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<5x4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<5x8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<5x8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<5x12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<5x12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def batch_reduce_matmul_op(
+ A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
+ ):
+ # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.BatchReduceMatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ )
+ linalg.fill_builtin_region(res.operation)
+ # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.batch_reduce_matmul(A, B, outs=(C,))
+
+ # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.BatchReduceMatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
----------------
makslevental wrote:
> I'd welcome others that have more familiarity with the code (and Python) to make the necessary adjustments, maybe after this PR.
Ok but these are breaking changes are they not? Just like the last one? Like I couldn't just change some python that affected the cpp codebase and then ask people to fix the cpp just because I'm not familiar with that code right?
https://github.com/llvm/llvm-project/pull/130944
More information about the Mlir-commits
mailing list