[Mlir-commits] [mlir] 760ec2c - [MLIR][Linalg] Introduce Python API for linalg.batch_matmul Ops. (#127614)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 19 06:15:06 PST 2025


Author: Md Asghar Ahmad Shahid
Date: 2025-02-19T14:15:02Z
New Revision: 760ec2c38e0cd01c016c403301e8dc081e0fc85c

URL: https://github.com/llvm/llvm-project/commit/760ec2c38e0cd01c016c403301e8dc081e0fc85c
DIFF: https://github.com/llvm/llvm-project/commit/760ec2c38e0cd01c016c403301e8dc081e0fc85c.diff

LOG: [MLIR][Linalg] Introduce Python API for linalg.batch_matmul Ops. (#127614)

As linalg.batch_matmul has been moved into tablegen from OpDSL, its
derived python wrapper no longer exist.This patch adds the required
python wrapper.

Also refactors the BatchmatmulOp printer to make it consistent with its
parser.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/python/mlir/dialects/linalg/__init__.py
    mlir/test/Dialect/Linalg/named-ops.mlir
    mlir/test/python/dialects/linalg/ops.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6a439bfb09078..a5725d6f1507e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -858,7 +858,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
     let arguments = (ins
       Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs,
-      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+      DefaultValuedOptionalAttr<
+        AffineMapArrayAttr,
+        "BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+      >:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
     );
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
     let regions = (region AnyRegion:$region);
@@ -884,9 +888,10 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
       }]>,
       OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
-            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+          "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
         $_state.addOperands(operands);
+        $_state.addAttribute("cast", cast);
         $_state.addAttributes(attributes);
         $_state.addTypes(resultTensorTypes);
         (void)$_state.addRegion(),

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b756a67f3ba7a..42ea0e1197ef1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3951,11 +3951,18 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
   RegionBuilderHelper helper(b, block);
   SmallVector<Value> yields;
 
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
   auto toType = block.getArgument(2).getType();
-  Value castValA =
-      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
-  Value castValB =
-      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+  Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
+  Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
   Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
   Value addVal =
       helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
@@ -4004,11 +4011,6 @@ ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void BatchMatmulOp::print(OpAsmPrinter &p) {
-  SmallVector<StringRef, 3> elidedAttrs = {
-      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
-  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
-                           elidedAttrs);
-
   SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
       BatchMatmulOp::getDefaultIndexingMaps(getContext()),
       [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -4018,6 +4020,11 @@ void BatchMatmulOp::print(OpAsmPrinter &p) {
                           [&](Attribute attr) { p.printAttribute(attr); });
     p << "]";
   }
+
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                           elidedAttrs);
 }
 
 /// Verify the user defined indexing maps.

diff  --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 5cda4769d593f..c5fbb833ee399 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -149,7 +149,8 @@ def __init__(
 generic = region_op(GenericOp_, terminator=YieldOp)
 
 
-def matmul(
+def create_op(
+    op_type,
     *ins: Union[Operation, OpView, Value],
     outs: Sequence[Union[Operation, OpView, Value]],
     indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
@@ -161,7 +162,7 @@ def matmul(
     init = _get_op_result_or_value(outs[0])
     result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
 
-    op = MatmulOp(
+    op = op_type(
         result_tensors=result_types,
         inputs=ins,
         outputs=[init],
@@ -172,24 +173,32 @@ def matmul(
     return op
 
 
+def matmul(
+    *ins: Union[Operation, OpView, Value],
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
+    cast: Optional[Union[TypeFn, Attribute]] = None,
+):
+    return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
+
+
+def batch_matmul(
+    *ins: Union[Operation, OpView, Value],
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
+    cast: Optional[Union[TypeFn, Attribute]] = None,
+):
+    return create_op(
+        BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+    )
+
+
 def contract(
     *ins: Union[Operation, OpView, Value],
     outs: Sequence[Union[Operation, OpView, Value]],
     indexing_maps: Sequence[AffineMapAttr],
     cast: Optional[Union[TypeFn, Attribute]] = None,
 ):
-    ins = [_get_op_result_or_value(input) for input in ins]
-    if len(outs) > 1:
-        raise ValueError(f"{outs=} must have length 1.")
-    init = _get_op_result_or_value(outs[0])
-    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
-
-    op = ContractOp(
-        result_tensors=result_types,
-        inputs=ins,
-        outputs=[init],
-        indexing_maps=indexing_maps,
-        cast=cast,
+    return create_op(
+        ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
     )
-    fill_builtin_region(op.operation)
-    return op

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 8474eeac0db5b..1bd9c8825b05e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1497,7 +1497,7 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
 // CHECK-SAME:                                    %[[VAL_0:.*]]: memref<5xf32>,
 // CHECK-SAME:                                    %[[VAL_1:.*]]: memref<2x5x7xf32>,
 // CHECK-SAME:                                    %[[VAL_2:.*]]: memref<2x3x7xf32>) {
-// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1520,7 +1520,7 @@ func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %ar
 // CHECK-SAME:                                              %[[VAL_0:.*]]: memref<3x5xf32>,
 // CHECK-SAME:                                              %[[VAL_1:.*]]: memref<2x5x7xf32>,
 // CHECK-SAME:                                              %[[VAL_2:.*]]: memref<2x3x7xf32>) {
-// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1543,7 +1543,7 @@ func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<
 // CHECK-SAME:                                                    %[[VAL_0:.*]]: memref<2x3x5xf32>,
 // CHECK-SAME:                                                    %[[VAL_1:.*]]: memref<5xf32>,
 // CHECK-SAME:                                                    %[[VAL_2:.*]]: memref<2x3x7xf32>) {
-// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) {
@@ -1566,7 +1566,7 @@ func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1:
 // CHECK-SAME:                                              %[[VAL_0:.*]]: memref<2x3x5xf32>,
 // CHECK-SAME:                                              %[[VAL_1:.*]]: memref<5x7xf32>,
 // CHECK-SAME:                                              %[[VAL_2:.*]]: memref<2x3x7xf32>) {
-// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 
@@ -1622,7 +1622,7 @@ func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: me
 // CHECK-SAME:                                                %[[VAL_0:.*]]: memref<3x5xf32>,
 // CHECK-SAME:                                                %[[VAL_1:.*]]: memref<2x7x5xf32>,
 // CHECK-SAME:                                                %[[VAL_2:.*]]: memref<2x3x7xf32>) {
-// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           linalg.batch_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 94f8ea4faf4a8..307a88709ad52 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -466,3 +466,103 @@ def matmul_as_contract_op(
                 )
 
         print(module)
+
+
+# CHECK-LABEL: TEST: testBatchMatmulOp
+ at run
+def testBatchMatmulOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (2, 4, 8)
+            b_shape = (2, 8, 12)
+            b_transposed_shape = (2, 12, 8)
+            c_shape = (2, 4, 12)
+
+            dimBatch = ir.AffineDimExpr.get(0)
+            dimM = ir.AffineDimExpr.get(1)
+            dimN = ir.AffineDimExpr.get(2)
+            dimK = ir.AffineDimExpr.get(3)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+            a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
+            b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
+            c_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimN])
+
+            # CHECK: func.func @batch_matmul_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<2x4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<2x4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<2x8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<2x8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<2x12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<2x12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<2x4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<2x4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
+                # CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
+                res = linalg.BatchMatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x4x8xf32>, tensor<2x8x12xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
+                res = linalg.batch_matmul(A, B, outs=(C,))
+
+                # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
+                res = linalg.BatchMatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<2x4x8xf32>, tensor<2x12x8xf32>) outs(%[[C]] : tensor<2x4x12xf32>)
+                res = linalg.batch_matmul(
+                    A,
+                    Btransposed,
+                    outs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+
+                # CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
+                res = linalg.BatchMatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Bmem),
+                    outputs=(Cmem,),
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_matmul ins(%[[Amem]], %[[Bmem]] : memref<2x4x8xf32>, memref<2x8x12xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
+                linalg.batch_matmul(Amem, Bmem, outs=(Cmem,))
+
+                # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
+                res = linalg.BatchMatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Btransposedmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<2x4x8xf32>, memref<2x12x8xf32>) outs(%[[Cmem]] : memref<2x4x12xf32>)
+                linalg.batch_matmul(
+                    Amem,
+                    Btransposedmem,
+                    outs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+
+    print(module)


        


More information about the Mlir-commits mailing list