[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Tue Apr 29 23:51:07 PDT 2025
https://github.com/shahidact updated https://github.com/llvm/llvm-project/pull/130944
>From 2c55e73ad54111be15fe4c9c3508c91c416ae4cd Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Mon, 17 Feb 2025 03:05:48 -0800
Subject: [PATCH 1/5] [MLIR][Linalg] Introduce transpose/broadcast semantic to
linalg.batch_reduce_matmul.
This patch exposes broadcast and transpose semantics on
'batch_reduce_matmul'. This is the last one in continuation of
other two variant of matmul ops.
The broadcast and transpose semantic are as follows:
Broadcast and Transpose semantics can be appiled by specifying the
explicit attribute 'indexing_maps' as shown below. This is a list
attribute, so must include maps for all arguments if specified.
Example Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
Example Broadcast:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
Example Broadcast and Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
outs(%arg2: memref<3x7xf32>)
```
RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
https://github.com/llvm/llvm-project/pull/115319
https://github.com/llvm/llvm-project/pull/122275
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 70 -----
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 131 +++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 261 ++++++++++++++++--
.../Dialect/Linalg/generalize-named-ops.mlir | 30 ++
mlir/test/Dialect/Linalg/invalid.mlir | 202 ++++++++++++++
mlir/test/Dialect/Linalg/named-ops.mlir | 165 +++++++++++
6 files changed, 761 insertions(+), 98 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b44af2defc3e4..6344861c53ac5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: BZp
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_reduce_matmul
- cpp_class_name: BatchReduceMatmulOp
- doc: |-
- Performs a batch-reduce matrix multiplication of two 3D inputs.
- The partial multiplication results are reduced into a 2D output.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
- iterator_types:
- - reduction
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
cpp_class_name: MatvecOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index b9edcc92e81a9..e43112cccbd39 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -1065,6 +1065,137 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
}
+//===----------------------------------------------------------------------===//
+// Op definition for BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+
+ let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
+The partial multiplication results are reduced into a 2D output.}];
+ let description = [{
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+
+ Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+ 'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+ arguments if specified.
+
+ Example Transpose:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast and Transpose:
+ ```
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<
+ AffineMapArrayAttr,
+ "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+ >:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("cast", cast);
+ buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+ attributes, BatchReduceMatmulOp::getRegionBuilder(),
+ BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// Implements the block region builder.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ /// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ std::string getLibraryCallName();
+ bool hasDynamicIndexingMaps() { return true; };
+ /// Returns true if the user defined indexing maps are not equal to default maps.
+ bool hasUserDefinedMaps();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 089ccc6680e48..404db532454c7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -220,6 +220,23 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder,
+ ArrayRef<AffineMap> indexingMaps) {
+ // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
+ SmallVector<Attribute, 4> indexingMapsAttrVal;
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3485,19 +3502,24 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
-// Check general validity of input indexing map.
-static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+// Check general validity of input indexing map of
+// BatchMatmulOp/BatchReduceMatmulOp.
+template <typename OpTy>
+static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
AffineMap opIndexingMap,
AffineMap defaultIndexingMap, bool isLHS) {
+ assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+ isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+ "Expected BatchMatmulOp or BatchReduceMatmulOp");
// Check the result dims are valid.
if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Unexpected result dim expression (outside the set of default "
"result dims).";
// Check for valid number of result dims of input maps.
if (opIndexingMap.getNumResults() > 3)
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "no. of result dim expressions exceeds 3.";
auto hasValidBatchDim = [](AffineMap map) {
@@ -3507,60 +3529,83 @@ static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
// Check if the requested broadcast is valid.
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
- if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
- return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+ if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+ return batchVariantMatmulOp->emitOpError()
+ << "Invalid broadcast requested.";
} else if (!hasValidBatchDim(opIndexingMap)) {
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Invalid batch dimension expression.";
}
return success();
}
/// This function checks if the given AffineMap for the output of a
-/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
-/// dimensions are valid.
-static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
+/// dimensions and if the output map result dimensions are valid.
+template <typename OpTy>
+static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
AffineMap opIndexingMap) {
- if (opIndexingMap.getNumResults() != 3)
- return batchMatmulOp->emitOpError()
+ assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+ isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+ "Expected BatchMatmulOp or BatchReduceMatmulOp");
+ if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
+ opIndexingMap.getNumResults() != 3) {
+
+ return batchVariantMatmulOp->emitOpError()
<< "expects 3 dims, but got (" << opIndexingMap.getNumResults()
<< ").";
+ }
+ if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
+ opIndexingMap.getNumResults() != 2) {
+ return batchVariantMatmulOp->emitOpError()
+ << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
+ << ").";
+ }
- auto areValidOutputResultDim = [](AffineMap outputMap) {
- return outputMap.getResult(0).isFunctionOfDim(0) &&
- outputMap.getResult(1).isFunctionOfDim(1) &&
- outputMap.getResult(2).isFunctionOfDim(2);
+ auto areValidOutputResultDim = [&](AffineMap outputMap) {
+ return isa<BatchMatmulOp>(batchVariantMatmulOp)
+ ? outputMap.getResult(0).isFunctionOfDim(0) &&
+ outputMap.getResult(1).isFunctionOfDim(1) &&
+ outputMap.getResult(2).isFunctionOfDim(2)
+ : outputMap.getResult(0).isFunctionOfDim(1) &&
+ outputMap.getResult(1).isFunctionOfDim(2);
};
- if (!areValidOutputResultDim(opIndexingMap))
- return batchMatmulOp->emitOpError()
+ if (!areValidOutputResultDim(opIndexingMap)) {
+ return batchVariantMatmulOp->emitOpError()
<< "Invalid output map result dimension.";
+ }
return success();
}
/// Verifies the broadcast and transpose semantic specified by the explicit
-/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
+/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
+/// specified by opIndex.
+template <typename OpTy>
static LogicalResult
-verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
- unsigned opIndex) {
+verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
+ unsigned opIndex) {
SmallVector<AffineMap, 3> opIndexingMaps =
- batchMatmulOp.getIndexingMapsArray();
+ batchVariantMatmulOp.getIndexingMapsArray();
SmallVector<AffineMap, 3> defaultIndexingMaps =
- batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+ batchVariantMatmulOp.getDefaultIndexingMaps(
+ batchVariantMatmulOp->getContext());
if (opIndexingMaps.size() != 3)
- return batchMatmulOp->emitOpError()
+ return batchVariantMatmulOp->emitOpError()
<< "Indexing_map attribute must have 3 affine maps.";
auto opIndexingMap = opIndexingMaps[opIndex];
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
- if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+ if (opIndex == 2 &&
+ failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
return failure();
- if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
- opIndex == 0)))
+ if (opIndex != 2 &&
+ failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
+ defaultIndexingMap, opIndex == 0)))
return failure();
return success();
@@ -4045,7 +4090,7 @@ LogicalResult BatchMatmulOp::verify() {
return success();
for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
- if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+ if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
return failure();
}
return success();
@@ -5366,6 +5411,166 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
}
};
+//===----------------------------------------------------------------------===//
+// BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::reduction, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+SmallVector<AffineMap>
+BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+ AffineExpr d0, d1, d2, d3;
+ SmallVector<AffineMap> indexingMaps;
+ bindDims(context, d0, d1, d2, d3);
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
+ return indexingMaps;
+}
+
+unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchReduceMatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchReduceMatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map bcastMap is valid for this op.
+bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
+ bool isLHS) {
+ assert(bcastMap.getNumResults() < 3 &&
+ "Expected less than 3 result dim expr.");
+ bool isValid = false;
+ enum Indices { batchPos, mPos, nPos, kPos };
+ if (bcastMap.getNumResults() == 1) {
+ AffineExpr exp = bcastMap.getResult(0);
+ isValid = exp.isFunctionOfDim(kPos);
+ } else if (bcastMap.getNumResults() == 2) {
+ AffineExpr exp0 = bcastMap.getResult(0);
+ AffineExpr exp1 = bcastMap.getResult(1);
+ isValid = isLHS
+ ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+ : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ }
+ return isValid;
+}
+
+void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ auto toType = block.getArgument(2).getType();
+ Value castValA =
+ helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
+ Value castValB =
+ helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+ Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+ Value addVal =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+ yields.push_back(addVal);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+ if (parser.parseLSquare())
+ return failure();
+
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr)) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ }
+ indexingMapsAttr.push_back(mapAttr);
+
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+
+ if (parser.parseRSquare())
+ return failure();
+ }
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ indexingMapsAttr = llvm::map_to_vector(
+ BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+ return ::parseNamedStructuredOp(parser, result,
+ BatchReduceMatmulOp::getNumRegionArgs(),
+ BatchReduceMatmulOp::getRegionBuilder());
+}
+
+void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchReduceMatmulOp::verify() {
+ // Verification of pure batch_reduce_matmul is handled by
+ // verifyStructuredOpInterface().
+ if (!hasUserDefinedMaps())
+ return success();
+
+ for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
+ if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
+ return failure();
+ }
+ return success();
+}
+LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
+ SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void BatchReduceMatmulOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 0ec71c35497b1..feb627e84beca 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1024,6 +1024,36 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
// -----
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
+// CHECK: %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<3x7xf32>) {
+// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK: %[[VAL_7:.*]] = arith.mulf %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK: %[[VAL_8:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK: linalg.yield %[[VAL_8]] : f32
+// CHECK: } -> tensor<3x7xf32>
+// CHECK: return %[[VAL_3]] : tensor<3x7xf32>
+// CHECK: }
+
+func.func @batch_reduce_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<3x7xf32>) -> tensor<3x7xf32> {
+ %0 = linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
+ outs(%arg2: tensor<3x7xf32>) -> tensor<3x7xf32>
+ return %0 : tensor<3x7xf32>
+}
+
+// -----
+
// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 90ceadebbc1fa..ed51e5784a713 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1484,6 +1484,208 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
}
+// -----
+
+func.func @indexing_map_size_mismatch_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+ %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+ %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+
+}
+
+// -----
+
+func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{expected attribute value}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ ,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>) outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_A_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_B_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_C_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op expects 2 dims, but got (1).}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_C_map_result_dim_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid output map result dimension.}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
// -----
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 1bd9c8825b05e..93ef4f0a4a956 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1637,6 +1637,171 @@ func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memre
// -----
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+
+func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_batch_dim_A(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+
+func.func @batch_reduce_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+
+func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_batch_dim_B(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+
+func.func @batch_reduce_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_explicit_transpose_A(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x3xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+func.func @batch_reduce_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_explicit_transpose_B(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+func.func @batch_reduce_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_A_transpose_B(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
+// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK: return
+// CHECK: }
+func.func @batch_reduce_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.batch_reduce_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @batchmatmul_transpose_a
// CHECK: linalg.batch_matmul_transpose_a
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
>From 7a7f7c3f1588fc02002671b8254b2958be7e3c51 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 2 Apr 2025 06:19:29 -0700
Subject: [PATCH 2/5] 1. Improved indentation in op definition description and
test cases.
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 118 +++++------
.../Dialect/Linalg/generalize-named-ops.mlir | 26 ++-
mlir/test/Dialect/Linalg/invalid.mlir | 191 +++++++++---------
3 files changed, 159 insertions(+), 176 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e43112cccbd39..0100a91228451 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -690,34 +690,32 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
Example Transpose:
```
- linalg.matmul indexing_maps = [
- affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
- outs(%arg2: memref<3x7xf32>)
+ linalg.matmul
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
```
Example Broadcast:
```
- linalg.matmul indexing_maps = [
- affine_map<(d0, d1, d2) -> (d2)>, // broadcast
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
- outs(%arg2: memref<3x7xf32>)
+ linalg.matmul
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, // broadcast
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
```
Example Broadcast and transpose:
```
- linalg.matmul indexing_maps = [
- affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
- affine_map<(d0, d1, d2) -> (d2)>, // broadcast
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+ linalg.matmul
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+ affine_map<(d0, d1, d2) -> (d2)>, // broadcast
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>)
+ outs(%arg2: memref<3x7xf32>)
```
}];
@@ -954,35 +952,32 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
Example Transpose:
```
- linalg.batch_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
- outs(%arg2: memref<2x3x7xf32>)
+ linalg.batch_matmul
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
```
Example Broadcast:
```
- linalg.batch_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
- outs(%arg2: memref<2x3x7xf32>)
+ linalg.batch_matmul
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
```
Example Broadcast and Transpose:
```
- linalg.batch_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
- affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
- outs(%arg2: memref<2x3x7xf32>)
+ linalg.batch_matmul
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
```
}];
@@ -1074,7 +1069,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
LinalgContractionOpInterface]> {
let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
-The partial multiplication results are reduced into a 2D output.}];
+ The partial multiplication results are reduced into a 2D output.}];
let description = [{
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
@@ -1085,35 +1080,32 @@ The partial multiplication results are reduced into a 2D output.}];
Example Transpose:
```
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
- outs(%arg2: memref<3x7xf32>)
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
```
Example Broadcast:
```
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
- outs(%arg2: memref<3x7xf32>)
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
```
Example Broadcast and Transpose:
```
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
- affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
- outs(%arg2: memref<3x7xf32>)
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
```
}];
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index feb627e84beca..e15dfde02b7cc 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1024,22 +1024,20 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
// -----
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
// CHECK-LABEL: func.func @batch_reduce_matmul(
-// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x3x5xf32>,
-// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x5x7xf32>,
-// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
-// CHECK: %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<3x7xf32>) {
-// CHECK: ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
-// CHECK: %[[VAL_7:.*]] = arith.mulf %[[VAL_4]], %[[VAL_5]] : f32
-// CHECK: %[[VAL_8:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
-// CHECK: linalg.yield %[[VAL_8]] : f32
-// CHECK: } -> tensor<3x7xf32>
-// CHECK: return %[[VAL_3]] : tensor<3x7xf32>
-// CHECK: }
+// CHECK-SAME: %[[ARG_A:.*]]: tensor<2x3x5xf32>,
+// CHECK-SAME: %[[ARG_B:.*]]: tensor<2x5x7xf32>,
+// CHECK-SAME: %[[ARG_C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]],
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
func.func @batch_reduce_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<3x7xf32>) -> tensor<3x7xf32> {
%0 = linalg.batch_reduce_matmul indexing_maps = [
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index ed51e5784a713..a29bbb4c559b5 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1486,16 +1486,15 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
// -----
-func.func @indexing_map_size_mismatch_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
- ]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
- return
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, n, k)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
}
// -----
@@ -1503,12 +1502,11 @@ func.func @indexing_map_size_mismatch_batch_reduce_matmul(%arg0: memref<?x?x?xf3
func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
- ]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
- return
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
}
@@ -1518,11 +1516,10 @@ func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %a
// expected-error @+1 {{expected attribute value}}
linalg.batch_reduce_matmul indexing_maps = [
,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2 :memref<?x?xf32>)
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?xf32>)
return
}
@@ -1530,12 +1527,12 @@ func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %a
func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, n, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?xf32>)
return
}
@@ -1543,12 +1540,12 @@ func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg
func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, m)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?xf32>)
return
}
@@ -1556,12 +1553,12 @@ func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg
func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
return
}
@@ -1569,12 +1566,12 @@ func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memr
func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
return
}
@@ -1582,12 +1579,12 @@ func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?x
func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (k, batch)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
return
}
@@ -1595,12 +1592,12 @@ func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_b(%arg0: memref<?x?x
func.func @invalid_bcast_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d2)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>) outs(%arg2: memref<?x?xf32>)
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (n)>,
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>)
+ outs(%arg2: memref<?x?xf32>)
return
}
@@ -1608,12 +1605,12 @@ func.func @invalid_bcast_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1:
func.func @invalid_batch_dim_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (m, batch, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?xf32>)
return
}
@@ -1621,12 +1618,12 @@ func.func @invalid_batch_dim_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %ar
func.func @invalid_batch_dim_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (n, k, batch)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?xf32>)
return
}
@@ -1634,56 +1631,52 @@ func.func @invalid_batch_dim_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %ar
func.func @invalid_A_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
- return
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
}
// -----
func.func @invalid_B_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
- return
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n, k)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
}
// -----
func.func @invalid_C_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{'linalg.batch_reduce_matmul' op expects 2 dims, but got (1).}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1)>
- ]
- ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
- return
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m)>]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
}
// -----
func.func @invalid_C_map_result_dim_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid output map result dimension.}}
- linalg.batch_reduce_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d3)>
- ]
- ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
- return
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, k)>]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
}
// -----
>From 7acedcf48c0fb8e50b0f9a1e690d52468a190740 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Sat, 12 Apr 2025 22:06:26 -0700
Subject: [PATCH 3/5] -Added python test for linalg.batch_reduce_matmul
---
mlir/python/mlir/dialects/linalg/__init__.py | 13 +++
mlir/test/python/dialects/linalg/ops.py | 101 +++++++++++++++++++
2 files changed, 114 insertions(+)
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 63586a5bb8bbb..9a8a7b40e25e4 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -203,6 +203,19 @@ def batch_matmul(
)
+def batch_reduce_matmul(
+ *ins: Union[Operation, OpView, Value],
+ outs: Sequence[Union[Operation, OpView, Value]],
+ indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
+ cast: Optional[Union[TypeFn, Attribute]] = None,
+):
+ return _get_op_result_or_op_results(
+ _create_matmul_like_op(
+ BatchReduceMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+ )
+ )
+
+
def contract(
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index e32a911b24b11..79bb576ff6738 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -568,6 +568,107 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
print(module)
+# CHECK-LABEL: TEST: testBatchReduceMatmulOp
+ at run
+def testBatchReduceMatmulOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_shape = (5, 4, 8)
+ b_shape = (5, 8, 12)
+ b_transposed_shape = (5, 12, 8)
+ c_shape = (4, 12)
+
+ dimBatch = ir.AffineDimExpr.get(0)
+ dimM = ir.AffineDimExpr.get(1)
+ dimN = ir.AffineDimExpr.get(2)
+ dimK = ir.AffineDimExpr.get(3)
+
+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
+ b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
+ c_map = ir.AffineMap.get(4, 0, [dimM, dimN])
+
+ # CHECK: func.func @batch_reduce_matmul_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[A:.*]]: tensor<5x4x8xf32>,
+ RankedTensorType.get(a_shape, f32),
+ # CHECK-SAME: %[[Amem:.*]]: memref<5x4x8xf32>,
+ MemRefType.get(a_shape, f32),
+ # CHECK-SAME: %[[B:.*]]: tensor<5x8x12xf32>,
+ RankedTensorType.get(b_shape, f32),
+ # CHECK-SAME: %[[Bmem:.*]]: memref<5x8x12xf32>,
+ MemRefType.get(b_shape, f32),
+ # CHECK-SAME: %[[BTrans:.*]]: tensor<5x12x8xf32>,
+ RankedTensorType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[BTransmem:.*]]: memref<5x12x8xf32>,
+ MemRefType.get(b_transposed_shape, f32),
+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
+ RankedTensorType.get(c_shape, f32),
+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
+ MemRefType.get(c_shape, f32),
+ )
+ def batch_reduce_matmul_op(
+ A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
+ ):
+ # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.BatchReduceMatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, B),
+ outputs=(C,),
+ )
+ linalg.fill_builtin_region(res.operation)
+ # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.batch_reduce_matmul(A, B, outs=(C,))
+
+ # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.BatchReduceMatmulOp(
+ result_tensors=(C.type,),
+ inputs=(A, Btransposed),
+ outputs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+ linalg.fill_builtin_region(res.operation)
+ # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+ res = linalg.batch_reduce_matmul(
+ A,
+ Btransposed,
+ outs=(C,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+
+ # CHECK: linalg.batch_reduce_matmul ins(%[[Amem]], %[[Bmem]] : memref<5x4x8xf32>, memref<5x8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ res = linalg.BatchReduceMatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Bmem),
+ outputs=(Cmem,),
+ )
+ linalg.fill_builtin_region(res.operation)
+ # CHECK: linalg.batch_reduce_matmul ins(%[[Amem]], %[[Bmem]] : memref<5x4x8xf32>, memref<5x8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.batch_reduce_matmul(Amem, Bmem, outs=(Cmem,))
+
+ # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<5x4x8xf32>, memref<5x12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ res = linalg.BatchReduceMatmulOp(
+ result_tensors=[],
+ inputs=(Amem, Btransposedmem),
+ outputs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+ linalg.fill_builtin_region(res.operation)
+ # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<5x4x8xf32>, memref<5x12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+ linalg.batch_reduce_matmul(
+ Amem,
+ Btransposedmem,
+ outs=(Cmem,),
+ indexing_maps=[a_map, b_transposed_map, c_map],
+ )
+
+ print(module)
+
+
# CHECK-LABEL: TEST: testPackUnPackOp
@run
def testPackUnPackOp():
>From abfacb41afcc3cd3153338167f704ea5776cd720 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Mon, 28 Apr 2025 10:29:16 -0700
Subject: [PATCH 4/5] -Fixed invalid broadcast test. -Added consistent variable
and function naming in test cases. -Improved ops indexing_maps description.
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 58 ++++-----
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 18 ++-
.../Dialect/Linalg/generalize-named-ops.mlir | 12 +-
mlir/test/Dialect/Linalg/invalid.mlir | 112 +++++++++---------
mlir/test/Dialect/Linalg/named-ops.mlir | 102 ++++++++--------
5 files changed, 158 insertions(+), 144 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 0100a91228451..a41b406e02655 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -691,9 +691,9 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
Example Transpose:
```
linalg.matmul
- indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>]
+ indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
@@ -701,9 +701,9 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
Example Broadcast:
```
linalg.matmul
- indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, // broadcast
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>]
+ indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
@@ -711,9 +711,9 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
Example Broadcast and transpose:
```
linalg.matmul
- indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
- affine_map<(d0, d1, d2) -> (d2)>, // broadcast
- affine_map<(d0, d1, d2) -> (d0, d1)>]
+ indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
+ affine_map<(m, n, k) -> (k)>, // broadcast
+ affine_map<(m, n, k) -> (m, n)>]
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>)
outs(%arg2: memref<3x7xf32>)
```
@@ -773,7 +773,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
- /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+ /// Returns a list of AffineMap with the default matmul indexing charactristic.
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
@@ -953,9 +953,9 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
Example Transpose:
```
linalg.batch_matmul
- indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
outs(%arg2: memref<2x3x7xf32>)
```
@@ -963,9 +963,9 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
Example Broadcast:
```
linalg.batch_matmul
- indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+ indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
outs(%arg2: memref<2x3x7xf32>)
```
@@ -973,9 +973,9 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
Example Broadcast and Transpose:
```
linalg.batch_matmul
- indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
- affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+ indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
+ affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
outs(%arg2: memref<2x3x7xf32>)
```
@@ -1081,9 +1081,9 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
Example Transpose:
```
linalg.batch_reduce_matmul
- indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
@@ -1091,9 +1091,9 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
Example Broadcast:
```
linalg.batch_reduce_matmul
- indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+ indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
@@ -1101,9 +1101,9 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
Example Broadcast and Transpose:
```
linalg.batch_reduce_matmul
- indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
- affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+ indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
+ affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
+ affine_map<(batch, m, n, k) -> (m, n)>]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
outs(%arg2: memref<3x7xf32>)
```
@@ -1163,7 +1163,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
- /// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
+ /// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 404db532454c7..9dd4575240f6e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3996,9 +3996,12 @@ bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
} else if (bcastMap.getNumResults() == 2) {
AffineExpr exp0 = bcastMap.getResult(0);
AffineExpr exp1 = bcastMap.getResult(1);
- isValid = isLHS
- ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
- : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ isValid =
+ isLHS
+ ? ((exp0.isFunctionOfDim(batchPos) || exp0.isFunctionOfDim(mPos)) &&
+ exp1.isFunctionOfDim(kPos))
+ : ((exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(kPos)) ||
+ (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos)));
}
return isValid;
}
@@ -5459,9 +5462,12 @@ bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
} else if (bcastMap.getNumResults() == 2) {
AffineExpr exp0 = bcastMap.getResult(0);
AffineExpr exp1 = bcastMap.getResult(1);
- isValid = isLHS
- ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
- : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ isValid =
+ isLHS
+ ? ((exp0.isFunctionOfDim(batchPos) || exp0.isFunctionOfDim(mPos)) &&
+ exp1.isFunctionOfDim(kPos))
+ : ((exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(kPos)) ||
+ (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos)));
}
return isValid;
}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index e15dfde02b7cc..ae07b1b82228c 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1029,9 +1029,9 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
// CHECK-LABEL: func.func @batch_reduce_matmul(
-// CHECK-SAME: %[[ARG_A:.*]]: tensor<2x3x5xf32>,
-// CHECK-SAME: %[[ARG_B:.*]]: tensor<2x5x7xf32>,
-// CHECK-SAME: %[[ARG_C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
+// CHECK-SAME: %[[A:.*]]: tensor<2x3x5xf32>,
+// CHECK-SAME: %[[B:.*]]: tensor<2x5x7xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]],
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
@@ -1039,14 +1039,14 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
// CHECK: arith.addf
// CHECK: linalg.yield
-func.func @batch_reduce_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<3x7xf32>) -> tensor<3x7xf32> {
+func.func @batch_reduce_matmul(%A: tensor<2x3x5xf32>, %B: tensor<2x5x7xf32>, %C: tensor<3x7xf32>) -> tensor<3x7xf32> {
%0 = linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
- ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
- outs(%arg2: tensor<3x7xf32>) -> tensor<3x7xf32>
+ ins(%A, %B: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
+ outs(%C: tensor<3x7xf32>) -> tensor<3x7xf32>
return %0 : tensor<3x7xf32>
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a29bbb4c559b5..04c59777d9d7a 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1364,10 +1364,10 @@ func.func @invalid_bcast_batch_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x
// -----
-func.func @invalid_multi_dim_bcast_expr_batch_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+func.func @invalid_single_dim_bcast_expr_batch_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
// expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
linalg.batch_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
@@ -1377,14 +1377,14 @@ func.func @invalid_multi_dim_bcast_expr_batch_matmul_a(%arg0: memref<?x?xf32>, %
// -----
-func.func @invalid_multi_dim_bcast_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
+func.func @invalid_single_dim_bcast_expr_batch_matmul_B(%A: memref<?x?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?x?xf32>) {
// expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ ins(%A, %B : memref<?x?x?xf32>, memref<?x?xf32>) outs(%C: memref<?x?x?xf32>)
return
}
@@ -1486,7 +1486,11 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
// -----
-func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+//===----------------------------------------------------------------------===//
+// linalg.batch_reduce_matmul
+//===----------------------------------------------------------------------===//
+
+func.func @missing_one_indexing_map(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
linalg.batch_reduce_matmul
@@ -1499,7 +1503,7 @@ func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
// -----
-func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+func.func @missing_two_indexing_map(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
linalg.batch_reduce_matmul
@@ -1512,7 +1516,7 @@ func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
// -----
-func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+func.func @missing_indexing_map(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
// expected-error @+1 {{expected attribute value}}
linalg.batch_reduce_matmul indexing_maps = [
,
@@ -1525,157 +1529,157 @@ func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %a
// -----
-func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+func.func @invalid_dim_expr_A(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
// expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, n, k)>,
affine_map<(batch, m, n, k) -> (batch, k, n)>,
affine_map<(batch, m, n, k) -> (m, n)>]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2 :memref<?x?xf32>)
+ ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C :memref<?x?xf32>)
return
}
// -----
-func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+func.func @invalid_dim_expr_B(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
// expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
affine_map<(batch, m, n, k) -> (batch, k, m)>,
affine_map<(batch, m, n, k) -> (m, n)>]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2 :memref<?x?xf32>)
+ ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C :memref<?x?xf32>)
return
}
// -----
-func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+func.func @invalid_bcast_A(%A: memref<?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+ // expected-error @+1 {{Invalid broadcast requested}}
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch)>,
affine_map<(batch, m, n, k) -> (batch, k, n)>,
affine_map<(batch, m, n, k) -> (m, n)>]
- ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
+ ins(%A, %B : memref<?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?xf32>)
return
}
// -----
-func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+func.func @invalid_multi_dim_bcast_expr_A(%A: memref<?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+ // expected-error @+1 {{Invalid broadcast requested}}
linalg.batch_reduce_matmul
- indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k)>,
+ indexing_maps = [affine_map<(batch, m, n, k) -> (k, batch)>,
affine_map<(batch, m, n, k) -> (batch, k, n)>,
affine_map<(batch, m, n, k) -> (m, n)>]
- ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
+ ins(%A, %B : memref<?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?xf32>)
return
}
// -----
-func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+func.func @invalid_multi_dim_bcast_expr_B(%A: memref<?x?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
+ // expected-error @+1 {{Invalid broadcast requested}}
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
affine_map<(batch, m, n, k) -> (k, batch)>,
affine_map<(batch, m, n, k) -> (m, n)>]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
+ ins(%A, %B : memref<?x?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
return
}
// -----
-func.func @invalid_bcast_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+func.func @invalid_bcast_B(%A: memref<?x?x?xf32>, %B: memref<?xf32>, %C: memref<?x?xf32>) {
+ // expected-error @+1 {{Invalid broadcast requested}}
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
affine_map<(batch, m, n, k) -> (n)>,
affine_map<(batch, m, n, k) -> (batch, m, n)>]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>)
- outs(%arg2: memref<?x?xf32>)
+ ins(%A, %B : memref<?x?x?xf32>, memref<?xf32>)
+ outs(%C: memref<?x?xf32>)
return
}
// -----
-func.func @invalid_batch_dim_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+func.func @invalid_batch_dim_A(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+ // expected-error @+1 {{Invalid batch dimension expression}}
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (m, batch, k)>,
affine_map<(batch, m, n, k) -> (batch, k, n)>,
affine_map<(batch, m, n, k) -> (m, n)>]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2 :memref<?x?xf32>)
+ ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C :memref<?x?xf32>)
return
}
// -----
-func.func @invalid_batch_dim_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+func.func @invalid_batch_dim_B(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+ // expected-error @+1 {{Invalid batch dimension expression}}
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
affine_map<(batch, m, n, k) -> (n, k, batch)>,
affine_map<(batch, m, n, k) -> (m, n)>]
- ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2 :memref<?x?xf32>)
+ ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C :memref<?x?xf32>)
return
}
// -----
-func.func @invalid_A_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+func.func @invalid_A_map_result_num(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+ // expected-error @+1 {{no. of result dim expressions exceeds 3.}}
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k, k)>,
affine_map<(batch, m, n, k) -> (batch, k, n)>,
affine_map<(batch, m, n, k) -> (m, n)>]
- ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
+ ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?xf32>)
return
}
// -----
-func.func @invalid_B_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+func.func @invalid_B_map_result_num(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+ // expected-error @+1 {{no. of result dim expressions exceeds 3.}}
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
affine_map<(batch, m, n, k) -> (batch, k, n, k)>,
affine_map<(batch, m, n, k) -> (m, n)>]
- ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
+ ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?xf32>)
return
}
// -----
-func.func @invalid_C_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_reduce_matmul' op expects 2 dims, but got (1).}}
+func.func @invalid_C_map_result_num(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+ // expected-error @+1 {{expects 2 dims, but got (1).}}
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
affine_map<(batch, m, n, k) -> (batch, k, n)>,
affine_map<(batch, m, n, k) -> (m)>]
- ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
+ ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?xf32>)
return
}
// -----
-func.func @invalid_C_map_result_dim_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid output map result dimension.}}
+func.func @invalid_C_map_result_dim(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+ // expected-error @+1 {{Invalid output map result dimension.}}
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
affine_map<(batch, m, n, k) -> (batch, k, n)>,
affine_map<(batch, m, n, k) -> (m, k)>]
- ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%arg2: memref<?x?xf32>)
+ ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 93ef4f0a4a956..470bc1c78640c 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1637,25 +1637,29 @@ func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memre
// -----
+//===----------------------------------------------------------------------===//
+// linalg.batch_reduce_matmul
+//===----------------------------------------------------------------------===//
+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(
-// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
-// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
-// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL: func.func @bcast_k_to_fill_missing_dims_A(
+// CHECK-SAME: %[[A:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[C]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
-func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+func.func @bcast_k_to_fill_missing_dims_A(%A: memref<5xf32>, %B: memref<2x5x7xf32>, %C: memref<3x7xf32>) {
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
- ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ ins(%A, %B : memref<5xf32>, memref<2x5x7xf32>) outs(%C: memref<3x7xf32>)
return
}
@@ -1665,21 +1669,21 @@ func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf3
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_batch_dim_A(
-// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
-// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
-// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL: func.func @bcast_batch_dim_A(
+// CHECK-SAME: %[[A:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[C]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
-func.func @batch_reduce_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+func.func @bcast_batch_dim_A(%A: memref<3x5xf32>, %B: memref<2x5x7xf32>, %C: memref<3x7xf32>) {
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
- ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ ins(%A, %B : memref<3x5xf32>, memref<2x5x7xf32>) outs(%C: memref<3x7xf32>)
return
}
@@ -1689,21 +1693,21 @@ func.func @batch_reduce_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1:
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(
-// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
-// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
-// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL: func.func @bcast_batch_and_n_dim_B(
+// CHECK-SAME: %[[A:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[C]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
-func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+func.func @bcast_batch_and_n_dim_B(%A: memref<2x3x5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d3)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
- ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ ins(%A, %B : memref<2x3x5xf32>, memref<5xf32>) outs(%C: memref<3x7xf32>)
return
}
@@ -1713,21 +1717,21 @@ func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>,
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_batch_dim_B(
-// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
-// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5x7xf32>,
-// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL: func.func @bcast_batch_dim_B(
+// CHECK-SAME: %[[A:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[C]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
-func.func @batch_reduce_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+func.func @bcast_batch_dim_B(%A: memref<2x3x5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) {
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
- ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ ins(%A, %B : memref<2x3x5xf32>, memref<5x7xf32>) outs(%C: memref<3x7xf32>)
return
}
@@ -1737,20 +1741,20 @@ func.func @batch_reduce_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-// CHECK-LABEL: func.func @batch_reduce_matmul_explicit_transpose_A(
-// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x3xf32>,
-// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
-// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL: func.func @explicit_transpose_A(
+// CHECK-SAME: %[[A:.*]]: memref<2x5x3xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%[[C]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
-func.func @batch_reduce_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+func.func @explicit_transpose_A(%A: memref<2x5x3xf32>, %B: memref<2x5x7xf32>, %C: memref<3x7xf32>) {
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
- ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ ins(%A, %B : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%C: memref<3x7xf32>)
return
}
@@ -1760,20 +1764,20 @@ func.func @batch_reduce_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %a
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-// CHECK-LABEL: func.func @batch_reduce_matmul_explicit_transpose_B(
-// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
-// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
-// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL: func.func @explicit_transpose_B(
+// CHECK-SAME: %[[A:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<2x7x5xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%[[C]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
-func.func @batch_reduce_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+func.func @explicit_transpose_B(%A: memref<2x3x5xf32>, %B: memref<2x7x5xf32>, %C: memref<3x7xf32>) {
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
- ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+ ins(%A, %B : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%C: memref<3x7xf32>)
return
}
@@ -1783,20 +1787,20 @@ func.func @batch_reduce_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %a
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-// CHECK-LABEL: func.func @batch_reduce_matmul_bcast_A_transpose_B(
-// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
-// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
-// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL: func.func @bcast_A_transpose_B(
+// CHECK-SAME: %[[A:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<2x7x5xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[C]] : memref<3x7xf32>)
// CHECK: return
// CHECK: }
-func.func @batch_reduce_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+func.func @bcast_A_transpose_B(%A: memref<3x5xf32>, %B: memref<2x7x5xf32>, %C: memref<3x7xf32>) {
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
- ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+ ins(%A, %B : memref<3x5xf32>, memref<2x7x5xf32>) outs(%C: memref<3x7xf32>)
return
}
>From 88ce27b875343de1674ed952798942e27525f0f7 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Tue, 29 Apr 2025 23:50:27 -0700
Subject: [PATCH 5/5] -Formats example IR for the linalg.*matmul ops for syntax
highlighting.
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 18 +++++++++---------
1 file changed, 9 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index a41b406e02655..43c23bd9d154f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -689,7 +689,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
the maps if specified.
Example Transpose:
- ```
+ ```mlir
linalg.matmul
indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
affine_map<(m, n, k) -> (k, n)>,
@@ -699,7 +699,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
```
Example Broadcast:
- ```
+ ```mlir
linalg.matmul
indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
affine_map<(m, n, k) -> (k, n)>,
@@ -709,7 +709,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
```
Example Broadcast and transpose:
- ```
+ ```mlir
linalg.matmul
indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
affine_map<(m, n, k) -> (k)>, // broadcast
@@ -951,7 +951,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
arguments if specified.
Example Transpose:
- ```
+ ```mlir
linalg.batch_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
affine_map<(batch, m, n, k) -> (batch, k, n)>,
@@ -961,7 +961,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
```
Example Broadcast:
- ```
+ ```mlir
linalg.batch_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
affine_map<(batch, m, n, k) -> (batch, k, n)>,
@@ -971,7 +971,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
```
Example Broadcast and Transpose:
- ```
+ ```mlir
linalg.batch_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
@@ -1079,7 +1079,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
arguments if specified.
Example Transpose:
- ```
+ ```mlir
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
affine_map<(batch, m, n, k) -> (batch, k, n)>,
@@ -1089,7 +1089,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
```
Example Broadcast:
- ```
+ ```mlir
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (k)>, // broadcast
affine_map<(batch, m, n, k) -> (batch, k, n)>,
@@ -1099,7 +1099,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
```
Example Broadcast and Transpose:
- ```
+ ```mlir
linalg.batch_reduce_matmul
indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>, // broadcast
affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
More information about the Mlir-commits
mailing list