[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)

Maksim Levental llvmlistbot at llvm.org
Mon Apr 14 06:03:13 PDT 2025


================
@@ -568,6 +568,107 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
     print(module)
 
 
+# CHECK-LABEL: TEST: testBatchReduceMatmulOp
+ at run
+def testBatchReduceMatmulOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (5, 4, 8)
+            b_shape = (5, 8, 12)
+            b_transposed_shape = (5, 12, 8)
+            c_shape = (4, 12)
+
+            dimBatch = ir.AffineDimExpr.get(0)
+            dimM = ir.AffineDimExpr.get(1)
+            dimN = ir.AffineDimExpr.get(2)
+            dimK = ir.AffineDimExpr.get(3)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+            a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
+            b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
+            c_map = ir.AffineMap.get(4, 0, [dimM, dimN])
+
+            # CHECK: func.func @batch_reduce_matmul_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<5x4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<5x4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<5x8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<5x8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<5x12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<5x12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def batch_reduce_matmul_op(
+                A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
+            ):
+                # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                res = linalg.BatchReduceMatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                res = linalg.batch_reduce_matmul(A, B, outs=(C,))
+
+                # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                res = linalg.BatchReduceMatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
----------------
makslevental wrote:

Anyway don't worry about it - I'll patch it up after you land this (and add a test)

https://github.com/llvm/llvm-project/pull/130944


More information about the Mlir-commits mailing list