[Mlir-commits] [mlir] [mlir][Linalg] Fix non-matmul linalg structured ops (PR #116412)
Kunwar Grover
llvmlistbot at llvm.org
Fri Nov 15 09:28:23 PST 2024
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/116412
https://github.com/llvm/llvm-project/commit/3ad0148020ca91cc288bffd8ad36e25f7555a3bb broke linalg structured ops other than MatmulOp.
The patch:
- Changes the printer to hide additional attributes, which weren't hidden before: "indexing_maps".
- Changes the build of every linalg structured op to have an indexing map for matmul.
These changes combined, hide the problem until you print the operation in it's generic form.
Reproducer:
```mlir
func.func public @bug(%arg0 : tensor<5x10x20xf32>, %arg1 : tensor<5x20x40xf32>, %arg3 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32> {
%out = linalg.batch_matmul ins(%arg0, %arg1 : tensor<5x10x20xf32>, tensor<5x20x40xf32>)
outs(%arg3 : tensor<5x10x40xf32>) -> tensor<5x10x40xf32>
func.return %out : tensor<5x10x40xf32>
}
```
Prints fine, with `mlir-opt <file>`, but if you do `mlir-opt --mlir-print-op-generic <file>`:
```
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
"builtin.module"() ({
"func.func"() <{function_type = (tensor<5x10x20xf32>, tensor<5x20x40xf32>, tensor<5x10x40xf32>) -> tensor<5x10x40xf32>, sym_name = "bug", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<5x10x20xf32>, %arg1: tensor<5x20x40xf32>, %arg2: tensor<5x10x40xf32>):
%0 = "linalg.batch_matmul"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> ({
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%1 = "arith.mulf"(%arg3, %arg4) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
%2 = "arith.addf"(%arg5, %1) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
"linalg.yield"(%2) : (f32) -> ()
}) {indexing_maps = [#map, #map1, #map2], linalg.memoized_indexing_maps = [#map3, #map4, #map5]} : (tensor<5x10x20xf32>, tensor<5x20x40xf32>, tensor<5x10x40xf32>) -> tensor<5x10x40xf32>
"func.return"(%0) : (tensor<5x10x40xf32>) -> ()
}) : () -> ()
}) : () -> ()
```
The batch_matmul operation's builder now always inserts a indexing_map which is unrelated to the operation itself. This was caught when a transformation from one LinalgStructuredOp to another, tried to pass it's attributes to the other ops builder and there were multiple indexing_map attributes in the result.
This patch fixes this by specializing the builders for MatmulOp with indexing map information.
>From 93cc6ba1f8959be56ab11a678274ec13a137abb4 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 15 Nov 2024 17:26:05 +0000
Subject: [PATCH] [mlir][Linalg] Fix non-matmul linalg structured ops
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 6 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 123 ++++++++++--------
.../mlir-linalg-ods-yaml-gen.cpp | 3 +-
3 files changed, 71 insertions(+), 61 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e578f4b956ef5e..a90777c82bf63a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -621,7 +621,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
- buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, MatmulOp::getRegionBuilder());
}]>,
OpBuilder<
@@ -629,7 +629,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
"ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
- buildStructuredOp($_builder, $_state, resultTensorTypes,
+ buildMatmulOp($_builder, $_state, resultTensorTypes,
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
}]>,
OpBuilder<
@@ -647,7 +647,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addAttribute("cast", cast);
- buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+ buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
attributes, MatmulOp::getRegionBuilder());
}]>
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c909d13e4314b4..dee8a4e27e6b26 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -169,7 +169,8 @@ getDefaultIndexingMapsForMatmul(MLIRContext *context) {
}
/// Wrapper to return the typical indexing map array attribute for MatmulOp.
-static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
+static SmallVector<Attribute>
+getDefaultMatmulIndexingMapAttr(MLIRContext *context) {
return llvm::map_to_vector(
getDefaultIndexingMapsForMatmul(context),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -179,12 +180,11 @@ static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
/// The result types are derived automatically if `resultTensorTypes` is none.
/// The body of the operation is filled using `regionBuilder`. All ods-gen
/// created structured operations use the method to implement their builders.
-static void buildStructuredOp(
- OpBuilder &b, OperationState &state,
- std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
- ValueRange outputs, ArrayRef<NamedAttribute> attributes,
- RegionBuilderFn regionBuilder,
- std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
+static void buildStructuredOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder) {
// Derive the result types if needed.
SmallVector<Type> derivedResultTypes =
resultTensorTypes.value_or(TypeRange());
@@ -196,6 +196,24 @@ static void buildStructuredOp(
state.addOperands(outputs);
state.addTypes(derivedResultTypes);
+ state.addAttributes(attributes);
+ state.addAttribute(
+ "operandSegmentSizes",
+ b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
+ static_cast<int32_t>(outputs.size())}));
+
+ // Create and fill the region of the structured operation.
+ Region ®ion = *state.addRegion();
+ fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
+ state.attributes.getAttrs(), regionBuilder);
+}
+
+static void
+buildMatmulOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder,
+ std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
// Initialize indexingMaps, for MatmulOp.
SmallVector<Attribute, 3> indexingMapsAttrVal;
if (indexingMaps.has_value()) {
@@ -205,20 +223,11 @@ static void buildStructuredOp(
}
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
} else {
- indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
+ indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr(b.getContext());
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
}
-
- state.addAttributes(attributes);
- state.addAttribute(
- "operandSegmentSizes",
- b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
- static_cast<int32_t>(outputs.size())}));
-
- // Create and fill the region of the structured operation.
- Region ®ion = *state.addRegion();
- fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
- state.attributes.getAttrs(), regionBuilder);
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
}
/// Common parsing used for both named structured ops created by ods-gen and by
@@ -340,39 +349,6 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result,
unsigned numRegionArgs,
RegionBuilderFn regionBuilder) {
-
- SmallVector<Attribute, 3> indexingMapsAttr;
- Attribute mapAttr;
- if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
- if (parser.parseEqual())
- return failure();
-
- if (parser.parseLSquare())
- return failure();
-
- do {
- if (parser.parseAttribute(mapAttr))
- return failure();
- if (!isa<AffineMapAttr>(mapAttr)) {
- return parser.emitError(parser.getCurrentLocation(),
- "expected affine map attribute");
- }
- indexingMapsAttr.push_back(mapAttr);
-
- if (parser.parseOptionalComma())
- break;
- } while (true);
-
- if (parser.parseRSquare())
- return failure();
- }
- // Initialize indexingMaps, if not supplied explicitly.
- if (indexingMapsAttr.empty()) {
- indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
- }
- result.addAttribute("indexing_maps",
- parser.getBuilder().getArrayAttr(indexingMapsAttr));
-
// TODO: Enable when ods-gen supports captures.
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
@@ -3503,9 +3479,11 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
namespace mlir {
namespace linalg {
+
//===----------------------------------------------------------------------===//
// MatMulOp
//===----------------------------------------------------------------------===//
+
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
utils::IteratorType::parallel,
@@ -3520,8 +3498,8 @@ std::string MatmulOp::getLibraryCallName() {
bool MatmulOp::hasDynamicIndexingMaps() { return true; }
-/// Check if the op has broadcast and/or transpose semantic. Returns true if the
-/// user defined indexing maps are not equal to default map.
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
bool MatmulOp::hasUserDefinedMaps() {
SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
@@ -3557,7 +3535,8 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
helper.yieldOutputs(yields);
}
-/// Returns a list of AffineMap with the typical matmul indexing charactristic.
+/// Returns a list of AffineMap with the typical matmul indexing
+/// charactristic.
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
MLIRContext *context = this->getContext();
return getDefaultIndexingMapsForMatmul(context);
@@ -3572,6 +3551,38 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
}
ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+
+ if (parser.parseLSquare())
+ return failure();
+
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr)) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ }
+ indexingMapsAttr.push_back(mapAttr);
+
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+
+ if (parser.parseRSquare())
+ return failure();
+ }
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ indexingMapsAttr = getDefaultMatmulIndexingMapAttr(result.getContext());
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
MatmulOp::getRegionBuilder());
}
@@ -3582,7 +3593,7 @@ void MatmulOp::print(OpAsmPrinter &p) {
elidedAttrs);
SmallVector<Attribute, 3> indexingMaps =
- getDefaultIndexingMapAttr(getContext());
+ getDefaultMatmulIndexingMapAttr(getContext());
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
p << " indexing_maps = [";
llvm::interleaveComma(getIndexingMaps(), p,
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 6be7d4320c6562..80d979864921d6 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -679,8 +679,7 @@ ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
}
void {0}::print(OpAsmPrinter &p) {{
SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes",
- "linalg.memoized_indexing_maps",
- "indexing_maps"};
+ "linalg.memoized_indexing_maps"};
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);
}
More information about the Mlir-commits
mailing list