[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Wed Mar 12 04:07:28 PDT 2025
https://github.com/shahidact created https://github.com/llvm/llvm-project/pull/130944
…_reduce_matmul.
This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops.
The broadcast and transpose semantic are as follows:
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified.
Example Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
Example Broadcast:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
Example Broadcast and Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
outs(%arg2: memref<3x7xf32>)
```
RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 https://github.com/llvm/llvm-project/pull/115319
https://github.com/llvm/llvm-project/pull/122275
>From b5468fcf0c2149d006f62f536bb24f053c6778df Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Mon, 17 Feb 2025 03:05:48 -0800
Subject: [PATCH] [MLIR][Linalg] Introduce transpose/broadcast semantic to
linalg.batch_reduce_matmul.
This patch exposes broadcast and transpose semantics on
'batch_reduce_matmul'. This is the last one in continuation of
other two variant of matmul ops.
The broadcast and transpose semantic are as follows:
Broadcast and Transpose semantics can be appiled by specifying the
explicit attribute 'indexing_maps' as shown below. This is a list
attribute, so must include maps for all arguments if specified.
Example Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
Example Broadcast:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
Example Broadcast and Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
outs(%arg2: memref<3x7xf32>)
```
RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
https://github.com/llvm/llvm-project/pull/115319
https://github.com/llvm/llvm-project/pull/122275
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 70 -----
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 131 +++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 262 ++++++++++++++++--
.../Dialect/Linalg/generalize-named-ops.mlir | 30 ++
mlir/test/Dialect/Linalg/invalid.mlir | 202 ++++++++++++++
mlir/test/Dialect/Linalg/named-ops.mlir | 165 +++++++++++
6 files changed, 762 insertions(+), 98 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b44af2defc3e4..6344861c53ac5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: BZp
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_reduce_matmul
- cpp_class_name: BatchReduceMatmulOp
- doc: |-
- Performs a batch-reduce matrix multiplication of two 3D inputs.
- The partial multiplication results are reduced into a 2D output.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
- iterator_types:
- - reduction
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
cpp_class_name: MatvecOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..5191a658bbf26 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -1054,6 +1054,137 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
}
+//===----------------------------------------------------------------------===//
+// Op definition for BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+
+ let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
+The partial multiplication results are reduced into a 2D output.}];
+ let description = [{
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+
+ Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+ 'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+ arguments if specified.
+
+ Example Transpose:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast and Transpose:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<
+ AffineMapArrayAttr,
+ "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+ >:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("cast", cast);
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// Implements the block region builder.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ /// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ std::string getLibraryCallName();
+ bool hasDynamicIndexingMaps() { return true; };
+ /// Returns true if the user defined indexing maps are not equal to default maps.
+ bool hasUserDefinedMaps();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..d46fbf988d762 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -218,6 +218,23 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder,
+ ArrayRef<AffineMap> indexingMaps) {
+ // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
+ SmallVector<Attribute, 4> indexingMapsAttrVal;
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3464,19 +3481,24 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
-// Check general validity of input indexing map.
-static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+// Check general validity of input indexing map of
+// BatchMatmulOp/BatchReduceMatmulOp.
+template <typename OpTy>
+static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
AffineMap opIndexingMap,
AffineMap defaultIndexingMap, bool isLHS) {
+ assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+ isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+ "Expected BatchMatmulOp or BatchReduceMatmulOp");
// Check the result dims are valid.
if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Unexpected result dim expression (outside the set of default "
"result dims).";
// Check for valid number of result dims of input maps.
if (opIndexingMap.getNumResults() > 3)
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "no. of result dim expressions exceeds 3.";
auto hasValidBatchDim = [](AffineMap map) {
@@ -3486,60 +3508,83 @@ static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
// Check if the requested broadcast is valid.
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
- if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
- return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+ if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+ return batchVariantMatmulOp->emitOpError()
+ << "Invalid broadcast requested.";
} else if (!hasValidBatchDim(opIndexingMap)) {
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Invalid batch dimension expression.";
}
return success();
}
/// This function checks if the given AffineMap for the output of a
-/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
-/// dimensions are valid.
-static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
+/// dimensions and if the output map result dimensions are valid.
+template <typename OpTy>
+static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
AffineMap opIndexingMap) {
- if (opIndexingMap.getNumResults() != 3)
- return batchMatmulOp->emitOpError()
+ assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+ isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+ "Expected BatchMatmulOp or BatchReduceMatmulOp");
+ if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
+ opIndexingMap.getNumResults() != 3) {
+
+ return batchVariantMatmulOp->emitOpError()
<< "expects 3 dims, but got (" << opIndexingMap.getNumResults()
<< ").";
+ }
+ if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
+ opIndexingMap.getNumResults() != 2) {
+ return batchVariantMatmulOp->emitOpError()
+ << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
+ << ").";
+ }
- auto areValidOutputResultDim = [](AffineMap outputMap) {
- return outputMap.getResult(0).isFunctionOfDim(0) &&
- outputMap.getResult(1).isFunctionOfDim(1) &&
- outputMap.getResult(2).isFunctionOfDim(2);
+ auto areValidOutputResultDim = [&](AffineMap outputMap) {
+ return isa<BatchMatmulOp>(batchVariantMatmulOp)
+ ? outputMap.getResult(0).isFunctionOfDim(0) &&
+ outputMap.getResult(1).isFunctionOfDim(1) &&
+ outputMap.getResult(2).isFunctionOfDim(2)
+ : outputMap.getResult(0).isFunctionOfDim(1) &&
+ outputMap.getResult(1).isFunctionOfDim(2);
};
- if (!areValidOutputResultDim(opIndexingMap))
- return batchMatmulOp->emitOpError()
+ if (!areValidOutputResultDim(opIndexingMap)) {
+ return batchVariantMatmulOp->emitOpError()
<< "Invalid output map result dimension.";
+ }
return success();
}
/// Verifies the broadcast and transpose semantic specified by the explicit
-/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
+/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
+/// specified by opIndex.
+template <typename OpTy>
static LogicalResult
-verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
- unsigned opIndex) {
+verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
+ unsigned opIndex) {
SmallVector<AffineMap, 3> opIndexingMaps =
- batchMatmulOp.getIndexingMapsArray();
+ batchVariantMatmulOp.getIndexingMapsArray();
SmallVector<AffineMap, 3> defaultIndexingMaps =
- batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+ batchVariantMatmulOp.getDefaultIndexingMaps(
+ batchVariantMatmulOp->getContext());
if (opIndexingMaps.size() != 3)
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Indexing_map attribute must have 3 affine maps.";
auto opIndexingMap = opIndexingMaps[opIndex];
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
- if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+ if (opIndex == 2 &&
+ failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
return failure();
- if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
- opIndex == 0)))
+ if (opIndex != 2 &&
+ failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
+ defaultIndexingMap, opIndex == 0)))
return failure();
return success();
@@ -4035,7 +4080,7 @@ LogicalResult BatchMatmulOp::verify() {
return success();
for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
- if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+ if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
return failure();
}
return success();
@@ -5340,6 +5385,167 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
}
};
+//===----------------------------------------------------------------------===//
+// BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::reduction, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+SmallVector<AffineMap>
+BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+ AffineExpr d0, d1, d2, d3;
+ SmallVector<AffineMap> indexingMaps;
+ bindDims(context, d0, d1, d2, d3);
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
+ return indexingMaps;
+}
+
+unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchReduceMatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchReduceMatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map bcastMap is valid for this op.
+bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
+ bool isLHS) {
+ assert(bcastMap.getNumResults() < 3 &&
+ "Expected less than 3 result dim expr.");
+ bool isValid = false;
+ enum Indices { batchPos, mPos, nPos, kPos };
+ if (bcastMap.getNumResults() == 1) {
+ AffineExpr exp = bcastMap.getResult(0);
+ isValid = exp.isFunctionOfDim(kPos);
+ } else if (bcastMap.getNumResults() == 2) {
+ AffineExpr exp0 = bcastMap.getResult(0);
+ AffineExpr exp1 = bcastMap.getResult(1);
+ isValid = isLHS
+ ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+ : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ }
+ return isValid;
+}
+
+void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ auto toType = block.getArgument(2).getType();
+ Value castValA =
+ helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
+ Value castValB =
+ helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+ Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+ Value addVal =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+ yields.push_back(addVal);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+ if (parser.parseLSquare())
+ return failure();
+
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr)) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ }
+ indexingMapsAttr.push_back(mapAttr);
+
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+
+ if (parser.parseRSquare())
+ return failure();
+ }
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ indexingMapsAttr = llvm::map_to_vector(
+ BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+ return ::parseNamedStructuredOp(parser, result,
+ BatchReduceMatmulOp::getNumRegionArgs(),
+ BatchReduceMatmulOp::getRegionBuilder());
+}
+
+void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchReduceMatmulOp::verify() {
+ // Verification of pure batch_reduce_matmul is handled by
+ // verifyStructuredOpInterface().
+ if (!hasUserDefinedMaps())
+ return success();
+
+ for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
+ if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
+ return failure();
+ }
+ return success();
+}
+LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
+ SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void BatchReduceMatmulOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 0ec71c35497b1..feb627e84beca 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1024,6 +1024,36 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
// -----
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
+// CHECK: %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<3x7xf32>) {
+// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK: %[[VAL_7:.*]] = arith.mulf %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK: %[[VAL_8:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK: linalg.yield %[[VAL_8]] : f32
+// CHECK: } -> tensor<3x7xf32>
+// CHECK: return %[[VAL_3]] : tensor<3x7xf32>
+// CHECK: }
+
+func.func @batch_reduce_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<3x7xf32>) -> tensor<3x7xf32> {
+ %0 = linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
+ outs(%arg2: tensor<3x7xf32>) -> tensor<3x7xf32>
+ return %0 : tensor<3x7xf32>
+}
+
+// -----
+
// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index f2283db8f89b2..978df2acc8a10 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1462,6 +1462,208 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
}
+// -----
+
+func.func @indexing_map_size_mismatch_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+ %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+ %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+
+}
+
+// -----
+
+func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{expected attribute value}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ ,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>) outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_A_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_B_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_C_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op expects 2 dims, but got (1).}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_C_map_result_dim_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid output map result dimension.}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
// -----
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 1bd9c8825b05e..93ef4f0a4a956 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1637,6 +1637,171 @@ func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memre
// -----
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+
+func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_batch_dim_A(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+
+func.func @batch_reduce_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+
+func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_batch_dim_B(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+
+func.func @batch_reduce_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_explicit_transpose_A(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x3xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+func.func @batch_reduce_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_explicit_transpose_B(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+func.func @batch_reduce_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_A_transpose_B(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+func.func @batch_reduce_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @batchmatmul_transpose_a
// CHECK: linalg.batch_matmul_transpose_a
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
More information about the Mlir-commits
mailing list