[Mlir-commits] [mlir] [mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (PR #104783)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Thu Oct 10 05:19:52 PDT 2024
https://github.com/shahidact updated https://github.com/llvm/llvm-project/pull/104783
>From c1caf7474560c38afca17960b6135048eea30776 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 21 Aug 2024 05:34:27 -0700
Subject: [PATCH] [mlir][linalg] Extends 'linalg.matmul' named op to define
broadcast and transpose semantic.
Goals:
1. To add syntax to matmul without changing any of the existing syntax
expectations for current usage. matmul is still just matmul.
2. To expose broadcast and transpose semantics on the three matmul
variations: matmul, batch_matmul and batch_reduce_matmul.
Scope of this patch:
To expose broadcast and transpose semantics on the 'matmul'.
The broadcast and transpose semantic is as follows:
By default 'linalg.matmul' behavior will remain as is.Broadcast and
Transpose semantics can be appiled by specifying the explicit attribute
'indexing_maps' as shown below.This is a list attribute, so the list
must include all the maps if specified.
Example Transpose:
linalg.matmul indexing_maps = [
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
]
ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
Example Broadcast:
linalg.matmul indexing_maps = [
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
]
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
---
.../Dialect/Linalg/IR/LinalgInterfaces.td | 10 +
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 72 -----
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 134 +++++++++
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 18 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 263 +++++++++++++++++-
.../Linalg/Transforms/TransposeMatmul.cpp | 7 +
.../Linalg/Transforms/Vectorization.cpp | 5 +
.../NVGPU/TransformOps/NVGPUTransformOps.cpp | 6 +
.../linalg/opdsl/ops/core_named_ops.py | 17 --
.../Dialect/Linalg/generalize-named-ops.mlir | 111 ++++++++
mlir/test/Dialect/Linalg/invalid.mlir | 159 +++++++++++
mlir/test/Dialect/Linalg/named-ops.mlir | 243 ++++++++++++++++
mlir/test/python/dialects/linalg/ops.py | 75 -----
.../mlir-linalg-ods-yaml-gen.cpp | 6 +-
14 files changed, 944 insertions(+), 182 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9b..e80dbb2afb9ef7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -684,6 +684,16 @@ def LinalgStructuredInterface
return;
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if the user has supplied an explicit indexing maps for this op.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasUserDefinedMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ return false; }]
+ >,
//===------------------------------------------------------------------===//
// Linalg generalization hooks.
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 8cb698096ef5b7..97b90333e2b200 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul
- cpp_class_name: MatmulOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_matmul
cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 31f29139247267..61d4fc9734c6de 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -535,6 +535,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+
+ let summary = [{
+ Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
+ }];
+ let description = [{
+ Numeric casting is performed on the operands to the inner multiply,
+ promoting them to the same data type as the accumulator/output.
+
+ Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+ 'indexing_maps' as shown below.This is a list attribute, so the list must include all
+ the maps if specified.
+
+ Example Transpose:
+ ```
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast:
+ ```
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>, // broadcast
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast and transpose:
+ ```
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+ affine_map<(d0, d1, d2) -> (d2)>, // broadcast
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, MatmulOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addOperands(operands);
+ $_state.addAttributes(attributes);
+ $_state.addTypes(resultTensorTypes);
+ (void)$_state.addRegion();
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("cast", cast);
+ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+ attributes, MatmulOp::getRegionBuilder());
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// Implements the block region builder.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+ SmallVector<AffineMap> getDefaultIndexingMaps();
+
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ std::string getLibraryCallName();
+ bool hasDynamicIndexingMaps();
+ /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+ /// user defined indexing maps are not equal to default map.
+ bool hasUserDefinedMaps();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 40795879c3026d..97f1076002ede6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -15,13 +15,20 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
#include <algorithm>
+#include <optional>
using namespace mlir;
using namespace mlir::linalg;
@@ -1142,7 +1149,6 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
LinalgOp linalgOp = cast<LinalgOp>(op);
-
// Mixed tensor/buffer operands are not allowed.
if (!linalgOp.hasPureTensorSemantics() &&
!linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1162,6 +1168,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
<< ") to be equal to the number of input/output operands ("
<< linalgOp->getNumOperands() << ")";
+ // Set this flag if this op has user defined maps. This is required to guard
+ // the below error condition which assume default indexing maps.
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
@@ -1178,13 +1186,14 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
<< " dim(s) to match the number of loops";
int64_t rank = linalgOp.getRank(&opOperand);
+
if (indexingMap.getNumResults() != rank)
return op->emitOpError("expected operand rank (")
<< rank << ") to match the result rank of indexing_map #"
<< opOperand.getOperandNumber() << " ("
<< indexingMap.getNumResults() << ")";
- }
+ }
SmallVector<unsigned> redDims;
linalgOp.getReductionDims(redDims);
@@ -1194,9 +1203,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
// Check if given shapes match to inferred shapes.
SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
-
- // Verify only static cases since we can't get exact dimension sizes and loop
- // ranges for dynamic cases in this stage.
+ // Verify only static cases since we can't get exact dimension sizes and
+ // loop ranges for dynamic cases in this stage.
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
for (int64_t &range : endLoopRangeValues)
range -= 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..77a1db080cd40e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -27,6 +27,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
@@ -37,12 +38,17 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
+#include <cassert>
#include <optional>
using namespace mlir;
@@ -149,15 +155,36 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
// iterator_types is an auto-generated method.
}
+/// Helper to create a typical indexing map for MatmulOp. Returns a list of
+/// AffineMap.
+static SmallVector<AffineMap>
+getDefaultIndexingMapsForMatmul(MLIRContext *context) {
+ AffineExpr d0, d1, d2;
+ SmallVector<AffineMap, 3> indexingMaps;
+ bindDims(context, d0, d1, d2);
+ indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+ indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+ indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+ return indexingMaps;
+}
+
+/// Wrapper to return the typical indexing map array attribute for MatmulOp.
+static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
+ return llvm::map_to_vector(
+ getDefaultIndexingMapsForMatmul(context),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+}
+
/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
/// The result types are derived automatically if `resultTensorTypes` is none.
/// The body of the operation is filled using `regionBuilder`. All ods-gen
/// created structured operations use the method to implement their builders.
-static void buildStructuredOp(OpBuilder &b, OperationState &state,
- std::optional<TypeRange> resultTensorTypes,
- ValueRange inputs, ValueRange outputs,
- ArrayRef<NamedAttribute> attributes,
- RegionBuilderFn regionBuilder) {
+static void buildStructuredOp(
+ OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder,
+ std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
// Derive the result types if needed.
SmallVector<Type> derivedResultTypes =
resultTensorTypes.value_or(TypeRange());
@@ -168,6 +195,20 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
state.addOperands(inputs);
state.addOperands(outputs);
state.addTypes(derivedResultTypes);
+
+ // Initialize indexingMaps, for MatmulOp.
+ SmallVector<Attribute, 3> indexingMapsAttrVal;
+ if (indexingMaps.has_value()) {
+ for (mlir::AffineMap map : *indexingMaps) {
+ // Convert each AffineMap to an AffineMapAttr
+ indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
+ }
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ } else {
+ indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ }
+
state.addAttributes(attributes);
state.addAttribute(
"operandSegmentSizes",
@@ -299,11 +340,48 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result,
unsigned numRegionArgs,
RegionBuilderFn regionBuilder) {
+
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+
+ if (parser.parseLSquare())
+ return failure();
+
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr)) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ }
+ indexingMapsAttr.push_back(mapAttr);
+
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+
+ if (parser.parseRSquare())
+ return failure();
+ }
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
// TODO: Enable when ods-gen supports captures.
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
+ // Parse optional attributes.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
// TODO: consider merging results parsing into region parsing.
// Need to wait for declarative assembly resolution to decide.
SmallVector<Type, 1> outputTensorsTypes;
@@ -329,13 +407,9 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
}
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
- ValueRange inputs, ValueRange outputs) {
- p.printOptionalAttrDict(
- op->getAttrs(),
- /*elidedAttrs=*/{"operandSegmentSizes",
- // See generated code in
- // LinalgNamedStructuredOps.yamlgen.cpp.inc
- "linalg.memoized_indexing_maps"});
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<StringRef> elidedAttrs = {}) {
+ p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
// Printing is shared with generic ops, except for the region and
// attributes.
@@ -3382,3 +3456,168 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
Location loc) {
return arith::ConstantOp::materialize(builder, value, type, loc);
}
+
+/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
+/// defaultMap.
+static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
+ auto explicitRange = explictMap.getResults();
+ auto defaultRange = defaultMap.getResults();
+ DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
+ DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
+ llvm::set_union(explicitSet, defaultSet);
+ return explicitSet == defaultSet;
+}
+
+/// Returns true if the \p explictMap is broadcasted with respect to the
+/// \p defaultMap.
+static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
+ return explictMap.getNumResults() < defaultMap.getNumResults();
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the MatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
+ unsigned opIndex) {
+ SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
+ SmallVector<AffineMap, 3> defaultIndexingMaps =
+ matmulOp.getDefaultIndexingMaps();
+
+ auto opIndexingMap = opIndexingMaps[opIndex];
+ auto defaultIndexingMap = defaultIndexingMaps[opIndex];
+ // Check general validity of indexing map results.
+ if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+ return matmulOp->emitOpError()
+ << "Unexpected dim expression in map result.";
+
+ // Check if the requested broadcast is valid.
+ if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+ if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
+ return matmulOp->emitOpError()
+ << "Invalid broadcast requested, should be (d2).";
+ }
+ return success();
+ }
+ return success();
+}
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+ utils::IteratorType::parallel,
+ utils::IteratorType::reduction};
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if the
+/// user defined indexing maps are not equal to default map.
+bool MatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Implements the block region builder for the MatmulOp. This is called by
+/// 'fillStructuredOpRegion'.
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(3 > 0 && block.getNumArguments() == 3 &&
+ "MatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ TypeFn castVal = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castVal = attr.getValue();
+ }
+
+ Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+/// Returns a list of AffineMap with the typical matmul indexing charactristic.
+SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
+ MLIRContext *context = this->getContext();
+ return getDefaultIndexingMapsForMatmul(context);
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
+ assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
+ AffineExpr exp = bcastMap.getResult(0);
+ // Invalid map if the common dimension of matmul not found.
+ return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
+}
+
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
+ MatmulOp::getRegionBuilder());
+}
+void MatmulOp::print(OpAsmPrinter &p) {
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+
+ SmallVector<Attribute, 3> indexingMaps =
+ getDefaultIndexingMapAttr(getContext());
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult MatmulOp::verify() {
+ // Verification of pure matmul is handled by verifyStructuredOpInterface().
+ if (!hasUserDefinedMaps())
+ return success();
+
+ for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
+ if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
+ return failure();
+ }
+ return success();
+}
+
+LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void MatmulOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability MatmulOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
+} // namespace linalg
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index aa0052ce47fa7b..6b934f7e8157d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -31,6 +31,13 @@ using namespace mlir::linalg;
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp matmulOp,
bool transposeLHS) {
+ // Check to not let go the matmul with extended semantic, through this
+ // transform.
+ if (matmulOp.hasUserDefinedMaps()) {
+ return rewriter.notifyMatchFailure(
+ matmulOp, "only matmul ops with non-extended semantics are supported");
+ }
+
if (!bufferization::hasTensorSemantics(matmulOp))
return rewriter.notifyMatchFailure(
matmulOp, "only matmul ops with tensors are supported");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 09c6b2683b4388..e3f010d9cfb20b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2071,6 +2071,11 @@ vectorizeScalableVectorPrecondition(Operation *op,
return failure();
}
+ // Check to not let go the matmul with extended semantic, through this
+ // transform.
+ if (linalgOp.hasUserDefinedMaps())
+ return failure();
+
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 0c2275bbc4b224..3c508ed6e324b2 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -821,6 +821,12 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
bool fail = true;
// TODO: more robust detection of matmulOp, with transposes etc.
if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
+ // Check to not let go the matmul with extended semantic, through this
+ // transform.
+ if (linalgOp.hasUserDefinedMaps()) {
+ return emitSilenceableError()
+ << "only matmul ops with non-extended semantics are supported";
+ }
Location loc = linalgOp.getLoc();
// TODO: more robust computation of laneId, for now assume a single warp.
Value laneId = rewriter.create<gpu::ThreadIdOp>(
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index e4a6ec7487bb2f..d5e79b4d3cb6dd 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -383,23 +383,6 @@ def select(
O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
- at linalg_structured_op
-def matmul(
- A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
- """Performs a matrix multiplication of two 2D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
-
-
@linalg_structured_op
def quantized_matmul(
A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 1e8f1435ca0fa5..aba26c35931fd3 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -29,6 +29,34 @@ func.func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>,
// -----
+func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_bcast_a(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) {
+// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK: %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK: linalg.yield %[[VAL_7]] : f32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// -----
+
func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
@@ -891,3 +919,86 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_a_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_transpose_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index c481a723c5623c..b2869893b8042d 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -361,6 +361,165 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
// -----
+func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+ // expected-error @+1 {{expected attribute value}}
+ linalg.matmul indexing_maps = [
+ ,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)
+ outs(%arg2 :memref<2x4xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_matmul_dim_a(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) {
+ // expected-error @+1 {{Unexpected dim expression in map result}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_matmul_dim_b(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) {
+ // expected-error @+1 {{Unexpected dim expression in map result}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_transpose_a_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
+ // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 1}}
+ %0 = linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ return %0: tensor<4x64xf32>
+}
+
+// -----
+
+func.func @invalid_transpose_b_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
+ // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #1 to be 1, but found 64}}
+ %0 = linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ return %0: tensor<4x64xf32>
+}
+
+// -----
+
+func.func @invalid_bcast_a(%arg0: memref<3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_a_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #0 (1)}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_b_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #1 (1)}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 5, but found 7}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_matmul_bcast_b_transpose_a_wrong_dim(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op Unexpected dim expression in map result.}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+ // expected-error @+2 {{custom op 'indexing_maps' is unknown (tried 'func.indexing_maps' as well)}}
+ linalg.matmul ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) outs(%init : tensor<4x64xf32>)
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ return
+}
+
+// -----
+
func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
// expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
linalg.conv_2d_nhwc_hwcf
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 02ecbed232c8b5..65c18de8424771 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1201,6 +1201,249 @@ func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %a
// -----
+// CHECK-LABEL: func @matmul_transpose_a_explicit
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_a
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_a_dim1
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_b
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_bcast_a_b(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_b_dim1
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @dynamic_matmul_bcast_a(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @dynamic_matmul_bcast_a(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<?x?xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_bcast_a_transpose_b(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_bcast_b_transpose_a(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
// CHECK-LABEL: func @matmul_transpose_b
// CHECK: linalg.matmul_transpose_b
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 3bfbcf7d7f7c81..72045a07b2da80 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,81 +84,6 @@ def named_form(lhs, rhs):
print(module)
-
-# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
- at run
-def testNamedStructuredOpGenericForm():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
- )
- def named_form(lhs, rhs):
- init_result = tensor.empty([4, 8], f32)
- # CHECK: "linalg.matmul"(%{{.*}})
- # CHECK-SAME: cast = #linalg.type_fn<cast_signed>
- # CHECK-SAME: operandSegmentSizes = array<i32: 2, 1>
- # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
- # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32
- # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
- # CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
- # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
- return linalg.matmul(lhs, rhs, outs=[init_result])
-
- module.operation.print(print_generic_op_form=True)
-
-
-# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
- at run
-def testNamedStructuredAsGenericOp():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
- )
- def generic_form(lhs, rhs):
- init_result = tensor.EmptyOp([4, 8], f32)
- # CHECK: linalg.generic
- return linalg.matmul(
- lhs, rhs, outs=[init_result.result], emit_generic=True
- )
-
- print(module)
-
-
-# CHECK-LABEL: TEST: testOpResultFromOtherOp
- at run
-def testOpResultFromOtherOp():
- with Context(), Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
- )
- def pass_an_op_directly(arg0, arg1):
- one = arith.ConstantOp(F32Type.get(), 1.0)
- # CHECK: %[[LHS:.*]] = linalg.fill
- lhs = linalg.fill(one, outs=[arg0])
- # CHECK: %[[RHS:.*]] = linalg.fill
- rhs = linalg.fill(one, outs=[arg1])
- # CHECK: %[[INIT:.*]] = tensor.empty
- init = tensor.EmptyOp([4, 8], f32)
- # CHECK: linalg.matmul
- # CHECK: ins(%[[LHS]], %[[RHS]]
- # CHECK: outs(%[[INIT]]
- return linalg.matmul(lhs, rhs, outs=init)
-
- print(module)
-
-
# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index aa5a52a21f1251..f820cb7ee8c3c4 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -681,7 +681,11 @@ ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
{0}::getNumRegionArgs(), {0}::getRegionBuilder());
}
void {0}::print(OpAsmPrinter &p) {{
- ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
+ SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes",
+ "linalg.memoized_indexing_maps",
+ "indexing_maps"};
+ ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
}
)FMT";
More information about the Mlir-commits
mailing list