[Mlir-commits] [mlir] [MLIR][Linalg] Remove/update failing obsolete OpDSL tests for linalg.matmul. (PR #115319)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Thu Nov 7 05:49:20 PST 2024
https://github.com/shahidact created https://github.com/llvm/llvm-project/pull/115319
The earlier PR(https://github.com/llvm/llvm-project/pull/104783) which introduces
transpose and broadcast semantic to linalg.matmul was reverted due to two failing
OpDSL test for linalg.matmul.
Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL,
these test started failing and needs to be removed/updated.
This commit removes/updates the failing obsolete tests from below files. All other files
were part of earlier PR and just cherry picked.
"mlir/test/python/integration/dialects/linalg/opsrun.py"
"mlir/test/python/integration/dialects/transform.py"
>From a9770d295fb554efe0004bd7f5b2dbeb0ca55b13 Mon Sep 17 00:00:00 2001
From: Md Asghar Ahmad Shahid <md.asghar.ahmad.shahid at intel.com>
Date: Thu, 10 Oct 2024 21:30:58 +0530
Subject: [PATCH 1/3] [mlir][linalg] Introduce transpose/broadcast semantic to
'linalg.matmul' ops. (#104783)
The main goal of this patch is to extend the semantic of 'linalg.matmul'
named op to include per operand transpose semantic while also laying out
a way to move ops definition from OpDSL to tablegen. Hence, it is
implemented in tablegen. Transpose semantic is as follows.
By default 'linalg.matmul' behavior will remain as is.
Transpose/broadcast semantics can be appiled explicitly by specifying
the optional indexing_map attribute. By default, no transpose/broadcast
is mandated.
Example:
```
linalg.matmul indexing_maps = [
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
]
ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
linalg.matmul indexing_maps = [
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
]
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
---
.../Dialect/Linalg/IR/LinalgInterfaces.td | 10 +
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 72 -----
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 134 +++++++++
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 17 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 263 +++++++++++++++++-
.../Linalg/Transforms/TransposeMatmul.cpp | 7 +
.../Linalg/Transforms/Vectorization.cpp | 5 +
.../NVGPU/TransformOps/NVGPUTransformOps.cpp | 6 +
.../linalg/opdsl/ops/core_named_ops.py | 17 --
.../Dialect/Linalg/generalize-named-ops.mlir | 111 ++++++++
mlir/test/Dialect/Linalg/invalid.mlir | 159 +++++++++++
mlir/test/Dialect/Linalg/named-ops.mlir | 243 ++++++++++++++++
mlir/test/python/dialects/linalg/ops.py | 75 -----
.../mlir-linalg-ods-yaml-gen.cpp | 6 +-
14 files changed, 943 insertions(+), 182 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index b81a4c9c8760cfa..c0eff99c8507524 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -708,6 +708,16 @@ def LinalgStructuredInterface
return;
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if the user has supplied an explicit indexing maps for this op.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasUserDefinedMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ return false; }]
+ >,
//===------------------------------------------------------------------===//
// Linalg generalization hooks.
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index bf2f26de26e9ed1..ee88ca516de6ff1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul
- cpp_class_name: MatmulOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_matmul
cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c2fee8ea55c960a..2b47414ff5e924b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -554,6 +554,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", [
+ AttrSizedOperandSegments,
+ LinalgContractionOpInterface]> {
+
+ let summary = [{
+ Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
+ }];
+ let description = [{
+ Numeric casting is performed on the operands to the inner multiply,
+ promoting them to the same data type as the accumulator/output.
+
+ Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+ 'indexing_maps' as shown below.This is a list attribute, so the list must include all
+ the maps if specified.
+
+ Example Transpose:
+ ```
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast:
+ ```
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>, // broadcast
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ ```
+
+ Example Broadcast and transpose:
+ ```
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+ affine_map<(d0, d1, d2) -> (d2)>, // broadcast
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, MatmulOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addOperands(operands);
+ $_state.addAttributes(attributes);
+ $_state.addTypes(resultTensorTypes);
+ (void)$_state.addRegion();
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("cast", cast);
+ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+ attributes, MatmulOp::getRegionBuilder());
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// Implements the block region builder.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+ SmallVector<AffineMap> getDefaultIndexingMaps();
+
+ /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+ bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ std::string getLibraryCallName();
+ bool hasDynamicIndexingMaps();
+ /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+ /// user defined indexing maps are not equal to default map.
+ bool hasUserDefinedMaps();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index bd77965194b27fd..0cffadf8fb64a0b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -15,14 +15,21 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <numeric>
+#include <optional>
using namespace mlir;
using namespace mlir::linalg;
@@ -1211,7 +1218,6 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
LinalgOp linalgOp = cast<LinalgOp>(op);
-
// Mixed tensor/buffer operands are not allowed.
if (!linalgOp.hasPureTensorSemantics() &&
!linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1231,6 +1237,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
<< ") to be equal to the number of input/output operands ("
<< linalgOp->getNumOperands() << ")";
+ // Set this flag if this op has user defined maps. This is required to guard
+ // the below error condition which assume default indexing maps.
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
@@ -1247,13 +1255,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
<< " dim(s) to match the number of loops";
int64_t rank = linalgOp.getRank(&opOperand);
+
if (indexingMap.getNumResults() != rank)
return op->emitOpError("expected operand rank (")
<< rank << ") to match the result rank of indexing_map #"
<< opOperand.getOperandNumber() << " ("
<< indexingMap.getNumResults() << ")";
}
-
SmallVector<unsigned> redDims;
linalgOp.getReductionDims(redDims);
@@ -1263,9 +1271,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
// Check if given shapes match to inferred shapes.
SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
-
- // Verify only static cases since we can't get exact dimension sizes and loop
- // ranges for dynamic cases in this stage.
+ // Verify only static cases since we can't get exact dimension sizes and
+ // loop ranges for dynamic cases in this stage.
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
for (int64_t &range : endLoopRangeValues)
range -= 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef8..4f350ea236da848 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -27,6 +27,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
@@ -37,12 +38,17 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
+#include <cassert>
#include <optional>
using namespace mlir;
@@ -149,15 +155,36 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
// iterator_types is an auto-generated method.
}
+/// Helper to create a typical indexing map for MatmulOp. Returns a list of
+/// AffineMap.
+static SmallVector<AffineMap>
+getDefaultIndexingMapsForMatmul(MLIRContext *context) {
+ AffineExpr d0, d1, d2;
+ SmallVector<AffineMap, 3> indexingMaps;
+ bindDims(context, d0, d1, d2);
+ indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+ indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+ indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+ return indexingMaps;
+}
+
+/// Wrapper to return the typical indexing map array attribute for MatmulOp.
+static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
+ return llvm::map_to_vector(
+ getDefaultIndexingMapsForMatmul(context),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+}
+
/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
/// The result types are derived automatically if `resultTensorTypes` is none.
/// The body of the operation is filled using `regionBuilder`. All ods-gen
/// created structured operations use the method to implement their builders.
-static void buildStructuredOp(OpBuilder &b, OperationState &state,
- std::optional<TypeRange> resultTensorTypes,
- ValueRange inputs, ValueRange outputs,
- ArrayRef<NamedAttribute> attributes,
- RegionBuilderFn regionBuilder) {
+static void buildStructuredOp(
+ OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder,
+ std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
// Derive the result types if needed.
SmallVector<Type> derivedResultTypes =
resultTensorTypes.value_or(TypeRange());
@@ -168,6 +195,20 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
state.addOperands(inputs);
state.addOperands(outputs);
state.addTypes(derivedResultTypes);
+
+ // Initialize indexingMaps, for MatmulOp.
+ SmallVector<Attribute, 3> indexingMapsAttrVal;
+ if (indexingMaps.has_value()) {
+ for (mlir::AffineMap map : *indexingMaps) {
+ // Convert each AffineMap to an AffineMapAttr
+ indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
+ }
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ } else {
+ indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
+ state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+ }
+
state.addAttributes(attributes);
state.addAttribute(
"operandSegmentSizes",
@@ -299,11 +340,48 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result,
unsigned numRegionArgs,
RegionBuilderFn regionBuilder) {
+
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+
+ if (parser.parseLSquare())
+ return failure();
+
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr)) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ }
+ indexingMapsAttr.push_back(mapAttr);
+
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+
+ if (parser.parseRSquare())
+ return failure();
+ }
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
// TODO: Enable when ods-gen supports captures.
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
+ // Parse optional attributes.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
// TODO: consider merging results parsing into region parsing.
// Need to wait for declarative assembly resolution to decide.
SmallVector<Type, 1> outputTensorsTypes;
@@ -329,13 +407,9 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
}
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
- ValueRange inputs, ValueRange outputs) {
- p.printOptionalAttrDict(
- op->getAttrs(),
- /*elidedAttrs=*/{"operandSegmentSizes",
- // See generated code in
- // LinalgNamedStructuredOps.yamlgen.cpp.inc
- "linalg.memoized_indexing_maps"});
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<StringRef> elidedAttrs = {}) {
+ p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
// Printing is shared with generic ops, except for the region and
// attributes.
@@ -3382,3 +3456,168 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
Location loc) {
return arith::ConstantOp::materialize(builder, value, type, loc);
}
+
+/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
+/// defaultMap.
+static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
+ auto explicitRange = explictMap.getResults();
+ auto defaultRange = defaultMap.getResults();
+ DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
+ DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
+ llvm::set_union(explicitSet, defaultSet);
+ return explicitSet == defaultSet;
+}
+
+/// Returns true if the \p explictMap is broadcasted with respect to the
+/// \p defaultMap.
+static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
+ return explictMap.getNumResults() < defaultMap.getNumResults();
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the MatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
+ unsigned opIndex) {
+ SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
+ SmallVector<AffineMap, 3> defaultIndexingMaps =
+ matmulOp.getDefaultIndexingMaps();
+
+ auto opIndexingMap = opIndexingMaps[opIndex];
+ auto defaultIndexingMap = defaultIndexingMaps[opIndex];
+ // Check general validity of indexing map results.
+ if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+ return matmulOp->emitOpError()
+ << "Unexpected dim expression in map result.";
+
+ // Check if the requested broadcast is valid.
+ if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+ if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
+ return matmulOp->emitOpError()
+ << "Invalid broadcast requested, should be (d2).";
+ }
+ return success();
+ }
+ return success();
+}
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+ utils::IteratorType::parallel,
+ utils::IteratorType::reduction};
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if the
+/// user defined indexing maps are not equal to default map.
+bool MatmulOp::hasUserDefinedMaps() {
+ SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+}
+
+/// Implements the block region builder for the MatmulOp. This is called by
+/// 'fillStructuredOpRegion'.
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(3 > 0 && block.getNumArguments() == 3 &&
+ "MatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ TypeFn castVal = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castVal = attr.getValue();
+ }
+
+ Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+/// Returns a list of AffineMap with the typical matmul indexing charactristic.
+SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
+ MLIRContext *context = this->getContext();
+ return getDefaultIndexingMapsForMatmul(context);
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
+ assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
+ AffineExpr exp = bcastMap.getResult(0);
+ // Invalid map if the common dimension of matmul not found.
+ return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
+}
+
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
+ MatmulOp::getRegionBuilder());
+}
+void MatmulOp::print(OpAsmPrinter &p) {
+ SmallVector<StringRef, 3> elidedAttrs = {
+ "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+
+ SmallVector<Attribute, 3> indexingMaps =
+ getDefaultIndexingMapAttr(getContext());
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult MatmulOp::verify() {
+ // Verification of pure matmul is handled by verifyStructuredOpInterface().
+ if (!hasUserDefinedMaps())
+ return success();
+
+ for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
+ if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
+ return failure();
+ }
+ return success();
+}
+
+LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void MatmulOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability MatmulOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
+} // namespace linalg
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index aa0052ce47fa7b2..6b934f7e8157d47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -31,6 +31,13 @@ using namespace mlir::linalg;
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp matmulOp,
bool transposeLHS) {
+ // Check to not let go the matmul with extended semantic, through this
+ // transform.
+ if (matmulOp.hasUserDefinedMaps()) {
+ return rewriter.notifyMatchFailure(
+ matmulOp, "only matmul ops with non-extended semantics are supported");
+ }
+
if (!bufferization::hasTensorSemantics(matmulOp))
return rewriter.notifyMatchFailure(
matmulOp, "only matmul ops with tensors are supported");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 090e0b46768d7e9..757701dc024dfeb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2090,6 +2090,11 @@ vectorizeScalableVectorPrecondition(Operation *op,
return failure();
}
+ // Check to not let go the matmul with extended semantic, through this
+ // transform.
+ if (linalgOp.hasUserDefinedMaps())
+ return failure();
+
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 0c2275bbc4b224a..3c508ed6e324b2b 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -821,6 +821,12 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
bool fail = true;
// TODO: more robust detection of matmulOp, with transposes etc.
if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
+ // Check to not let go the matmul with extended semantic, through this
+ // transform.
+ if (linalgOp.hasUserDefinedMaps()) {
+ return emitSilenceableError()
+ << "only matmul ops with non-extended semantics are supported";
+ }
Location loc = linalgOp.getLoc();
// TODO: more robust computation of laneId, for now assume a single warp.
Value laneId = rewriter.create<gpu::ThreadIdOp>(
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index b45fecd0ee14575..5c1c984b136058e 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -383,23 +383,6 @@ def select(
O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
- at linalg_structured_op
-def matmul(
- A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
- """Performs a matrix multiplication of two 2D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
-
-
@linalg_structured_op
def quantized_matmul(
A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 1e8f1435ca0fa5c..aba26c35931fd38 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -29,6 +29,34 @@ func.func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>,
// -----
+func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_bcast_a(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) {
+// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK: %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK: linalg.yield %[[VAL_7]] : f32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// -----
+
func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
@@ -891,3 +919,86 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_a_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_transpose_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 4b5a66f8fb5b922..a59472377a732c4 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -370,6 +370,165 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
// -----
+func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+ // expected-error @+1 {{expected attribute value}}
+ linalg.matmul indexing_maps = [
+ ,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)
+ outs(%arg2 :memref<2x4xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_matmul_dim_a(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) {
+ // expected-error @+1 {{Unexpected dim expression in map result}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_matmul_dim_b(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) {
+ // expected-error @+1 {{Unexpected dim expression in map result}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_transpose_a_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
+ // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 1}}
+ %0 = linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ return %0: tensor<4x64xf32>
+}
+
+// -----
+
+func.func @invalid_transpose_b_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
+ // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #1 to be 1, but found 64}}
+ %0 = linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+ outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+ return %0: tensor<4x64xf32>
+}
+
+// -----
+
+func.func @invalid_bcast_a(%arg0: memref<3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_a_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #0 (1)}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_b_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #1 (1)}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 5, but found 7}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_matmul_bcast_b_transpose_a_wrong_dim(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op Unexpected dim expression in map result.}}
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+ // expected-error @+2 {{custom op 'indexing_maps' is unknown (tried 'func.indexing_maps' as well)}}
+ linalg.matmul ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) outs(%init : tensor<4x64xf32>)
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ return
+}
+
+// -----
+
func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
// expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
linalg.conv_2d_nhwc_hwcf
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 02ecbed232c8b56..65c18de8424771a 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1201,6 +1201,249 @@ func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %a
// -----
+// CHECK-LABEL: func @matmul_transpose_a_explicit
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// -----
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_a
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_a_dim1
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_b
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_bcast_a_b(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_b_dim1
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @dynamic_matmul_bcast_a(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @dynamic_matmul_bcast_a(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<?x?xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_bcast_a_transpose_b(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_bcast_b_transpose_a(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK: linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+
+// -----
+
// CHECK-LABEL: func @matmul_transpose_b
// CHECK: linalg.matmul_transpose_b
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 3bfbcf7d7f7c81d..72045a07b2da800 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,81 +84,6 @@ def named_form(lhs, rhs):
print(module)
-
-# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
- at run
-def testNamedStructuredOpGenericForm():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
- )
- def named_form(lhs, rhs):
- init_result = tensor.empty([4, 8], f32)
- # CHECK: "linalg.matmul"(%{{.*}})
- # CHECK-SAME: cast = #linalg.type_fn<cast_signed>
- # CHECK-SAME: operandSegmentSizes = array<i32: 2, 1>
- # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
- # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32
- # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
- # CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
- # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
- return linalg.matmul(lhs, rhs, outs=[init_result])
-
- module.operation.print(print_generic_op_form=True)
-
-
-# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
- at run
-def testNamedStructuredAsGenericOp():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
- )
- def generic_form(lhs, rhs):
- init_result = tensor.EmptyOp([4, 8], f32)
- # CHECK: linalg.generic
- return linalg.matmul(
- lhs, rhs, outs=[init_result.result], emit_generic=True
- )
-
- print(module)
-
-
-# CHECK-LABEL: TEST: testOpResultFromOtherOp
- at run
-def testOpResultFromOtherOp():
- with Context(), Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
- )
- def pass_an_op_directly(arg0, arg1):
- one = arith.ConstantOp(F32Type.get(), 1.0)
- # CHECK: %[[LHS:.*]] = linalg.fill
- lhs = linalg.fill(one, outs=[arg0])
- # CHECK: %[[RHS:.*]] = linalg.fill
- rhs = linalg.fill(one, outs=[arg1])
- # CHECK: %[[INIT:.*]] = tensor.empty
- init = tensor.EmptyOp([4, 8], f32)
- # CHECK: linalg.matmul
- # CHECK: ins(%[[LHS]], %[[RHS]]
- # CHECK: outs(%[[INIT]]
- return linalg.matmul(lhs, rhs, outs=init)
-
- print(module)
-
-
# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 5f86c0cd7470772..6be7d4320c65626 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -678,7 +678,11 @@ ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
{0}::getNumRegionArgs(), {0}::getRegionBuilder());
}
void {0}::print(OpAsmPrinter &p) {{
- ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
+ SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes",
+ "linalg.memoized_indexing_maps",
+ "indexing_maps"};
+ ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
}
)FMT";
>From 4bfe1e3c660bac34e89d6b66bb6d4d5d9339d681 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Thu, 10 Oct 2024 18:52:20 +0100
Subject: [PATCH 2/3] Fix GCC build problem with 03483737a7a2
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4f350ea236da848..c909d13e4314b48 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -157,7 +157,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
/// Helper to create a typical indexing map for MatmulOp. Returns a list of
/// AffineMap.
-static SmallVector<AffineMap>
+static SmallVector<AffineMap, 3>
getDefaultIndexingMapsForMatmul(MLIRContext *context) {
AffineExpr d0, d1, d2;
SmallVector<AffineMap, 3> indexingMaps;
>From 4809882fbbd1d49d1e4b94de7c293741d469cd73 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Thu, 7 Nov 2024 05:28:56 -0800
Subject: [PATCH 3/3] [MLIR][Linalg] Remove/update failing obsolete OpDSL test
for linalg.matmul.
The earlier PR(https://github.com/llvm/llvm-project/pull/104783) was reverted
due to two failing OpDSL test for linalg.matmul.
Since linalg.matmul is now defined using TableGen ODS instead of Python-based
OpDSL, these test started failing and needs to be removed/updated.
This commit removes/updates the failing obsolete tests from below tests.
"mlir/test/python/integration/dialects/linalg/opsrun.py"
"mlir/test/python/integration/dialects/transform.py"
---
.../integration/dialects/linalg/opsrun.py | 115 ------------------
.../python/integration/dialects/transform.py | 26 ++--
2 files changed, 14 insertions(+), 127 deletions(-)
diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index f6519fb17a6b98f..f77900bc277736c 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -50,37 +50,6 @@ def log(*args):
}
"""
-matmul_boiler = """
-func.func @main() -> f32 attributes {llvm.emit_c_interface} {
- %v0 = arith.constant 0.0 : f32
- %v1 = arith.constant -1 : i8
- %v2 = arith.constant 2.0 : f32
-
- %A = memref.alloc() : memref<4x16xi8>
- %B = memref.alloc() : memref<16x8xf32>
- %C0 = memref.alloc() : memref<4x8xf32>
- %C1 = memref.alloc() : memref<4x8xf32>
- linalg.fill ins(%v1 : i8) outs(%A : memref<4x16xi8>)
- linalg.fill ins(%v2 : f32) outs(%B : memref<16x8xf32>)
- linalg.fill ins(%v0 : f32) outs(%C0 : memref<4x8xf32>)
- linalg.fill ins(%v0 : f32) outs(%C1 : memref<4x8xf32>)
-
- call @matmul_signed_on_buffers(%A, %B, %C0) :
- (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> ()
- call @matmul_unsigned_on_buffers(%A, %B, %C1) :
- (memref<4x16xi8>, memref<16x8xf32>, memref<4x8xf32>) -> ()
-
- %c0 = arith.constant 0 : index
- %res0 = memref.load %C0[%c0, %c0] : memref<4x8xf32>
- %res1 = memref.load %C1[%c0, %c0] : memref<4x8xf32>
-
- %0 = arith.addf %res0, %res1 : f32
-
- // TODO: FFI-based solution to allow testing and printing with python code.
- return %0 : f32
-}
-"""
-
fill_boiler = """
func.func @main() -> i32 attributes {llvm.emit_c_interface} {
%O0 = memref.alloc() : memref<i32>
@@ -296,90 +265,6 @@ def elemwise_log_mul_on_buffers(lhs, rhs, out):
test_elemwise_generic()
-def test_matmul_builtin():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i8 = IntegerType.get_signless(8)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((4, 16), i8),
- MemRefType.get((16, 8), f32),
- MemRefType.get((4, 8), f32),
- )
- def matmul_signed_on_buffers(lhs, rhs, out):
- linalg.matmul(lhs, rhs, outs=[out])
-
- @func.FuncOp.from_py_func(
- MemRefType.get((4, 16), i8),
- MemRefType.get((16, 8), f32),
- MemRefType.get((4, 8), f32),
- )
- def matmul_unsigned_on_buffers(lhs, rhs, out):
- linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned)
-
- execution_engine = ExecutionEngine(transform(module, matmul_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result f32.
- # Arguments must be passed as pointers.
- c_float_p = ctypes.c_float * 1
- res = c_float_p(-1.0)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # matmul_signed_on_buffers: -1 * 2.0 * 16 = -32
- # matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160
- # CHECK: RESULT: 8128
-
-
-test_matmul_builtin()
-
-
-def test_matmul_generic():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i8 = IntegerType.get_signless(8)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((4, 16), i8),
- MemRefType.get((16, 8), f32),
- MemRefType.get((4, 8), f32),
- )
- def matmul_signed_on_buffers(lhs, rhs, out):
- linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
-
- @func.FuncOp.from_py_func(
- MemRefType.get((4, 16), i8),
- MemRefType.get((16, 8), f32),
- MemRefType.get((4, 8), f32),
- )
- def matmul_unsigned_on_buffers(lhs, rhs, out):
- linalg.matmul(
- lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True
- )
-
- execution_engine = ExecutionEngine(transform(module, matmul_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result f32.
- # Arguments must be passed as pointers.
- c_float_p = ctypes.c_float * 1
- res = c_float_p(-1.0)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # matmul_signed_on_buffers = -1 * 2.0 * 16 = -32
- # matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160
- # CHECK: RESULT: 8128
-
-
-test_matmul_generic()
-
-
def test_fill_builtin():
with Context() as ctx, Location.unknown():
module = Module.create()
diff --git a/mlir/test/python/integration/dialects/transform.py b/mlir/test/python/integration/dialects/transform.py
index bc88a61314d0d8a..3895eed58684313 100644
--- a/mlir/test/python/integration/dialects/transform.py
+++ b/mlir/test/python/integration/dialects/transform.py
@@ -99,26 +99,28 @@ def basic(target: any_op_t()):
# CHECK-LABEL: TEST: test_apply_patterns
@construct_and_print_in_module
def test_apply_patterns(module_):
- M, N, K = 3, 5, 3
+ b, M, N, K = 1, 3, 5, 3
- # CHECK-LABEL: func.func @matmul(
- # CHECK-SAME: %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
+ # CHECK-LABEL: func.func @batch_reduce_matmul(
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x5xf32>,
+ # CHECK-SAME: %[[VAL_1:.*]]: tensor<1x5x3xf32>,
+ # CHECK-SAME: %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
# CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : i32
- # CHECK: %[[VAL_5:.*]] = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
+ # CHECK: %[[VAL_5:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
# CHECK: return %[[VAL_5]] : tensor<3x3xf32>
# CHECK: }
@func.func(
- T.tensor(M, N, T.f32()), T.tensor(N, K, T.f32()), T.tensor(M, K, T.f32())
+ T.tensor(b, M, N, T.f32()), T.tensor(b, N, K, T.f32()), T.tensor(M, K, T.f32())
)
- def matmul(A, B, C):
+ def batch_reduce_matmul(A, B, C):
i = arith.constant(T.i32(), 1)
v = arith.addi(i, i)
- return linalg.matmul(A, B, outs=[C])
+ return linalg.batch_reduce_matmul(A, B, outs=[C])
# CHECK-LABEL: module attributes {transform.with_named_sequence} {
# CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) {
- # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+ # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
# CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
# CHECK: transform.apply_patterns to %[[VAL_2]] {
# CHECK: transform.apply_patterns.canonicalization
@@ -132,7 +134,7 @@ def matmul(A, B, C):
def mod():
@named_sequence("__transform_main", [any_op_t()], [])
def basic(variant_op: any_op_t()):
- matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"])
+ matmul = structured_match(any_op_t(), variant_op, ops=["linalg.batch_reduce_matmul"])
top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")
@apply_patterns(top_func)
@@ -147,9 +149,9 @@ def pats():
pm = PassManager.parse("builtin.module(transform-interpreter)")
pm.run(module_.operation)
- # CHECK-LABEL: func.func @matmul(
- # CHECK-SAME: %[[VAL_0:.*]]: tensor<3x5xf32>, %[[VAL_1:.*]]: tensor<5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
- # CHECK: %[[VAL_3:.*]] = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : tensor<3x5xf32>, tensor<5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
+ # CHECK-LABEL: func.func @batch_reduce_matmul(
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x5xf32>, %[[VAL_1:.*]]: tensor<1x5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
+ # CHECK: %[[VAL_3:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
# CHECK: return %[[VAL_3]] : tensor<3x3xf32>
# CHECK: }
print(module_)
More information about the Mlir-commits
mailing list