[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)
Rolf Morel
llvmlistbot at llvm.org
Tue Jan 28 14:37:20 PST 2025
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/123618
>From e3716872e9b379553fa52c67cfec49392cacefad Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 19 Jan 2025 13:42:59 -0800
Subject: [PATCH 01/10] [MLIR][Linalg] Introduce linalg.contract
A new op that allows for representing arbitrary contractions on operands
of arbitrary rank, with arbitrary transposes and arbitrary broadcasts
specified through its indexing_maps attribute.
Supports the expected lowerings to linalg.generic and to vector.contract.
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 118 +++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 224 ++++++++++++++++--
.../Dialect/Linalg/generalize-named-ops.mlir | 168 ++++++++++++-
.../generalize-named-polymorphic-ops.mlir | 50 ++++
mlir/test/Dialect/Linalg/invalid.mlir | 97 ++++++++
mlir/test/Dialect/Linalg/loops.mlir | 47 ++++
mlir/test/Dialect/Linalg/named-ops.mlir | 23 +-
mlir/test/Dialect/Linalg/roundtrip.mlir | 17 +-
mlir/test/Dialect/Linalg/tile-tensors.mlir | 46 ++++
.../Linalg/transform-op-vectorize.mlir | 33 +++
10 files changed, 789 insertions(+), 34 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..d4277bd34f3946 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,124 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+ let summary = [{
+ Perform a contraction on two inputs, accumulating on top of a third.
+ }];
+ let description = [{
+ The semantics of contracting inputs `A` and `B` on top of `C` to produce
+ output `D` is given by
+
+ `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+ where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
+ dimension identifiers (meant to range over valid indices), corresponding to
+ the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
+ and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
+ the dimensions in the set `dims`.
+
+ The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+ domain of each of the `affine_map`s. Like for einsums, the iteration type of
+ each dim is inferred and is either:
+
+ - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
+ Per the above semantics, these dims will be contracted, i.e. reduced over.
+
+ - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
+ deriving from matmul terminology - is either an "M-like" dim (if in `A`
+ and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
+ `B`, and `C`).
+
+ For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+ `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+ `n` and `b` are of parallel iteration-type) and gets represented as:
+
+ ```
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
+ ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ ```
+
+ Note that by permuting the dims in the co-domains of the `affine_map`s, we
+ can apply arbitrary transposes to the inputs and output. Similarly,
+ arbitrary broadcasts can be achieved through leaving out dims on either
+ input operand.
+
+ Numeric casting is performed on the operands to the inner multiplication,
+ promoting them to the same data type as the accumulator/output.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ AffineMapArrayAttr:$indexing_maps
+ );
+ let results = (outs Variadic<AnyShaped>:$result_tensors);
+ let regions = (region SizedRegion<1>:$combiner);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("indexing_maps", indexingMaps);
+ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
+ outputs, attributes, regionBuilder);
+ }]>,
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+ "ArrayAttr":$indexingMaps,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("indexing_maps", indexingMaps);
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, regionBuilder);
+ }]>
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ // Declare/implement functions necessary for LinalgStructuredInterface.
+ /// Infer iterator types for each dim in the domain of IndexingMaps.
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// IndexingMaps always depends on attr associated to current Op instance.
+ bool hasDynamicIndexingMaps() { return true; };
+ bool hasUserDefinedMaps() { return true; };
+
+ static unsigned getNumRegionArgs();
+
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ std::string getLibraryCallName() {
+ return "op_has_no_registered_library_name";
+ }
+
+ // Implement function necessary for DestinationStyleOpInterface.
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..355ed2d269291c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3528,44 +3528,45 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
}
-ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<Attribute, 3> indexingMapsAttr;
- Attribute mapAttr;
- if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
- if (parser.parseEqual())
- return failure();
+FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
+ if (parser.parseOptionalKeyword("indexing_maps"))
+ return {nullptr}; // Success in case indexing_maps was not provided.
- if (parser.parseLSquare())
+ SmallVector<Attribute> indexingMaps;
+
+ auto parseIndexingMap = [&]() -> ParseResult {
+ AffineMapAttr affineMapAttr;
+ if (parser.parseAttribute(affineMapAttr))
return failure();
+ indexingMaps.push_back(affineMapAttr);
+ return success();
+ };
- do {
- if (parser.parseAttribute(mapAttr))
- return failure();
- if (!isa<AffineMapAttr>(mapAttr)) {
- return parser.emitError(parser.getCurrentLocation(),
- "expected affine map attribute");
- }
- indexingMapsAttr.push_back(mapAttr);
+ if (parser.parseEqual() ||
+ parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
+ parseIndexingMap))
+ return failure();
- if (parser.parseOptionalComma())
- break;
- } while (true);
+ return parser.getBuilder().getArrayAttr(indexingMaps);
+}
- if (parser.parseRSquare())
- return failure();
- }
- // Initialize indexingMaps, if not supplied explicitly.
- if (indexingMapsAttr.empty()) {
- indexingMapsAttr = llvm::map_to_vector(
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+ if (failed(indexingMapsAttr))
+ return failure();
+
+ if (*indexingMapsAttr == nullptr) {
+ auto indexingMapAttrs = llvm::map_to_vector(
MatmulOp::getDefaultIndexingMaps(parser.getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
}
- result.addAttribute("indexing_maps",
- parser.getBuilder().getArrayAttr(indexingMapsAttr));
+ result.addAttribute("indexing_maps", *indexingMapsAttr);
return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
MatmulOp::getRegionBuilder());
}
+
void MatmulOp::print(OpAsmPrinter &p) {
SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3599,6 +3600,7 @@ LogicalResult MatmulOp::verify() {
LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
+
void MatmulOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
@@ -3611,5 +3613,175 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// ContractOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+ AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+ /// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+ /// domains are all the same, and each implements a projected permutation.
+ /// Each dim in the domain must occur for at least one operand and is
+ /// classified as either batch, N-like, M-like, or K-like. Only the latter
+ /// corresponds to a reduction _and_ it is the only dim-kind which does not
+ /// occur for the output operand. We use this fact for fast inference:
+ // NB: In case we allow dims to occur solely for one input, the above still
+ // holds: per the einsum semantics, these are reduction dims as well.
+ auto dimsInOutput = SmallVector<bool>(outAffineMap.getNumDims(), false);
+ for (auto result : outAffineMap.getResults()) {
+ auto dimExpr = dyn_cast<AffineDimExpr>(result);
+ assert(dimExpr && "affine_map is a projected permutation");
+ dimsInOutput[dimExpr.getPosition()] = true;
+ }
+
+ SmallVector<utils::IteratorType> iteratorTypes;
+ for (auto dimOccursInOutput : dimsInOutput)
+ iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+ : utils::IteratorType::reduction);
+
+ return iteratorTypes;
+}
+
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "ContractOp regionBuilder expects 3 args");
+ RegionBuilderHelper helper(b, block);
+
+ TypeFn castSignedness = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castSignedness = attr.getValue();
+ }
+
+ // TODO: Support fields with operators besides mult & add.
+ Type outType = block.getArgument(2).getType();
+ Value lhsAtOutType =
+ helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+ Value rhsAtOutType =
+ helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+ Value productAtOutType =
+ helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+ Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+ productAtOutType);
+ helper.yieldOutputs({result});
+}
+
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+ FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+ if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected 'indexing_map' attribute");
+ result.addAttribute("indexing_maps", *indexingMapsAttr);
+
+ return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+ regionBuilder);
+}
+
+void ContractOp::print(OpAsmPrinter &p) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ printNamedStructuredOp(
+ p, getOperation(), getInputs(), getOutputs(),
+ /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+}
+
+LogicalResult ContractOp::verify() {
+ int iterationSpaceDims = -1;
+ // Maps iter space dim (as index) to num of occurrences in inputs and output.
+ SmallVector<size_t> inOccurrences;
+ SmallVector<size_t> outOccurrences;
+
+ auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
+ bool isInput) -> LogicalResult {
+ if (iterationSpaceDims == -1) {
+ iterationSpaceDims = affineMap.getNumDims();
+ inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+ outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+ } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
+ return emitError("iteration spaces of provided affine_maps differ");
+ }
+
+ if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
+ if (affineMap.getNumResults() != shapedType.getRank())
+ return emitError("ranks of shaped operand and co-domain of "
+ "corresponding affine_map differ");
+ } else if (affineMap.getNumResults() != 0) {
+ return emitError("affine_map specifies shaped access while operand has "
+ "non-shaped type");
+ }
+
+ if (!affineMap.isProjectedPermutation())
+ return emitError("provided affine_map is not a projected permutation");
+
+ for (AffineExpr affineExpr : affineMap.getResults()) {
+ auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
+ if (!affineDimExpr)
+ llvm_unreachable("affine_map is a projected permutation");
+
+ if (isInput)
+ inOccurrences[affineDimExpr.getPosition()] += 1;
+ else
+ outOccurrences[affineDimExpr.getPosition()] += 1;
+ }
+
+ return success();
+ };
+
+ for (auto &&[affineMap, operandType, isInput] :
+ llvm::zip(getIndexingMapsArray(), getOperandTypes(),
+ SmallVector<bool>{true, true, false}))
+ if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
+ return failure(); // NOTE: checking lambda will emit error.
+
+ bool hasContractingDim = false;
+ for (auto &&[inOccCount, outOccCount] : zip(inOccurrences, outOccurrences)) {
+ hasContractingDim |= inOccCount == 2 && outOccCount == 0;
+
+ if (inOccCount == 0)
+ return emitError("iteration space dim not used by either input");
+
+ // NB: A dim which occurs for only one input operand and not for the output.
+ // In terms of einsum semantics, such dims have a sensible meaning -
+ // namely an additional reduction per such dim - though this can also
+ // always be expressed through an additional op. Additionally, at time
+ // of writing, vector.contract's verifier accepts these dims but many of
+ // its lowerings do not handle these kinds of dims. Hence...
+ // TODO: Remove following once we have comprehensive support for input-only
+ // reduction dims, at both the linalg- and vector-dialect levels.
+ if (inOccCount == 1 && outOccCount != 1)
+ return emitError("iter type of dim is not one of M, N, K or batch");
+ }
+
+ if (!hasContractingDim)
+ return emitError("'indexing_maps' do not specify a contracting dimension");
+
+ return success();
+}
+
+LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+
+void ContractOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ContractOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..f7e570d5ce38f0 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -943,7 +943,6 @@ func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7
]
ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
-
return
}
@@ -969,7 +968,6 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
outs(%arg2: memref<3x7xf32>)
-
return
}
@@ -996,9 +994,173 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
]
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
outs(%arg2: memref<3x7xf32>)
-
return
}
// -----
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @contract_matmul(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT: ^{{.+}}(
+// CHECK-NEXT: arith.mulf
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+
+func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @contract_matmul_transpose_a_b(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT: ^{{.+}}(
+// CHECK-NEXT: arith.mulf
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+
+func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func.func @contract_batch_matmul(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<9x3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-NEXT: ^{{.+}}(
+// CHECK-NEXT: arith.mulf
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+
+func.func @contract_batch_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<9x3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
+ outs(%arg2: memref<9x3x7xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @contract_batch_reduce_matmul(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-NEXT: ^{{.+}}(
+// CHECK-NEXT: arith.mulf
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+
+func.func @contract_batch_reduce_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<9x5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<9x7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-NEXT: ^{{.+}}(
+// CHECK-NEXT: arith.mulf
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+
+func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(%arg0: memref<9x5x3xf32>, %arg1: memref<9x7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<9x5x3xf32>, memref<9x7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> ()>
+
+// CHECK-LABEL: func.func @contract_dot
+// CHECK-SAME: (%[[VAL_0:.*]]: memref<9xf32>, %[[VAL_1:.*]]: memref<9xf32>, %[[VAL_2:.*]]: memref<f32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["reduction"]}
+// CHECK-NEXT: ^{{.+}}(
+// CHECK-NEXT: arith.mulf
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+
+func.func @contract_dot(%arg0: memref<9xf32>, %arg1: memref<9xf32>, %arg2: memref<f32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> ()>
+ ]
+ ins(%arg0, %arg1 : memref<9xf32>, memref<9xf32>)
+ outs(%arg2: memref<f32>)
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index c170c5be4abff9..9acb7562f96ee0 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -120,6 +120,56 @@ func.func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B
// -----
+func.func @generalize_matmul_as_contraction_tensor_f16f64f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
+ outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %0: tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @generalize_matmul_as_contraction_tensor_f16f64f32
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+// Verify floating point extension and truncation.
+// CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT: %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
+// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+
+func.func @generalize_matmul_as_contract_with_ext_and_trunc(%arg0: tensor<24x12xf16>,
+ %arg1: tensor<12x25xf16>,
+ %arg2: tensor<24x25xf32>) -> tensor<24x25xf16> {
+ %0 = linalg.contract indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ]
+ ins(%arg0, %arg1 : tensor<24x12xf16>, tensor<12x25xf16>)
+ outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+ %1 = arith.truncf %0 : tensor<24x25xf32> to tensor<24x25xf16>
+ func.return %1 : tensor<24x25xf16>
+}
+
+// CHECK-LABEL: @generalize_matmul_as_contract_with_ext_and_trunc
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+// Verify floating point extension and truncation.
+// CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT: %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
+// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<24x25xf32>
+// CHECK-NEXT: %[[RES:.+]] = arith.truncf {{.*}} : tensor<24x25xf32> to tensor<24x25xf16>
+
+// -----
+
func.func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
%0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a59472377a732c..59eeb3953e548f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -529,6 +529,103 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
// -----
+func.func @invalid_indexing_maps_placement_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+ // expected-error @+1 {{custom op 'linalg.contract' expected 'indexing_map' attribute}}
+ linalg.contract ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>)
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ return
+}
+
+// -----
+
+func.func @invalid_affine_map_in_indexing_maps_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+ // expected-error @+1 {{provided affine_map is not a projected permutation}}
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0 + d2, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ return
+}
+
+// -----
+
+func.func @differing_iteration_space_of_affine_maps_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+ // expected-error @+1 {{iteration spaces of provided affine_maps differ}}
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ return
+}
+
+// -----
+
+func.func @mismatched_ranks_affine_map_and_operand_contraction(%lhs: tensor<4x1x2xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+ // expected-error @+1 {{ranks of shaped operand and co-domain of corresponding affine_map differ}}
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<4x1x2xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ return
+}
+// -----
+
+func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: tensor<4x64xf32>, %init: tensor<4x64xf32>) {
+ // expected-error @+1 {{affine_map specifies shaped access while operand has non-shaped type}}
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : f32, tensor<4x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ return
+}
+
+// -----
+
+func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+ // expected-error @+1 {{iteration space dim not used by either input}}
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ return
+}
+
+// -----
+
+func.func @unused_iteration_space_dim_contraction(%lhs: tensor<8x4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+ // expected-error @+1 {{iter type of dim is not one of M, N, K or batch}}
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<8x4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ return
+}
+
+// -----
+
func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
// expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
linalg.conv_2d_nhwc_hwcf
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 6286a11c11a21f..0a83750b81dea4 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -627,6 +627,53 @@ func.func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f
//----------------------------------------------------------------------------//
// Named ops to loops.
//----------------------------------------------------------------------------//
+func.func @batch_reduce_matmul_as_contract(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+ ]
+ ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C : memref<?x?xf32>)
+ return
+}
+// CHECK-LABEL: @batch_reduce_matmul_as_contract
+// CHECK-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECK-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECK-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK: %[[B:.*]] = memref.dim %[[mA]], %c0 : memref<?x?x?xf32>
+// CHECK: %[[M:.*]] = memref.dim %[[mA]], %c1 : memref<?x?x?xf32>
+// CHECK: %[[K:.*]] = memref.dim %[[mA]], %c2 : memref<?x?x?xf32>
+// CHECK: %[[N:.*]] = memref.dim %[[mB]], %c2 : memref<?x?x?xf32>
+// CHECK: scf.for %[[b:.*]] = %{{.*}} to %[[B]]
+// CHECK: scf.for %[[m:.*]] = %{{.*}} to %[[M]]
+// CHECK: scf.for %[[n:.*]] = %{{.*}} to %[[N]]
+// CHECK: scf.for %[[k:.*]] = %{{.*}} to %[[K]]
+// CHECK: %[[va:.*]] = memref.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+// CHECK: %[[vb:.*]] = memref.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+// CHECK: %[[vc:.*]] = memref.load %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+// CHECK: %[[inc:.*]] = arith.mulf %[[va]], %[[vb]] : f32
+// CHECK: %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+// CHECK: store %[[res]], %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+
+// CHECKPARALLEL-LABEL: @batch_reduce_matmul_as_contract
+// CHECKPARALLEL-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKPARALLEL-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// CHECKPARALLEL-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECKPARALLEL: %[[B:.*]] = memref.dim %[[mA]], %c0 : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[M:.*]] = memref.dim %[[mA]], %c1 : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[K:.*]] = memref.dim %[[mA]], %c2 : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[N:.*]] = memref.dim %[[mB]], %c2 : memref<?x?x?xf32>
+// CHECKPARALLEL: scf.for %[[b:.*]] = %{{.*}} to %[[B]]
+// CHECKPARALLEL: scf.parallel (%[[m:.*]], %[[n:.*]]) = ({{.*}}) to (%[[M]], %[[N]]) step ({{.*}}) {
+// CHECKPARALLEL: scf.for %[[k:.*]] = %{{.*}} to %[[K]]
+// CHECKPARALLEL: %[[va:.*]] = memref.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[vb:.*]] = memref.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+// CHECKPARALLEL: %[[vc:.*]] = memref.load %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+// CHECKPARALLEL: %[[inc:.*]] = arith.mulf %[[va]], %[[vb]] : f32
+// CHECKPARALLEL: %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+// CHECKPARALLEL: store %[[res]], %[[mC]][%[[m]], %[[n]]] : memref<?x?xf32>
+
func.func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
linalg.batch_matmul ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
outs(%C : memref<?x?x?xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 68aa5a85b5e0e6..3c68e7d394642e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1245,7 +1245,6 @@ func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7
]
ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
-
return
}
@@ -1259,7 +1258,6 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
outs(%arg2: memref<3x7xf32>)
-
return
}
@@ -1509,6 +1507,27 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
// -----
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: func @contract
+// CHECK: linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @contract(%arg0: memref<2x3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x5x7xf32>)
+ outs(%arg2: memref<2x3x7xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @mmt4d
func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> {
// CHECK: %{{.+}} = linalg.mmt4d
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 1b8969bd115595..99cbb6647effbe 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -277,22 +277,33 @@ func.func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offs
// -----
-
func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
%ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
- -> (tensor<?x?x?xf32>)
+ -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
{
linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
outs(%c3: memref<?x?x?xf32>)
+ linalg.contract indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
+ ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%c3: memref<?x?x?xf32>)
%res1 = linalg.batch_matmul
ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
outs(%tc3: tensor<?x?x?xf32>)
-> tensor<?x?x?xf32>
- return %res1 : tensor<?x?x?xf32>
+ %res2 = linalg.contract indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
+ ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%tc3: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
}
// CHECK-LABEL: func @named_ops
// CHECK: linalg.batch_matmul
+// CHECK: linalg.contract
// CHECK: linalg.batch_matmul
+// CHECK: linalg.contract
// -----
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 8f13c690704572..1de1863d6deb13 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -37,6 +37,52 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-NEXT: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-NEXT: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @matmul_as_contract_tensors(
+// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @matmul_as_contract_tensors(
+ %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sTD:.*]] = linalg.contract indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK: scf.yield %[[TD]] : tensor<?x?xf32>
+// CHECK: scf.yield %[[TD2]] : tensor<?x?xf32>
+// CHECK: scf.yield %[[TD1]] : tensor<?x?xf32>
+ %0 = linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+
+// CHECK: return %[[TD0]] : tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 3, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @matmul_tensors_with_size_zeros(
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 0d59dbba8940d4..2d30d62039642a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -54,6 +54,39 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @vectorize_matmul_as_contract
+// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32>
+// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32>
+// CHECK-SAME: %[[C:.*]]: tensor<24x25xf32>
+func.func @vectorize_matmul_as_contract(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
+ // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
+ // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
+ // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+ // CHECK: vector.transfer_write %[[vR]], %[[C]]
+ %0 = linalg.contract indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ]
+ ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>)
+ outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+ func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: @vectorize_copy_memref
// CHECK-SAME: %[[A:.*]]: memref<100x100xf32>,
// CHECK-SAME: %[[B:.*]]: memref<100x100xf32>
>From 3d0d5b307d7326929e5a8da0942c51f4f133543e Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 20 Jan 2025 15:46:48 -0800
Subject: [PATCH 02/10] Remove extraneous line that allowed adding a region
---
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index d4277bd34f3946..8c9a4e56ce00a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -742,7 +742,6 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
AffineMapArrayAttr:$indexing_maps
);
let results = (outs Variadic<AnyShaped>:$result_tensors);
- let regions = (region SizedRegion<1>:$combiner);
let skipDefaultBuilders = 1;
let builders = [
>From 0537c438a0baa4cb4bd5c7c85a7e852eff6aa23e Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 20 Jan 2025 16:14:07 -0800
Subject: [PATCH 03/10] Revert "Remove extraneous line that allowed adding a
region"
This reverts commit 3d0d5b307d7326929e5a8da0942c51f4f133543e.
Actually, that line is needed...
---
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 8c9a4e56ce00a0..d4277bd34f3946 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -742,6 +742,7 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
AffineMapArrayAttr:$indexing_maps
);
let results = (outs Variadic<AnyShaped>:$result_tensors);
+ let regions = (region SizedRegion<1>:$combiner);
let skipDefaultBuilders = 1;
let builders = [
>From dbcb847b4d5ca4fde71b46df75b867df85ced6b7 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 21 Jan 2025 16:12:37 -0800
Subject: [PATCH 04/10] Address @adam-smnk's comments
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 1 +
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 27 ++++++++++++-------
mlir/test/Dialect/Linalg/invalid.mlir | 4 +--
3 files changed, 20 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index d4277bd34f3946..d4b3cd9172b6c3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -769,6 +769,7 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare/implement functions necessary for LinalgStructuredInterface.
+
/// Infer iterator types for each dim in the domain of IndexingMaps.
SmallVector<utils::IteratorType> getIteratorTypesArray();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 355ed2d269291c..9f515c73657afb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3743,22 +3743,29 @@ LogicalResult ContractOp::verify() {
return failure(); // NOTE: checking lambda will emit error.
bool hasContractingDim = false;
- for (auto &&[inOccCount, outOccCount] : zip(inOccurrences, outOccurrences)) {
+ for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
+ size_t inOccCount = inOccurrences[dimIndex];
+ size_t outOccCount = outOccurrences[dimIndex];
+
hasContractingDim |= inOccCount == 2 && outOccCount == 0;
if (inOccCount == 0)
- return emitError("iteration space dim not used by either input");
-
- // NB: A dim which occurs for only one input operand and not for the output.
- // In terms of einsum semantics, such dims have a sensible meaning -
- // namely an additional reduction per such dim - though this can also
- // always be expressed through an additional op. Additionally, at time
- // of writing, vector.contract's verifier accepts these dims but many of
- // its lowerings do not handle these kinds of dims. Hence...
+ return emitError() << "iteration space dim at index " << dimIndex
+ << " not used by either input";
+
+ // NB: We disallow a dim which occurs for only one input operand and not
+ // for the output. In terms of einsum semantics such dims have a
+ // sensible meaning - namely an additional reduction per each such dim.
+ // By contrast, the ContractionOpInterface does not know about this
+ // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
+ // while vector.contract's verifier accepts dims of this kind many of
+ // its lowerings give up on encountering these dims.
// TODO: Remove following once we have comprehensive support for input-only
// reduction dims, at both the linalg- and vector-dialect levels.
if (inOccCount == 1 && outOccCount != 1)
- return emitError("iter type of dim is not one of M, N, K or batch");
+ return emitError()
+ << "iteration space dim at index " << dimIndex
+ << " is neither a contracting dim nor of parallel iteration type";
}
if (!hasContractingDim)
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 59eeb3953e548f..57e004a048620e 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -599,7 +599,7 @@ func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: ten
// -----
func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
- // expected-error @+1 {{iteration space dim not used by either input}}
+ // expected-error @+1 {{iteration space dim at index 3 not used by either input}}
linalg.contract indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
@@ -613,7 +613,7 @@ func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: t
// -----
func.func @unused_iteration_space_dim_contraction(%lhs: tensor<8x4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
- // expected-error @+1 {{iter type of dim is not one of M, N, K or batch}}
+ // expected-error @+1 {{iteration space dim at index 3 is neither a contracting dim nor of parallel iteration type}}
linalg.contract indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
>From 62ae30b8ff22c4376503d900b37331d9eedcc001 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 23 Jan 2025 12:29:25 -0800
Subject: [PATCH 05/10] Address Adam's comments, round 2
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 21 ++--
.../Dialect/Linalg/generalize-named-ops.mlir | 25 +++++
mlir/test/Dialect/Linalg/invalid.mlir | 2 +-
mlir/test/Dialect/Linalg/named-ops.mlir | 99 +++++++++++++++++++
4 files changed, 136 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9f515c73657afb..b3bc7fedc1ad80 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3619,15 +3619,15 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
- /// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
- /// domains are all the same, and each implements a projected permutation.
- /// Each dim in the domain must occur for at least one operand and is
- /// classified as either batch, N-like, M-like, or K-like. Only the latter
- /// corresponds to a reduction _and_ it is the only dim-kind which does not
- /// occur for the output operand. We use this fact for fast inference:
+ // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+ // domains are all the same, and each implements a projected permutation.
+ // Each iteration space dim must occur for at least one operand and either
+ // takes part in a contraction/reduction or else has parallel iteration type.
+ // We have that a dim is a contraction/reduction dim if and only if the dim
+ // occurs for the output operand. We use this fact for fast inference:
// NB: In case we allow dims to occur solely for one input, the above still
// holds: per the einsum semantics, these are reduction dims as well.
- auto dimsInOutput = SmallVector<bool>(outAffineMap.getNumDims(), false);
+ SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
for (auto result : outAffineMap.getResults()) {
auto dimExpr = dyn_cast<AffineDimExpr>(result);
assert(dimExpr && "affine_map is a projected permutation");
@@ -3738,9 +3738,10 @@ LogicalResult ContractOp::verify() {
for (auto &&[affineMap, operandType, isInput] :
llvm::zip(getIndexingMapsArray(), getOperandTypes(),
- SmallVector<bool>{true, true, false}))
+ SmallVector<bool>{true, true, false})) {
if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
return failure(); // NOTE: checking lambda will emit error.
+ }
bool hasContractingDim = false;
for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
@@ -3749,9 +3750,9 @@ LogicalResult ContractOp::verify() {
hasContractingDim |= inOccCount == 2 && outOccCount == 0;
- if (inOccCount == 0)
+ if (inOccCount == 0 && outOccCount == 0)
return emitError() << "iteration space dim at index " << dimIndex
- << " not used by either input";
+ << " not used to access any operand";
// NB: We disallow a dim which occurs for only one input operand and not
// for the output. In terms of einsum semantics such dims have a
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index f7e570d5ce38f0..3b21467ca45fef 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1164,3 +1164,28 @@ func.func @contract_dot(%arg0: memref<9xf32>, %arg1: memref<9xf32>, %arg2: memre
outs(%arg2: memref<f32>)
return
}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
+// CHECK-SAME: (%[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>, %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT: ^{{.+}}(
+// CHECK-NEXT: arith.mulf
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+
+func.func @contract_matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 57e004a048620e..4e06342df2af76 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -599,7 +599,7 @@ func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: ten
// -----
func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
- // expected-error @+1 {{iteration space dim at index 3 not used by either input}}
+ // expected-error @+1 {{iteration space dim at index 3 not used for any operand}}
linalg.contract indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 3c68e7d394642e..6defa827b77c48 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1528,6 +1528,105 @@ func.func @contract(%arg0: memref<2x3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: m
// -----
+func.func @contract_matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @contract_matmul_bcast_a
+// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @contract_matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @contract_matmul_bcast_b
+// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @contract_matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
+// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
+// CHECK: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5xf32>)
+// CHECK: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @contract_matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_transpose_b
+// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<7x5xf32>)
+// CHECK: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @contract_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.contract indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract_matmul_bcast_b_transpose_a
+// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5xf32>)
+// CHECK: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
// CHECK-LABEL: func @mmt4d
func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> {
// CHECK: %{{.+}} = linalg.mmt4d
>From a5726a4902d2345f667ba096951d580818ab25df Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 24 Jan 2025 00:34:43 -0800
Subject: [PATCH 06/10] Fix up error message in test case
---
mlir/test/Dialect/Linalg/invalid.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 4e06342df2af76..039057d8e502d6 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -599,7 +599,7 @@ func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: ten
// -----
func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
- // expected-error @+1 {{iteration space dim at index 3 not used for any operand}}
+ // expected-error @+1 {{iteration space dim at index 3 not used to access any operand}}
linalg.contract indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
>From ea8f018d03fc095733d933409e6bbe3705222a9b Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 27 Jan 2025 14:33:42 -0800
Subject: [PATCH 07/10] Address @banach-space's comments
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 2 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 24 +-
.../Dialect/Linalg/generalize-named-ops.mlir | 241 +++++++++---------
.../generalize-named-polymorphic-ops.mlir | 74 +++---
mlir/test/Dialect/Linalg/invalid.mlir | 114 ++++-----
mlir/test/Dialect/Linalg/loops.mlir | 16 +-
mlir/test/Dialect/Linalg/named-ops.mlir | 163 ++++++------
mlir/test/Dialect/Linalg/roundtrip.mlir | 21 +-
mlir/test/Dialect/Linalg/tile-tensors.mlir | 21 +-
.../Linalg/transform-op-vectorize.mlir | 33 ---
.../Linalg/vectorization-with-patterns.mlir | 32 +++
11 files changed, 378 insertions(+), 363 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index d4b3cd9172b6c3..5bc04ac4ceb3eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -688,7 +688,7 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
AttrSizedOperandSegments,
LinalgContractionOpInterface]> {
let summary = [{
- Perform a contraction on two inputs, accumulating on top of a third.
+ Perform a contraction on two inputs, accumulating into the third.
}];
let description = [{
The semantics of contracting inputs `A` and `B` on top of `C` to produce
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b3bc7fedc1ad80..c6ea207323019d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3532,22 +3532,16 @@ FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
if (parser.parseOptionalKeyword("indexing_maps"))
return {nullptr}; // Success in case indexing_maps was not provided.
- SmallVector<Attribute> indexingMaps;
-
- auto parseIndexingMap = [&]() -> ParseResult {
- AffineMapAttr affineMapAttr;
- if (parser.parseAttribute(affineMapAttr))
- return failure();
- indexingMaps.push_back(affineMapAttr);
- return success();
- };
-
- if (parser.parseEqual() ||
- parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
- parseIndexingMap))
+ ArrayAttr arrayAttr;
+ if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
return failure();
- return parser.getBuilder().getArrayAttr(indexingMaps);
+ if (llvm::any_of(arrayAttr,
+ [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
+ return parser.emitError(parser.getCurrentLocation())
+ << "element of indexing_maps array is not an affine_map";
+
+ return arrayAttr;
}
ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -3677,7 +3671,7 @@ ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
return parser.emitError(parser.getCurrentLocation(),
- "expected 'indexing_map' attribute");
+ "expected 'indexing_maps' attribute");
result.addAttribute("indexing_maps", *indexingMapsAttr);
return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 3b21467ca45fef..7ee25ac25ac6c9 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -999,193 +999,206 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
// -----
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-// CHECK-LABEL: func.func @contract_matmul(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
-
-// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
-// CHECK-NEXT: ^{{.+}}(
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @contract_matmul(
+// CHECK-SAME: %[[A:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-NEXT: ^{{.+}}(
// CHECK-NEXT: arith.mulf
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
- outs(%arg2: memref<3x7xf32>)
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
return
}
// -----
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL: func.func @contract_matmul_transpose_a_b(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK-LABEL: func.func @contract_matmul_transpose_a_b(
+// CHECK-SAME: %[[A:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
-// CHECK-NEXT: ^{{.+}}(
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-NEXT: ^{{.+}}(
// CHECK-NEXT: arith.mulf
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d2, d0)>,
- affine_map<(d0, d1, d2) -> (d1, d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
- outs(%arg2: memref<3x7xf32>)
-
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
return
}
// -----
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-// CHECK-LABEL: func.func @contract_batch_matmul(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<9x3x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<9x5x7xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: memref<9x3x7xf32>) {
+// CHECK-LABEL: func.func @contract_batch_matmul(
+// CHECK-SAME: %[[A:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<9x3x7xf32>) {
-// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-// CHECK-NEXT: ^{{.+}}(
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-NEXT: ^{{.+}}(
// CHECK-NEXT: arith.mulf
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
func.func @contract_batch_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<9x3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
- outs(%arg2: memref<9x3x7xf32>)
-
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+ ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
+ outs(%arg2: memref<9x3x7xf32>)
return
}
// -----
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-// CHECK-LABEL: func.func @contract_batch_reduce_matmul(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<9x3x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<9x5x7xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK-LABEL: func.func @contract_batch_reduce_matmul(
+// CHECK-SAME: %[[A:.*]]: memref<9x3x5xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<9x5x7xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
-// CHECK-NEXT: ^{{.+}}(
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]
+// CHECK-NEXT: ^{{.+}}(
// CHECK-NEXT: arith.mulf
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
-func.func @contract_batch_reduce_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
- outs(%arg2: memref<3x7xf32>)
+#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @contract_batch_reduce_matmul(
+ %A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<3x7xf32>) {
+ linalg.contract
+ indexing_maps = [#accessA, #accessB, #accessC]
+ ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>)
+ outs(%C: memref<3x7xf32>)
return
}
// -----
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-// CHECK-LABEL: func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<9x5x3xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<9x7x5xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK-LABEL: func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
+// CHECK-SAME: %[[A:.*]]: memref<9x5x3xf32>,
+// CHECK-SAME: %[[B:.*]]: memref<9x7x5xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
-// CHECK-NEXT: ^{{.+}}(
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]
+// CHECK-NEXT: ^{{.+}}(
// CHECK-NEXT: arith.mulf
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
-func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(%arg0: memref<9x5x3xf32>, %arg1: memref<9x7x5xf32>, %arg2: memref<3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<9x5x3xf32>, memref<9x7x5xf32>)
- outs(%arg2: memref<3x7xf32>)
+#accessA = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+#accessB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#accessC = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @contract_batch_reduce_matmul_permute_m_with_k_and_k_with_n(
+ %A: memref<9x5x3xf32>, %B: memref<9x7x5xf32>, %C: memref<3x7xf32>) {
+ linalg.contract
+ indexing_maps = [#accessA, #accessB, #accessC]
+ ins(%A, %B : memref<9x5x3xf32>, memref<9x7x5xf32>)
+ outs(%C: memref<3x7xf32>)
return
}
// -----
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> ()>
+// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0) -> ()>
-// CHECK-LABEL: func.func @contract_dot
-// CHECK-SAME: (%[[VAL_0:.*]]: memref<9xf32>, %[[VAL_1:.*]]: memref<9xf32>, %[[VAL_2:.*]]: memref<f32>) {
+// CHECK-LABEL: func.func @contract_dot(
+// CHECK-SAME: %[[A:.*]]: memref<9xf32>, %[[B:.*]]: memref<9xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<f32>) {
-// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["reduction"]}
-// CHECK-NEXT: ^{{.+}}(
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: iterator_types = ["reduction"]
+// CHECK-NEXT: ^{{.+}}(
// CHECK-NEXT: arith.mulf
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
-func.func @contract_dot(%arg0: memref<9xf32>, %arg1: memref<9xf32>, %arg2: memref<f32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0) -> (d0)>,
- affine_map<(d0) -> (d0)>,
- affine_map<(d0) -> ()>
- ]
- ins(%arg0, %arg1 : memref<9xf32>, memref<9xf32>)
- outs(%arg2: memref<f32>)
+#accessAB = affine_map<(d0) -> (d0)>
+#accessC = affine_map<(d0) -> ()>
+func.func @contract_dot(
+ %A: memref<9xf32>, %B: memref<9xf32>, %C: memref<f32>) {
+ linalg.contract
+ indexing_maps = [#accessAB, #accessAB, #accessC]
+ ins(%A, %B : memref<9xf32>, memref<9xf32>)
+ outs(%C: memref<f32>)
return
}
// -----
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK: #[[$ACCESS_A_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
-// CHECK-SAME: (%[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>, %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_b(
+// CHECK-SAME: %[[A:.*]]: memref<5xf32>, %[[B:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<3x7xf32>) {
-// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
-// CHECK-NEXT: ^{{.+}}(
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A_B]], #[[$ACCESS_A_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-NEXT: ^{{.+}}(
// CHECK-NEXT: arith.mulf
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
-func.func @contract_matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d2)>,
- affine_map<(d0, d1, d2) -> (d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>)
- outs(%arg2: memref<3x7xf32>)
+#accessAB = affine_map<(d0, d1, d2) -> (d2)>
+#accessC = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @contract_matmul_bcast_a_b(
+ %A: memref<5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+ linalg.contract
+ indexing_maps = [#accessAB, #accessAB, #accessC]
+ ins(%A, %B : memref<5xf32>, memref<5xf32>)
+ outs(%C: memref<3x7xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 9acb7562f96ee0..a192a2f72b286a 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -120,53 +120,55 @@ func.func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B
// -----
-func.func @generalize_matmul_as_contraction_tensor_f16f64f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
- %0 = linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
- outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+func.func @generalize_matmul_as_contraction_tensor_f16f64f32(
+ %A : tensor<16x8xf16>,
+ %B: tensor<8x32xf64>,
+ %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf64>)
+ outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
// CHECK-LABEL: @generalize_matmul_as_contraction_tensor_f16f64f32
-// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
// Verify floating point extension and truncation.
-// CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
-// CHECK-NEXT: %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
-// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
-// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
-// CHECK-NEXT: linalg.yield %[[ADD]] : f32
-// CHECK-NEXT: -> tensor<16x32xf32>
-
-// -----
-
-func.func @generalize_matmul_as_contract_with_ext_and_trunc(%arg0: tensor<24x12xf16>,
- %arg1: tensor<12x25xf16>,
- %arg2: tensor<24x25xf32>) -> tensor<24x25xf16> {
- %0 = linalg.contract indexing_maps = [
- affine_map<(m, n, k) -> (m, k)>,
- affine_map<(m, n, k) -> (k, n)>,
- affine_map<(m, n, k) -> (m, n)>
- ]
- ins(%arg0, %arg1 : tensor<24x12xf16>, tensor<12x25xf16>)
- outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+// CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT: %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
+// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+
+func.func @generalize_matmul_as_contract_with_ext_and_trunc(
+ %A: tensor<24x12xf16>,
+ %B: tensor<12x25xf16>,
+ %C: tensor<24x25xf32>) -> tensor<24x25xf16> {
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<24x12xf16>, tensor<12x25xf16>)
+ outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
%1 = arith.truncf %0 : tensor<24x25xf32> to tensor<24x25xf16>
func.return %1 : tensor<24x25xf16>
}
// CHECK-LABEL: @generalize_matmul_as_contract_with_ext_and_trunc
-// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
// Verify floating point extension and truncation.
-// CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
-// CHECK-NEXT: %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
-// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
-// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
-// CHECK-NEXT: linalg.yield %[[ADD]] : f32
-// CHECK-NEXT: -> tensor<24x25xf32>
-// CHECK-NEXT: %[[RES:.+]] = arith.truncf {{.*}} : tensor<24x25xf32> to tensor<24x25xf16>
+// CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK-NEXT: %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
+// CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<24x25xf32>
+// CHECK-NEXT: %[[RES:.+]] = arith.truncf {{.*}} : tensor<24x25xf32> to tensor<24x25xf16>
// -----
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 039057d8e502d6..f6fd60766d1574 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -529,98 +529,98 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
// -----
-func.func @invalid_indexing_maps_placement_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
- // expected-error @+1 {{custom op 'linalg.contract' expected 'indexing_map' attribute}}
- linalg.contract ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
- outs(%init : tensor<4x64xf32>)
- indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
+func.func @invalid_indexing_maps_placement_contraction(
+ %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+ // expected-error @+2 {{custom op 'linalg.contract' expected 'indexing_maps' attribute}}
+ linalg.contract
+ ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>)
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
return
}
// -----
-func.func @invalid_affine_map_in_indexing_maps_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+func.func @invalid_affine_map_in_indexing_maps_contraction(
+ %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
// expected-error @+1 {{provided affine_map is not a projected permutation}}
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0 + d2, d2)>,
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
- outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d2, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
return
}
// -----
-func.func @differing_iteration_space_of_affine_maps_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+func.func @differing_iteration_space_of_affine_maps_contraction(
+ %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
// expected-error @+1 {{iteration spaces of provided affine_maps differ}}
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
- outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
return
}
// -----
-func.func @mismatched_ranks_affine_map_and_operand_contraction(%lhs: tensor<4x1x2xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+func.func @mismatched_ranks_affine_map_and_operand_contraction(
+ %lhs: tensor<4x1x2xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
// expected-error @+1 {{ranks of shaped operand and co-domain of corresponding affine_map differ}}
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%lhs, %rhs : tensor<4x1x2xf32>, tensor<1x64xf32>)
- outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%lhs, %rhs : tensor<4x1x2xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
return
}
// -----
-func.func @mismatch_type_affine_map_and_operand_contraction(%lhs: f32, %rhs: tensor<4x64xf32>, %init: tensor<4x64xf32>) {
+func.func @mismatch_type_affine_map_and_operand_contraction(
+ %lhs: f32, %rhs: tensor<4x64xf32>, %init: tensor<4x64xf32>) {
// expected-error @+1 {{affine_map specifies shaped access while operand has non-shaped type}}
- linalg.contract indexing_maps = [
- affine_map<(d0, d1) -> (d0)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>
- ]
- ins(%lhs, %rhs : f32, tensor<4x64xf32>)
- outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>]
+ ins(%lhs, %rhs : f32, tensor<4x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
return
}
// -----
-func.func @unused_iteration_space_dim_contraction(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+func.func @unused_iteration_space_dim_contraction(
+ %lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
// expected-error @+1 {{iteration space dim at index 3 not used to access any operand}}
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d1)>
- ]
- ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
- outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>]
+ ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
return
}
// -----
-func.func @unused_iteration_space_dim_contraction(%lhs: tensor<8x4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+func.func @unused_iteration_space_dim_contraction(
+ %lhs: tensor<8x4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
// expected-error @+1 {{iteration space dim at index 3 is neither a contracting dim nor of parallel iteration type}}
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d1)>
- ]
- ins(%lhs, %rhs : tensor<8x4x1xf32>, tensor<1x64xf32>)
- outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>]
+ ins(%lhs, %rhs : tensor<8x4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
return
}
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 0a83750b81dea4..efe8010cffc916 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -627,14 +627,14 @@ func.func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f
//----------------------------------------------------------------------------//
// Named ops to loops.
//----------------------------------------------------------------------------//
-func.func @batch_reduce_matmul_as_contract(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d1, d2)>
- ]
- ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%C : memref<?x?xf32>)
+func.func @batch_reduce_matmul_as_contract(
+ %A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+ ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C : memref<?x?xf32>)
return
}
// CHECK-LABEL: @batch_reduce_matmul_as_contract
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 6defa827b77c48..06e0bb801b5444 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1515,115 +1515,120 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
// CHECK-SAME: indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x5x7xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>)
-func.func @contract(%arg0: memref<2x3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- ]
- ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x5x7xf32>)
- outs(%arg2: memref<2x3x7xf32>)
+func.func @contract(
+ %A: memref<2x3x5xf32>, %B: memref<2x5x7xf32>, %C: memref<2x3x7xf32>) {
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+ ins(%A, %B : memref<2x3x5xf32>, memref<2x5x7xf32>)
+ outs(%C: memref<2x3x7xf32>)
return
}
// -----
-func.func @contract_matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d2)>,
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @contract_matmul_bcast_a
+func.func @contract_matmul_bcast_a(%A: memref<5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) {
+// CHECK: linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%A, %B : memref<5xf32>, memref<5x7xf32>)
+ outs(%C: memref<3x7xf32>)
return
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL: func @contract_matmul_bcast_a
-// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
-// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
-
// -----
-func.func @contract_matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @contract_matmul_bcast_b
+func.func @contract_matmul_bcast_b(%A: memref<3x5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+// CHECK: linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%A, %B : memref<3x5xf32>, memref<5xf32>)
+ outs(%C: memref<3x7xf32>)
return
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL: func @contract_matmul_bcast_b
-// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
-// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
-
// -----
-func.func @contract_matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d2)>,
- affine_map<(d0, d1, d2) -> (d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
+func.func @contract_matmul_bcast_a_b(
+ %A: memref<5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+// CHECK: linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_A]], #[[$ACCESS_B]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%A, %B : memref<5xf32>, memref<5xf32>)
+ outs(%C: memref<3x7xf32>)
return
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
-// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
-// CHECK: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5xf32>)
-// CHECK: outs(%{{.+}} : memref<3x7xf32>)
-
// -----
-func.func @contract_matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d2)>,
- affine_map<(d0, d1, d2) -> (d1, d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract_matmul_bcast_a_transpose_b
+func.func @contract_matmul_bcast_a_transpose_b(
+ %A: memref<5xf32>, %B: memref<7x5xf32>, %C: memref<3x7xf32>) {
+// CHECK: linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<7x5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%A, %B : memref<5xf32>, memref<7x5xf32>)
+ outs(%C: memref<3x7xf32>)
return
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL: func.func @contract_matmul_bcast_a_transpose_b
-// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
-// CHECK: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<7x5xf32>)
-// CHECK: outs(%{{.+}} : memref<3x7xf32>)
// -----
-func.func @contract_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
- linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d2, d0)>,
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @contract_matmul_bcast_b_transpose_a
+func.func @contract_matmul_bcast_b_transpose_a(%A: memref<5x3xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
+// CHECK: linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]]
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+ linalg.contract
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
affine_map<(d0, d1, d2) -> (d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
+ ins(%A, %B : memref<5x3xf32>, memref<5xf32>)
+ outs(%C: memref<3x7xf32>)
return
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-LABEL: func.func @contract_matmul_bcast_b_transpose_a
-// CHECK: linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
-// CHECK: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5xf32>)
-// CHECK: outs(%{{.+}} : memref<3x7xf32>)
// -----
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 99cbb6647effbe..dc556761b09e56 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -277,26 +277,27 @@ func.func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offs
// -----
+#accessA = affine_map<(batch, m, n, k) -> (batch, m, k)>
+#accessB = affine_map<(batch, m, n, k) -> (batch, k, n)>
+#accessC = affine_map<(batch, m, n, k) -> (batch, m, n)>
func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
%ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
-> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
{
linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
outs(%c3: memref<?x?x?xf32>)
- linalg.contract indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
- affine_map<(batch, m, n, k) -> (batch, k, n)>,
- affine_map<(batch, m, n, k) -> (batch, m, n)>]
- ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%c3: memref<?x?x?xf32>)
+ linalg.contract
+ indexing_maps = [#accessA, #accessB, #accessC]
+ ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%c3: memref<?x?x?xf32>)
%res1 = linalg.batch_matmul
ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
outs(%tc3: tensor<?x?x?xf32>)
-> tensor<?x?x?xf32>
- %res2 = linalg.contract indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
- affine_map<(batch, m, n, k) -> (batch, k, n)>,
- affine_map<(batch, m, n, k) -> (batch, m, n)>]
- ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
- outs(%tc3: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %res2 = linalg.contract
+ indexing_maps = [#accessA, #accessB, #accessC]
+ ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%tc3: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
}
// CHECK-LABEL: func @named_ops
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 1de1863d6deb13..557233d8aa3ec4 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -40,13 +40,16 @@ module attributes {transform.with_named_sequence} {
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-NEXT: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-NEXT: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+#access_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>]
// CHECK-LABEL: func @matmul_as_contract_tensors(
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
func.func @matmul_as_contract_tensors(
- %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+ %A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
-> tensor<?x?xf32> {
// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
@@ -54,19 +57,17 @@ func.func @matmul_as_contract_tensors(
// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[sTD:.*]] = linalg.contract indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %[[sTD:.*]] = linalg.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME: ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: scf.yield %[[TD]] : tensor<?x?xf32>
// CHECK: scf.yield %[[TD2]] : tensor<?x?xf32>
// CHECK: scf.yield %[[TD1]] : tensor<?x?xf32>
- %0 = linalg.contract indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ]
- ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2: tensor<?x?xf32>)
+ %0 = linalg.contract indexing_maps = #access_maps
+ ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>)
-> tensor<?x?xf32>
// CHECK: return %[[TD0]] : tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 2d30d62039642a..0d59dbba8940d4 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -54,39 +54,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: @vectorize_matmul_as_contract
-// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32>
-// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32>
-// CHECK-SAME: %[[C:.*]]: tensor<24x25xf32>
-func.func @vectorize_matmul_as_contract(%arg0: tensor<24x12xf32>,
- %arg1: tensor<12x25xf32>,
- %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
- // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
- // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
- // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
- // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
- // CHECK: vector.transfer_write %[[vR]], %[[C]]
- %0 = linalg.contract indexing_maps = [
- affine_map<(m, n, k) -> (m, k)>,
- affine_map<(m, n, k) -> (k, n)>,
- affine_map<(m, n, k) -> (m, n)>
- ]
- ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>)
- outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
- func.return %0 : tensor<24x25xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
// CHECK-LABEL: @vectorize_copy_memref
// CHECK-SAME: %[[A:.*]]: memref<100x100xf32>,
// CHECK-SAME: %[[B:.*]]: memref<100x100xf32>
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index b688a677500c22..3e10b8402af4e0 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -82,6 +82,38 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @matmul_as_contract
+// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32>
+// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32>
+// CHECK-SAME: %[[C:.*]]: tensor<24x25xf32>
+func.func @matmul_as_contract(%A: tensor<24x12xf32>,
+ %B: tensor<12x25xf32>,
+ %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
+ // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
+ // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
+ // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
+ // CHECK: vector.transfer_write %[[vR]], %[[C]]
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<24x12xf32>, tensor<12x25xf32>)
+ outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+ func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
#matmul_trait = {
indexing_maps = [
affine_map<(m, n, k) -> (m, k)>,
>From f469ffb46c84c3c4351b978dc537115eead694c6 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 27 Jan 2025 14:41:19 -0800
Subject: [PATCH 08/10] Add todo as pointed out by @banach-space
---
mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 3e10b8402af4e0..5ae3f893c2e739 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -108,6 +108,7 @@ module attributes {transform.with_named_sequence} {
%0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ // TODO: also tests the other available vectorization strategies
transform.yield
}
}
>From 9f74a38ba1cd0f49af49ab2146362ae066b091b2 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 28 Jan 2025 05:47:46 -0800
Subject: [PATCH 09/10] Further fixes per @banach-space's comments
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 17 ++++++++++-------
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 13 +++++++++++--
.../Dialect/Linalg/generalize-named-ops.mlir | 18 +++++++++---------
.../generalize-named-polymorphic-ops.mlir | 2 +-
mlir/test/Dialect/Linalg/invalid.mlir | 3 ++-
mlir/test/Dialect/Linalg/named-ops.mlir | 3 ---
6 files changed, 33 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5bc04ac4ceb3eb..0f1dde270fc7f7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -698,9 +698,9 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
dimension identifiers (meant to range over valid indices), corresponding to
- the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
- and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
- the dimensions in the set `dims`.
+ the co-domains of the mandatory (projected permutation) `indexing_maps` of
+ `A`, `B` and `C`, respectively. `SUM_{dims}` means reduce over all valid
+ indices for the dimensions in the set `dims`.
The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
domain of each of the `affine_map`s. Like for einsums, the iteration type of
@@ -719,21 +719,24 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
`n` and `b` are of parallel iteration-type) and gets represented as:
```
- %0 = linalg.contract
+ %D = linalg.contract
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
affine_map<(batch, m, n, k) -> (batch, k, n)>,
affine_map<(batch, m, n, k) -> (batch, m, n)>]
- ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
- outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ ins(%A, %B: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
```
Note that by permuting the dims in the co-domains of the `affine_map`s, we
can apply arbitrary transposes to the inputs and output. Similarly,
arbitrary broadcasts can be achieved through leaving out dims on either
- input operand.
+ input operand - these dims' inferred iter type will be parallel.
Numeric casting is performed on the operands to the inner multiplication,
promoting them to the same data type as the accumulator/output.
+
+ TODO: Allow control over the combining/accumulating op and possibly the
+ multiplication op.
}];
let arguments = (ins
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c6ea207323019d..b80267c3726ec3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3690,10 +3690,18 @@ void ContractOp::print(OpAsmPrinter &p) {
LogicalResult ContractOp::verify() {
int iterationSpaceDims = -1;
- // Maps iter space dim (as index) to num of occurrences in inputs and output.
+ // Map iter space dims to #occurrences in inputs' and output's affine_maps:
+ // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
+ // access an input operand (so occurrence count can be at most 2) and
+ // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
SmallVector<size_t> inOccurrences;
SmallVector<size_t> outOccurrences;
+ // For each operand's affine_map and type, check that the rank of the
+ // affine_map's domain is the same as those seen prior, check that the
+ // affine_map's co-domain rank is the same as that of the corresponding type,
+ // check that the affine_map is a projected permutation, and, finally, update
+ // inputs and output occurrence counts for dims in the co-domains.
auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
bool isInput) -> LogicalResult {
if (iterationSpaceDims == -1) {
@@ -3734,7 +3742,7 @@ LogicalResult ContractOp::verify() {
llvm::zip(getIndexingMapsArray(), getOperandTypes(),
SmallVector<bool>{true, true, false})) {
if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
- return failure(); // NOTE: checking lambda will emit error.
+ return failure(); // NB: checkAffineMapAndType will emit relevant error.
}
bool hasContractingDim = false;
@@ -3742,6 +3750,7 @@ LogicalResult ContractOp::verify() {
size_t inOccCount = inOccurrences[dimIndex];
size_t outOccCount = outOccurrences[dimIndex];
+ // We have a contracting dim if and only if ...
hasContractingDim |= inOccCount == 2 && outOccCount == 0;
if (inOccCount == 0 && outOccCount == 0)
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 7ee25ac25ac6c9..a3611b8e4ec62e 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1016,13 +1016,13 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
-func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+func.func @contract_matmul(%A: memref<3x5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) {
linalg.contract
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>]
- ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
- outs(%arg2: memref<3x7xf32>)
+ ins(%A, %B : memref<3x5xf32>, memref<5x7xf32>)
+ outs(%C: memref<3x7xf32>)
return
}
@@ -1046,13 +1046,13 @@ func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
-func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+func.func @contract_matmul_transpose_a_b(%A: memref<5x3xf32>, %B: memref<7x5xf32>, %C: memref<3x7xf32>) {
linalg.contract
indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>]
- ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
- outs(%arg2: memref<3x7xf32>)
+ ins(%A, %B : memref<5x3xf32>, memref<7x5xf32>)
+ outs(%C: memref<3x7xf32>)
return
}
@@ -1075,13 +1075,13 @@ func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
-func.func @contract_batch_matmul(%arg0: memref<9x3x5xf32>, %arg1: memref<9x5x7xf32>, %arg2: memref<9x3x7xf32>) {
+func.func @contract_batch_matmul(%A: memref<9x3x5xf32>, %B: memref<9x5x7xf32>, %C: memref<9x3x7xf32>) {
linalg.contract
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
- ins(%arg0, %arg1 : memref<9x3x5xf32>, memref<9x5x7xf32>)
- outs(%arg2: memref<9x3x7xf32>)
+ ins(%A, %B : memref<9x3x5xf32>, memref<9x5x7xf32>)
+ outs(%C: memref<9x3x7xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index a192a2f72b286a..bbd6e0fc8e2ccb 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -121,7 +121,7 @@ func.func @generalize_matmul_unsigned_tensor_f16f64i32(%A : tensor<16x8xf16>, %B
// -----
func.func @generalize_matmul_as_contraction_tensor_f16f64f32(
- %A : tensor<16x8xf16>,
+ %A: tensor<16x8xf16>,
%B: tensor<8x32xf64>,
%C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.contract
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index f6fd60766d1574..09d3076e5b778e 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -531,7 +531,8 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t
func.func @invalid_indexing_maps_placement_contraction(
%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
- // expected-error @+2 {{custom op 'linalg.contract' expected 'indexing_maps' attribute}}
+ // expected-error @+3 {{custom op 'linalg.contract' expected 'indexing_maps' attribute}}
+ // NB: indexing_maps should be provided before ins and outs
linalg.contract
ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
outs(%init : tensor<4x64xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 06e0bb801b5444..578d24a550b08c 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1568,7 +1568,6 @@ func.func @contract_matmul_bcast_b(%A: memref<3x5xf32>, %B: memref<5xf32>, %C: m
// -----
-
// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func.func @contract_matmul_bcast_a_b
@@ -1608,7 +1607,6 @@ func.func @contract_matmul_bcast_a_transpose_b(
return
}
-
// -----
// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
@@ -1629,7 +1627,6 @@ func.func @contract_matmul_bcast_b_transpose_a(%A: memref<5x3xf32>, %B: memref<5
return
}
-
// -----
// CHECK-LABEL: func @mmt4d
>From 1dc42f04a1050f9da53bb08299578ec90bf59b64 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 28 Jan 2025 14:35:41 -0800
Subject: [PATCH 10/10] Expand docs on transpose and broadcast
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 19 +++++++++++++++----
1 file changed, 15 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 0f1dde270fc7f7..6cb6138721e176 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -727,10 +727,21 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
outs(%C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
```
- Note that by permuting the dims in the co-domains of the `affine_map`s, we
- can apply arbitrary transposes to the inputs and output. Similarly,
- arbitrary broadcasts can be achieved through leaving out dims on either
- input operand - these dims' inferred iter type will be parallel.
+ Note that by permuting dims in the co-domains of the `affine_map`s arbitrary
+ transposes can be applied to the inputs and output. Similarly, arbitrary
+ broadcasts can be achieved through leaving out dims on either input operand
+ (these dims' inferred iter type will be parallel). For example, the
+ following is a variant of batch-matmul where a transposition is applied to
+ `A` while matrix `B` gets broadcasted along the batch dimension:
+
+ ```
+ linalg.contract
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
+ affine_map<(batch, m, n, k) -> (k, n)>,
+ affine_map<(batch, m, n, k) -> (batch, m, n)>]
+ ins(%A, %B: memref<?x?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?x?xf32>)
+ ```
Numeric casting is performed on the operands to the inner multiplication,
promoting them to the same data type as the accumulator/output.
More information about the Mlir-commits
mailing list