[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Wed Jan 22 08:00:33 PST 2025
https://github.com/shahidact updated https://github.com/llvm/llvm-project/pull/122275
>From 0f0aa7df9029db9efea13ab0b0b817d63418b6f9 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 13 Nov 2024 01:08:25 -0800
Subject: [PATCH 1/7] [MLIR][Linalg] Introduce broadcast/transpose semantic to
'linalg.batch_matmul' operation.
Goals:
1. To add syntax and semantic to 'batch_matmul' without changing any of the
existing syntax expectations for current usage. batch_matmul is still
just batch_matmul.
2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS
infra.
Scope of this patch:
To expose broadcast and transpose semantics on the 'batch_matmul'.
The broadcast and transpose semantic is as follows:
By default 'linalg.batch_matmul' behavior will remain as is.
Broadcast and Transpose semantics can be appiled by specifying the
explicit attribute 'indexing_maps' as shown below.This is a list attribute, so the list
must include all the maps if specified.
Example Transpose:
```
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
outs(%arg2: memref<2x3x7xf32>)
```
Example Broadcast:
```
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d3)>, //broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
ins(%arg0, %arg1 : memref<5xf32>,memref<2x5x7xf32>)
outs(%arg2: memref<2x3x7xf32>)
```
Example Broadcast and transpose:
```
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, //broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
outs(%arg2: memref<2x3x7xf32>)
```
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 69 ------
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 124 ++++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 217 ++++++++++++++++++
.../Linalg/Transforms/DropUnitDims.cpp | 3 +-
.../linalg/opdsl/ops/core_named_ops.py | 18 --
.../Dialect/Linalg/generalize-named-ops.mlir | 23 ++
mlir/test/Dialect/Linalg/invalid.mlir | 118 ++++++++++
mlir/test/Dialect/Linalg/named-ops.mlir | 148 ++++++++++++
8 files changed, 632 insertions(+), 88 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b0ea1f76955816..496a323249e852 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul
- cpp_class_name: BatchMatmulOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul_transpose_a
cpp_class_name: BatchMatmulTransposeAOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..47b871aa322309 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,130 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}
+//===----------------------------------------------------------------------===//
+// Op definition for BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
+ /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+
+ let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
+ let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+
+ Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+ 'indexing_maps' as shown below.This is a list attribute, so the list must include all
+ the maps if specified.
+
+ Example Transpose:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+
+ Example Broadcast:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+
+ Example Broadcast and transpose:
+ ```
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ ```
+}];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, BatchMatmulOp::getRegionBuilder(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addOperands(operands);
+ $_state.addAttributes(attributes);
+ $_state.addTypes(resultTensorTypes);
+ (void)$_state.addRegion(),
+ BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ /// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ bool hasDynamicIndexingMaps() { return true; }
+ std::string getLibraryCallName();
+ /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+ /// user defined indexing maps are not equal to default map.
+ bool hasUserDefinedMaps();
+ }];
+}
+
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8973e87c063b33..868892d1e5f5cc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,23 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder,
+ ArrayRef<AffineMap> indexingMaps) {
+ // Initialize indexingMaps attribute, for BatchMatmulOp.
+ SmallVector<Attribute, 4> indexingMapsAttrVal;
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3450,6 +3467,46 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+ assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr.");
+ AffineExpr exp = bcastMap.getResult(0);
+ return exp.isFunctionOfDim(0);
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult
+verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
+ unsigned opIndex) {
+ SmallVector<AffineMap, 3> opIndexingMaps =
+ batchMatmulOp.getIndexingMapsArray();
+ SmallVector<AffineMap, 3> defaultIndexingMaps =
+ batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+
+ auto opIndexingMap = opIndexingMaps[opIndex];
+ auto defaultIndexingMap = defaultIndexingMaps[opIndex];
+ // Check general validity of indexing map results.
+ if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+ return batchMatmulOp->emitOpError()
+ << "Unexpected dim expression in map result.";
+ // Check if the requested broadcast is valid.
+ if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+ if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, opIndex == 0)) {
+ return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+ }
+ } else {
+ if (!isValidBatchDim(opIndexingMap)) {
+ return batchMatmulOp->emitOpError()
+ << "Invalid batch dimension expression.";
+ }
+ }
+ return success();
+}
+
namespace mlir {
namespace linalg {
@@ -3611,5 +3668,165 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// Implementation of BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineMap>
+BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+ AffineExpr d0, d1, d2, d3;
+ SmallVector<AffineMap> indexingMaps;
+ bindDims(context, d0, d1, d2, d3);
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+ indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
+ return indexingMaps;
+}
+
+SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchMatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchMatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
+ assert(bcastMap.getNumResults() < 3 && "Expected single result dim expr.");
+ bool isValid = false;
+ enum Indices { batchPos, mPos, nPos, kPos };
+ if (bcastMap.getNumResults() == 1) {
+ AffineExpr exp = bcastMap.getResult(0);
+ isValid = exp.isFunctionOfDim(kPos);
+ } else if (bcastMap.getNumResults() == 2) {
+ AffineExpr exp0 = bcastMap.getResult(0);
+ AffineExpr exp1 = bcastMap.getResult(1);
+ isValid = isLHS
+ ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+ : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+ }
+ return isValid;
+}
+
+void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(3 > 0 && block.getNumArguments() == 3 &&
+ "BatchMatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+
+ if (parser.parseLSquare())
+ return failure();
+
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr)) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ }
+ indexingMapsAttr.push_back(mapAttr);
+
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+
+ if (parser.parseRSquare())
+ return failure();
+ }
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ indexingMapsAttr = llvm::map_to_vector(
+ BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
+ return ::parseNamedStructuredOp(parser, result,
+ BatchMatmulOp::getNumRegionArgs(),
+ BatchMatmulOp::getRegionBuilder());
+}
+
+void BatchMatmulOp::print(OpAsmPrinter &p) {
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ BatchMatmulOp::getDefaultIndexingMaps(getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchMatmulOp::verify() {
+ // Verification of pure batch_matmul is handled by
+ // verifyStructuredOpInterface().
+ if (!hasUserDefinedMaps())
+ return success();
+
+ for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
+ if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+ return failure();
+ }
+ return success();
+}
+
+LogicalResult BatchMatmulOp::fold(FoldAdaptor,
+ SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void BatchMatmulOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9b97865990bfdd..a5d4c7fe9908c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -935,7 +935,8 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
ValueRange{collapsedInit});
for (auto attr : contractionOp->getAttrs()) {
- if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+ if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
+ attr.getName() == "indexing_maps")
continue;
collapsedOp->setAttr(attr.getName(), attr.getValue());
}
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c95cd5eecfffca..040663c882a086 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -484,24 +484,6 @@ def batch_mmt4d(
) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
- at linalg_structured_op
-def batch_matmul(
- A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
- """Performs a batched matrix multiplication of two 3D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.b, D.k, D.n]
- )
-
-
@linalg_structured_op
def batch_matmul_transpose_a(
A=TensorDef(T1, Batch, S.K, S.M),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..638238b5c38a60 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1002,3 +1002,26 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
// -----
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func.func @batch_matmul(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xf32>, %[[VAL_1:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[VAL_2]] : tensor<?x?x?xf32>) {
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+
+// -----
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a59472377a732c..e14124097fe2b6 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1142,3 +1142,121 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>
%0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32>
return %0 : tensor<2x12x11x2xf32>
}
+
+// -----
+
+func.func @missing_indexing_map_batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+ // expected-error @+1 {{expected attribute value}}
+ linalg.batch_matmul indexing_maps = [
+ ,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_matmul_a(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+ // expected-error @+1 {{Unexpected dim expression in map result}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 :tensor<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_matmul_b(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+ // expected-error @+1 {{Unexpected dim expression in map result}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 :tensor<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 68aa5a85b5e0e6..75200fc2e5b1ee 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1487,6 +1487,154 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
// -----
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func.func @batch_matmul_bcast_batch_and_m_dim_A(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+func.func @batch_matmul_bcast_batch_and_m_dim_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func.func @batch_matmul_bcast_batch_dim_A(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func.func @batch_matmul_bcast_batch_and_n_dim_B(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) {
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<2x3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func.func @batch_matmul_bcast_batch_dim_B(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+
+func.func @batch_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<2x3x7xf32>) {
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @batch_matmul_explicit_transpose_a
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @batch_matmul_explicit_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @batch_matmul_explicit_transpose_b
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @batch_matmul_explicit_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func.func @batch_matmul_bcast_A_transpose_B(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<2x7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @batchmatmul_transpose_a
// CHECK: linalg.batch_matmul_transpose_a
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
>From 46d7f36c87472c96b734557ef4946619cc28e7e2 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Fri, 10 Jan 2025 10:02:47 -0800
Subject: [PATCH 2/7] -Added output map verification and corresponding tests.
-Replaced assert for the count of number of dim expression with proper error
reporting and new test case. -Fixed typos.
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 53 ++++++++++++++-----
mlir/test/Dialect/Linalg/invalid.mlir | 66 ++++++++++++++++++++++--
2 files changed, 101 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 868892d1e5f5cc..db48c2baf009a8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3476,7 +3476,18 @@ static bool isValidBatchDim(AffineMap bcastMap) {
return exp.isFunctionOfDim(0);
}
-/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// Checks if the given AffineMap's result dimensions are valid output result
+/// dimensions.
+static bool isValidOutputResultDim(AffineMap outputMap) {
+ enum Indices { batchPos, mPos, nPos };
+ AffineExpr exp0 = outputMap.getResult(batchPos);
+ AffineExpr exp1 = outputMap.getResult(mPos);
+ AffineExpr exp2 = outputMap.getResult(nPos);
+ return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) &&
+ exp2.isFunctionOfDim(nPos);
+}
+
+/// Verifies the broadcast and transpose semantic specified by the explicit
/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
/// opIndex.
static LogicalResult
@@ -3490,19 +3501,35 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
auto opIndexingMap = opIndexingMaps[opIndex];
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
// Check general validity of indexing map results.
- if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
- return batchMatmulOp->emitOpError()
- << "Unexpected dim expression in map result.";
- // Check if the requested broadcast is valid.
- if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
- if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, opIndex == 0)) {
- return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+ if (opIndex < 2) {
+ if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+ return batchMatmulOp->emitOpError()
+ << "Unexpected dim expression in map result.";
+ // Check if the requested broadcast is valid.
+ if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+ if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap,
+ opIndex == 0)) {
+ return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+ }
+ } else {
+ // Check for valid number of result dims of input maps.
+ if (opIndexingMap.getNumResults() != 3)
+ return batchMatmulOp->emitOpError()
+ << "no. of result dim expression cannot exceed 3.";
+
+ if (!isValidBatchDim(opIndexingMap))
+ return batchMatmulOp->emitOpError()
+ << "Invalid batch dimension expression.";
}
} else {
- if (!isValidBatchDim(opIndexingMap)) {
+ // Check for valid number of result dims of output map.
+ if (opIndexingMap.getNumResults() != 3)
return batchMatmulOp->emitOpError()
- << "Invalid batch dimension expression.";
- }
+ << "no. of result dim expression cannot exceed 3.";
+
+ if (!isValidOutputResultDim(opIndexingMap))
+ return batchMatmulOp->emitOpError()
+ << "Invalid output map result dimension.";
}
return success();
}
@@ -3724,7 +3751,7 @@ bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
ArrayRef<NamedAttribute> attrs) {
- assert(3 > 0 && block.getNumArguments() == 3 &&
+ assert(block.getNumArguments() == 3 &&
"BatchMatmulOp regionBuilder expects 3 (>=0) args");
RegionBuilderHelper helper(b, block);
SmallVector<Value> yields;
@@ -3806,7 +3833,7 @@ LogicalResult BatchMatmulOp::verify() {
if (!hasUserDefinedMaps())
return success();
- for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
+ for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
return failure();
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index e14124097fe2b6..5faba3a815b8b9 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1145,7 +1145,7 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>
// -----
-func.func @missing_indexing_map_batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+func.func @missing_indexing_map_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
// expected-error @+1 {{expected attribute value}}
linalg.batch_matmul indexing_maps = [
,
@@ -1159,27 +1159,27 @@ func.func @missing_indexing_map_batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: te
// -----
-func.func @invalid_dim_expr_batch_matmul_a(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
// expected-error @+1 {{Unexpected dim expression in map result}}
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
- ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 :tensor<?x?x?xf32>)
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
return
}
// -----
-func.func @invalid_dim_expr_batch_matmul_b(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+func.func @invalid_dim_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
// expected-error @+1 {{Unexpected dim expression in map result}}
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
- ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 :tensor<?x?x?xf32>)
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
return
}
@@ -1260,3 +1260,59 @@ func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: mem
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
return
}
+
+// -----
+
+func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_C_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid output map result dimension.}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?x?xf32>)
+ return
+}
>From 5670406f4b6673634110589ab5a87295a6511e42 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 15 Jan 2025 01:05:07 -0800
Subject: [PATCH 3/7] *Added checks for extended semantics and exit gracefully
in user passes. *Added and udated test cases. *Refactored verification logic.
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 80 ++++++++++++-------
.../Linalg/Transforms/BlockPackMatmul.cpp | 10 +++
.../Linalg/Transforms/DropUnitDims.cpp | 9 +++
.../Linalg/Transforms/TransposeMatmul.cpp | 8 ++
.../Dialect/Linalg/generalize-named-ops.mlir | 15 ++--
mlir/test/Dialect/Linalg/invalid.mlir | 2 +-
6 files changed, 85 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index db48c2baf009a8..aa3be4a3763e8f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3471,7 +3471,6 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
/// It checks if the first result dimension is a function of the first
/// dimension.
static bool isValidBatchDim(AffineMap bcastMap) {
- assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr.");
AffineExpr exp = bcastMap.getResult(0);
return exp.isFunctionOfDim(0);
}
@@ -3487,6 +3486,48 @@ static bool isValidOutputResultDim(AffineMap outputMap) {
exp2.isFunctionOfDim(nPos);
}
+// Check general validity of input indexing map.
+static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+ AffineMap opIndexingMap,
+ AffineMap defaultIndexingMap, bool isLHS) {
+ // Check the result dims are valid.
+ if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+ return batchMatmulOp->emitOpError()
+ << "Unexpected dim expression in map result.";
+
+ // Check for valid number of result dims of input maps.
+ if (opIndexingMap.getNumResults() > 3)
+ return batchMatmulOp->emitOpError()
+ << "no. of result dim expression cannot exceed 3.";
+
+ // Check if the requested broadcast is valid.
+ if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+ if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+ return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+ } else if (!isValidBatchDim(opIndexingMap)) {
+ return batchMatmulOp->emitOpError()
+ << "Invalid batch dimension expression.";
+ }
+ return success();
+}
+
+/// This function checks if the given AffineMap for the output of a
+/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
+/// dimensions are valid.
+static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+ AffineMap opIndexingMap) {
+ if (opIndexingMap.getNumResults() != 3)
+ return batchMatmulOp->emitOpError()
+ << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
+ << ").";
+
+ if (!isValidOutputResultDim(opIndexingMap))
+ return batchMatmulOp->emitOpError()
+ << "Invalid output map result dimension.";
+
+ return success();
+}
+
/// Verifies the broadcast and transpose semantic specified by the explicit
/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
/// opIndex.
@@ -3500,37 +3541,14 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
auto opIndexingMap = opIndexingMaps[opIndex];
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
- // Check general validity of indexing map results.
- if (opIndex < 2) {
- if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
- return batchMatmulOp->emitOpError()
- << "Unexpected dim expression in map result.";
- // Check if the requested broadcast is valid.
- if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
- if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap,
- opIndex == 0)) {
- return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
- }
- } else {
- // Check for valid number of result dims of input maps.
- if (opIndexingMap.getNumResults() != 3)
- return batchMatmulOp->emitOpError()
- << "no. of result dim expression cannot exceed 3.";
-
- if (!isValidBatchDim(opIndexingMap))
- return batchMatmulOp->emitOpError()
- << "Invalid batch dimension expression.";
- }
- } else {
- // Check for valid number of result dims of output map.
- if (opIndexingMap.getNumResults() != 3)
- return batchMatmulOp->emitOpError()
- << "no. of result dim expression cannot exceed 3.";
- if (!isValidOutputResultDim(opIndexingMap))
- return batchMatmulOp->emitOpError()
- << "Invalid output map result dimension.";
- }
+ if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+ return failure();
+
+ if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
+ opIndex == 0)))
+ return failure();
+
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 57344f986480da..de05d0f7e5b832 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -138,6 +138,16 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
FailureOr<PackResult>
linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
const ControlBlockPackMatmulFn &controlPackMatmul) {
+ // Check to not let go the batch_matmul with extended semantic, through this
+ // transform.
+ if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
+ if (batchMatmulOp->hasUserDefinedMaps()) {
+ return rewriter.notifyMatchFailure(
+ *batchMatmulOp,
+ "only batch_matmul ops with non-extended semantics are supported");
+ }
+ }
+
if (linalgOp.hasPureBufferSemantics())
return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index a5d4c7fe9908c5..a5ebe7628accda 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -906,6 +906,15 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
LogicalResult matchAndRewrite(FromOpTy contractionOp,
PatternRewriter &rewriter) const override {
+ // Check to not let go the batch_matmul with extended semantic, through this
+ // transform.
+ if (std::is_same<FromOpTy, BatchMatmulOp>::value) {
+ if (contractionOp.hasUserDefinedMaps()) {
+ return rewriter.notifyMatchFailure(
+ contractionOp,
+ "only batch_matmul ops with non-extended semantics are supported");
+ }
+ }
auto loc = contractionOp.getLoc();
auto inputs = contractionOp.getDpsInputs();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 6b934f7e8157d4..8d12f8a98dbdd1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -88,6 +88,14 @@ FailureOr<Operation *>
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp batchMatmulOp,
bool transposeLHS) {
+ // Check to not let go the batch_matmul with extended semantic, through this
+ // transform.
+ if (batchMatmulOp.hasUserDefinedMaps()) {
+ return rewriter.notifyMatchFailure(
+ batchMatmulOp,
+ "only batch_matmul ops with non-extended semantics are supported");
+ }
+
if (!bufferization::hasTensorSemantics(batchMatmulOp))
return rewriter.notifyMatchFailure(
batchMatmulOp, "only matmul ops with tensors are supported");
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 638238b5c38a60..42eb940006b96d 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1007,21 +1007,22 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-LABEL: func.func @batch_matmul(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xf32>, %[[VAL_1:.*]]: tensor<?x?x?xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
-// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[VAL_2]] : tensor<?x?x?xf32>) {
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> {
+// CHECK: %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<2x3x7xf32>) {
// CHECK: arith.mulf
// CHECK: arith.addf
-func.func @batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> {
%0 = linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
- ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
- outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
- return %0 : tensor<?x?x?xf32>
+ ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
+ outs(%arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32>
+ return %0 : tensor<2x3x7xf32>
}
// -----
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 5faba3a815b8b9..5554cb082dd85f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1292,7 +1292,7 @@ func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
// -----
func.func @invalid_C_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+ // expected-error @+1 {{'linalg.batch_matmul' op expects 3 dims, but got (2).}}
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
>From 5044162a386fc8bd394b2bb2a5f447abfbe90b0f Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 15 Jan 2025 03:20:03 -0800
Subject: [PATCH 4/7] *Added logic and tests to verify the size of supplied
indexing_map attribute.
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 ++++
mlir/test/Dialect/Linalg/invalid.mlir | 27 ++++++++++++++++++++++++
2 files changed, 31 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index aa3be4a3763e8f..77d29c89d1d727 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3539,6 +3539,10 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
SmallVector<AffineMap, 3> defaultIndexingMaps =
batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+ if (opIndexingMaps.size() != 3)
+ return batchMatmulOp->emitOpError()
+ << "Indexing_map attribute must have 3 affine maps.";
+
auto opIndexingMap = opIndexingMaps[opIndex];
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 5554cb082dd85f..8bf2f858ff7c97 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1145,6 +1145,33 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>
// -----
+func.func @indexing_map_size_mismatch_batch_matmul(%arg0: memref<?x?x?xf32>,
+ %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @indexing_map_size_one_batch_matmul(%arg0: memref<?x?x?xf32>,
+ %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
func.func @missing_indexing_map_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
// expected-error @+1 {{expected attribute value}}
linalg.batch_matmul indexing_maps = [
>From 0f5614bc9767db56bc64b29cf411debf1d1f4ca1 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Thu, 16 Jan 2025 06:48:49 -0800
Subject: [PATCH 5/7] *Added logic to update the indexing_map attribute for
collapsed MatmulOp. *Updated test names and comments for consistency.
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 10 ++++----
.../Linalg/Transforms/DropUnitDims.cpp | 24 ++++++++++++++-----
.../Linalg/Transforms/TransposeMatmul.cpp | 5 +---
mlir/test/Dialect/Linalg/named-ops.mlir | 8 +++----
4 files changed, 28 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 47b871aa322309..cce833b31e64e8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -692,8 +692,8 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
them to the same data type as the accumulator/output.
Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
- 'indexing_maps' as shown below.This is a list attribute, so the list must include all
- the maps if specified.
+ 'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+ arguments if specified.
Example Transpose:
```
@@ -709,7 +709,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
Example Broadcast:
```
linalg.batch_matmul indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
+ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
@@ -717,7 +717,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
outs(%arg2: memref<2x3x7xf32>)
```
- Example Broadcast and transpose:
+ Example Broadcast and Transpose:
```
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
@@ -783,7 +783,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
return regionBuilder;
}
- /// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
+ /// Returns a list with default AffineMap(s), i.e. without broadcasts and transpositions.
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index a5ebe7628accda..904ad220d5551e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -32,6 +32,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include <type_traits>
namespace mlir {
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
@@ -908,11 +909,11 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
PatternRewriter &rewriter) const override {
// Check to not let go the batch_matmul with extended semantic, through this
// transform.
- if (std::is_same<FromOpTy, BatchMatmulOp>::value) {
+ if (std::is_same<FromOpTy, BatchMatmulOp>::value ||
+ std::is_same<FromOpTy, MatmulOp>::value) {
if (contractionOp.hasUserDefinedMaps()) {
return rewriter.notifyMatchFailure(
- contractionOp,
- "only batch_matmul ops with non-extended semantics are supported");
+ contractionOp, "ops with user-defined maps are not supported");
}
}
@@ -944,10 +945,21 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
ValueRange{collapsedInit});
for (auto attr : contractionOp->getAttrs()) {
- if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
- attr.getName() == "indexing_maps")
+ if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
continue;
- collapsedOp->setAttr(attr.getName(), attr.getValue());
+
+ // Update the indexing_maps attribute for the collapsed MatmulOp.
+ if (attr.getName() == "indexing_maps" &&
+ std::is_same<FromOpTy, BatchMatmulOp>::value &&
+ std::is_same<ToOpTy, MatmulOp>::value) {
+ SmallVector<Attribute, 3> indexingMapsAttr = llvm::map_to_vector(
+ MatmulOp::getDefaultIndexingMaps(rewriter.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ collapsedOp->setAttr(attr.getName(),
+ rewriter.getArrayAttr(indexingMapsAttr));
+ } else {
+ collapsedOp->setAttr(attr.getName(), attr.getValue());
+ }
}
auto results = contractionOp.getResults();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 8d12f8a98dbdd1..e624f589917d1e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -88,12 +88,9 @@ FailureOr<Operation *>
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp batchMatmulOp,
bool transposeLHS) {
- // Check to not let go the batch_matmul with extended semantic, through this
- // transform.
if (batchMatmulOp.hasUserDefinedMaps()) {
return rewriter.notifyMatchFailure(
- batchMatmulOp,
- "only batch_matmul ops with non-extended semantics are supported");
+ batchMatmulOp, "ops with user-defined maps are not supported");
}
if (!bufferization::hasTensorSemantics(batchMatmulOp))
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 75200fc2e5b1ee..1dcd6a9f25af59 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1580,11 +1580,11 @@ func.func @batch_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memre
// -----
-// CHECK-LABEL: func @batch_matmul_explicit_transpose_a
+// CHECK-LABEL: func @batch_matmul_explicit_transpose_A
// CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
-func.func @batch_matmul_explicit_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+func.func @batch_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
@@ -1596,11 +1596,11 @@ func.func @batch_matmul_explicit_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: me
// -----
-// CHECK-LABEL: func @batch_matmul_explicit_transpose_b
+// CHECK-LABEL: func @batch_matmul_explicit_transpose_B
// CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
-func.func @batch_matmul_explicit_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
+func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
>From 7ce1a68c015a56e198d920c4c0d6a81ab89675d0 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Tue, 21 Jan 2025 23:03:05 -0800
Subject: [PATCH 6/7] *Added logic to ensure the indexing_map attribute can be
dropped for collapsed contraction op. *Refactored some tests and methods for
better naming, comments and readability.
---
.../Dialect/Linalg/IR/LinalgInterfaces.td | 5 +-
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 6 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 55 ++++++++-----------
.../Linalg/Transforms/DropUnitDims.cpp | 28 ++--------
mlir/test/Dialect/Linalg/invalid.mlir | 8 +--
mlir/test/Dialect/Linalg/named-ops.mlir | 4 +-
6 files changed, 42 insertions(+), 64 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 244db23925ab3c..98a5fd278a9977 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -710,7 +710,10 @@ def LinalgStructuredInterface
>,
InterfaceMethod<
/*desc=*/[{
- Return true if the user has supplied an explicit indexing maps for this op.
+ Returns true if the user has supplied explicit indexing maps that are
+ different from default indexing maps for this op. Returns `false` otherwise.
+ Note, if the user define maps that are identical to the default maps,
+ this method returns `false`.
}],
/*retTy=*/"bool",
/*methodName=*/"hasUserDefinedMaps",
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index cce833b31e64e8..7637727e8e5e61 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -674,8 +674,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
static unsigned getNumRegionArgs();
std::string getLibraryCallName();
bool hasDynamicIndexingMaps();
- /// Check if the op has broadcast and/or transpose semantic. Returns true if the
- /// user defined indexing maps are not equal to default map.
+ /// Returns true if the user defined indexing maps are not equal to default maps.
bool hasUserDefinedMaps();
}];
}
@@ -797,8 +796,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
static unsigned getNumRegionArgs();
bool hasDynamicIndexingMaps() { return true; }
std::string getLibraryCallName();
- /// Check if the op has broadcast and/or transpose semantic. Returns true if the
- /// user defined indexing maps are not equal to default map.
+ /// Returns true if the user defined indexing maps are not equal to default maps.
bool hasUserDefinedMaps();
}];
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 77d29c89d1d727..ce6367ec46619e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3423,11 +3423,10 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
return arith::ConstantOp::materialize(builder, value, type, loc);
}
-/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
-/// defaultMap.
-static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
- auto explicitRange = explictMap.getResults();
- auto defaultRange = defaultMap.getResults();
+// Returns true if the result expression of `subMap` are a subset of `fullMap`.
+static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
+ auto explicitRange = subMap.getResults();
+ auto defaultRange = fullMap.getResults();
DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
llvm::set_union(explicitSet, defaultSet);
@@ -3452,7 +3451,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
auto opIndexingMap = opIndexingMaps[opIndex];
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
// Check general validity of indexing map results.
- if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+ if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
return matmulOp->emitOpError()
<< "Unexpected dim expression in map result.";
@@ -3467,44 +3466,31 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
return success();
}
-/// Checks if the given AffineMap represents a valid batch dimension.
-/// It checks if the first result dimension is a function of the first
-/// dimension.
-static bool isValidBatchDim(AffineMap bcastMap) {
- AffineExpr exp = bcastMap.getResult(0);
- return exp.isFunctionOfDim(0);
-}
-
-/// Checks if the given AffineMap's result dimensions are valid output result
-/// dimensions.
-static bool isValidOutputResultDim(AffineMap outputMap) {
- enum Indices { batchPos, mPos, nPos };
- AffineExpr exp0 = outputMap.getResult(batchPos);
- AffineExpr exp1 = outputMap.getResult(mPos);
- AffineExpr exp2 = outputMap.getResult(nPos);
- return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) &&
- exp2.isFunctionOfDim(nPos);
-}
-
// Check general validity of input indexing map.
static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
AffineMap opIndexingMap,
AffineMap defaultIndexingMap, bool isLHS) {
// Check the result dims are valid.
- if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+ if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
return batchMatmulOp->emitOpError()
- << "Unexpected dim expression in map result.";
+ << "Unexpected result dim expression (outside the set of default "
+ "result dims).";
// Check for valid number of result dims of input maps.
if (opIndexingMap.getNumResults() > 3)
return batchMatmulOp->emitOpError()
- << "no. of result dim expression cannot exceed 3.";
+ << "no. of result dim expressions exceeds 3.";
+
+ auto hasValidBatchDim = [](AffineMap map) {
+ AffineExpr batchDim = map.getResult(0);
+ return batchDim.isFunctionOfDim(0);
+ };
// Check if the requested broadcast is valid.
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
- } else if (!isValidBatchDim(opIndexingMap)) {
+ } else if (!hasValidBatchDim(opIndexingMap)) {
return batchMatmulOp->emitOpError()
<< "Invalid batch dimension expression.";
}
@@ -3521,7 +3507,13 @@ static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
<< "expects 3 dims, but got (" << opIndexingMap.getNumResults()
<< ").";
- if (!isValidOutputResultDim(opIndexingMap))
+ auto areValidOutputResultDim = [](AffineMap outputMap) {
+ return outputMap.getResult(0).isFunctionOfDim(0) &&
+ outputMap.getResult(1).isFunctionOfDim(1) &&
+ outputMap.getResult(2).isFunctionOfDim(2);
+ };
+
+ if (!areValidOutputResultDim(opIndexingMap))
return batchMatmulOp->emitOpError()
<< "Invalid output map result dimension.";
@@ -3755,7 +3747,8 @@ bool BatchMatmulOp::hasUserDefinedMaps() {
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
- assert(bcastMap.getNumResults() < 3 && "Expected single result dim expr.");
+ assert(bcastMap.getNumResults() < 3 &&
+ "Expected less than 3 result dim expr.");
bool isValid = false;
enum Indices { batchPos, mPos, nPos, kPos };
if (bcastMap.getNumResults() == 1) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 904ad220d5551e..efea4dea66d2e7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -907,14 +907,9 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
LogicalResult matchAndRewrite(FromOpTy contractionOp,
PatternRewriter &rewriter) const override {
- // Check to not let go the batch_matmul with extended semantic, through this
- // transform.
- if (std::is_same<FromOpTy, BatchMatmulOp>::value ||
- std::is_same<FromOpTy, MatmulOp>::value) {
- if (contractionOp.hasUserDefinedMaps()) {
- return rewriter.notifyMatchFailure(
- contractionOp, "ops with user-defined maps are not supported");
- }
+ if (contractionOp.hasUserDefinedMaps()) {
+ return rewriter.notifyMatchFailure(
+ contractionOp, "ops with user-defined maps are not supported");
}
auto loc = contractionOp.getLoc();
@@ -945,21 +940,10 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
ValueRange{collapsedInit});
for (auto attr : contractionOp->getAttrs()) {
- if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+ if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
+ attr.getName() == "indexing_maps")
continue;
-
- // Update the indexing_maps attribute for the collapsed MatmulOp.
- if (attr.getName() == "indexing_maps" &&
- std::is_same<FromOpTy, BatchMatmulOp>::value &&
- std::is_same<ToOpTy, MatmulOp>::value) {
- SmallVector<Attribute, 3> indexingMapsAttr = llvm::map_to_vector(
- MatmulOp::getDefaultIndexingMaps(rewriter.getContext()),
- [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
- collapsedOp->setAttr(attr.getName(),
- rewriter.getArrayAttr(indexingMapsAttr));
- } else {
- collapsedOp->setAttr(attr.getName(), attr.getValue());
- }
+ collapsedOp->setAttr(attr.getName(), attr.getValue());
}
auto results = contractionOp.getResults();
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 8bf2f858ff7c97..430484c796a0be 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1187,7 +1187,7 @@ func.func @missing_indexing_map_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: me
// -----
func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
- // expected-error @+1 {{Unexpected dim expression in map result}}
+ // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
@@ -1200,7 +1200,7 @@ func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memr
// -----
func.func @invalid_dim_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
- // expected-error @+1 {{Unexpected dim expression in map result}}
+ // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
@@ -1291,7 +1291,7 @@ func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: mem
// -----
func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+ // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expressions exceeds 3.}}
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
@@ -1305,7 +1305,7 @@ func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
// -----
func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
- // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+ // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expressions exceeds 3.}}
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 1dcd6a9f25af59..5e7e12df8775bc 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1491,14 +1491,14 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-// CHECK-LABEL: func.func @batch_matmul_bcast_batch_and_m_dim_A(
+// CHECK-LABEL: func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK: return
// CHECK: }
-func.func @batch_matmul_bcast_batch_and_m_dim_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
>From 7daf2134a9665b81a91d5db4e70b2056dd7d470e Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 22 Jan 2025 07:59:42 -0800
Subject: [PATCH 7/7] *Added tests to check DropUnitDim transform is not being
applied on contraction Op having user defined indexing_maps.
---
.../Linalg/rank-reduce-contraction-ops.mlir | 31 +++++++++++++++++++
1 file changed, 31 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index ebdbe70ff46eb7..c68a6362f52c5b 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -35,6 +35,23 @@ func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memr
// -----
+func.func @negative_singleton_batch_matmul_to_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
+ // CHECK-LABEL: @negative_singleton_batch_matmul_to_matmul_memref
+ // CHECK-NOT: collapse_shape
+ // CHECK-NOT: linalg.matmul
+ // CHECK-NOT: expand_shape
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
+ outs(%arg2 : memref<1x?x?xf32>)
+ return
+}
+
+// -----
+
func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512xf32>, %arg2: tensor<1x128xf32>) -> tensor<1x128xf32> {
// CHECK-LABEL: @singleton_batch_matvec
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32>
@@ -135,6 +152,20 @@ func.func @matmul_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x1xf32>, %arg
// -----
+func.func @negative_matmul_to_matvec(%arg0: memref<?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<?x1xf32>) {
+ // CHECK-LABEL: @negative_matmul_to_matvec
+ // CHECK-NOT: linalg.matvec
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1: memref<?xf32>, memref<?x1xf32>) outs(%arg2: memref<?x1xf32>)
+ return
+}
+
+// -----
+
func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
// CHECK-LABEL: @matmul_to_vecmat
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
More information about the Mlir-commits
mailing list