[mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Andrzej Warzyński
llvmlistbot at llvm.org
Wed Jan 15 11:53:08 PST 2025
================
@@ -680,6 +680,130 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}
+//===----------------------------------------------------------------------===//
+// Op definition for BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
+ /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+
+ let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
+ let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+
+ Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+ 'indexing_maps' as shown below.This is a list attribute, so the list must include all
+ the maps if specified.
+
+ Example Transpose:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+
+ Example Broadcast:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+
+ Example Broadcast and transpose:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+}];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, BatchMatmulOp::getRegionBuilder(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addOperands(operands);
+ $_state.addAttributes(attributes);
+ $_state.addTypes(resultTensorTypes);
+ (void)$_state.addRegion(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ /// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ bool hasDynamicIndexingMaps() { return true; }
+ std::string getLibraryCallName();
+ /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+ /// user defined indexing maps are not equal to default map.
+ bool hasUserDefinedMaps();
----------------
banach-space wrote:
The name of this hook is misleading, it should more like `hasDefaultMaps`. I've left similar comment here:
* https://github.com/llvm/llvm-project/pull/122753/files#r1917202562
Perhaps just leave a TODO to rename it?
https://github.com/llvm/llvm-project/pull/122275
More information about the Mlir-commits
mailing list