[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Thu May 1 09:07:35 PDT 2025


================
@@ -1065,6 +1060,134 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
 }
 
 
+//===----------------------------------------------------------------------===//
+// Op definition for BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
+                          AttrSizedOperandSegments,
+                          LinalgContractionOpInterface]> {
+    
+  let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
+    The partial multiplication results are reduced into a 2D output.}];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+    arguments if specified.
+
+    Example Transpose:
+    ```mlir
+    linalg.batch_reduce_matmul
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (m, n)>]
+        ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+        outs(%arg2: memref<3x7xf32>)
+    ```
+
+    Example Broadcast:
+    ```mlir
+    linalg.batch_reduce_matmul
+        indexing_maps = [affine_map<(batch, m, n, k) -> (k)>,         // broadcast
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (m, n)>]
+        ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+        outs(%arg2: memref<3x7xf32>)
+    ```
+
+    Example Broadcast and Transpose:
+    ```mlir
+    linalg.batch_reduce_matmul
+        indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>,        // broadcast
+                         affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
+                         affine_map<(batch, m, n, k) -> (m, n)>]
+        ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+        outs(%arg2: memref<3x7xf32>)
+    ```
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<
+        AffineMapArrayAttr,
+        "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+      >:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>
+      
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
+      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
----------------
shahidact wrote:

Sure, I will update the comment.  I think lhs/rhs is conventionally used for operands.

https://github.com/llvm/llvm-project/pull/130944


More information about the Mlir-commits mailing list