[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Mon Jan 20 07:47:38 PST 2025


================
@@ -680,6 +680,130 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
+  /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+    
+  let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
+  let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                   affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                   ]
+                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+                   outs(%arg2: memref<2x3x7xf32>)
+    ```
+
+    Example Broadcast:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+
+    Example Broadcast and transpose:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+}];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, BatchMatmulOp::getRegionBuilder(),
+          BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
+          BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion(),
+        BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
+      }]>
+      
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      /// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
----------------
shahidact wrote:

Done

https://github.com/llvm/llvm-project/pull/122275


More information about the Mlir-commits mailing list