[Mlir-commits] [mlir] [NFC][MLIR][Linalg] Refactor linalg.matmul tablegen ODS and related C++ code. (PR #116377)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Sun Nov 17 20:20:05 PST 2024
https://github.com/shahidact updated https://github.com/llvm/llvm-project/pull/116377
>From 0b1cbee389a72e8bcf3da04e1efb94b86da9a50c Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Sun, 17 Nov 2024 20:19:09 -0800
Subject: [PATCH] Resolve conflict.
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 19 ++++++++---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 32 ++++++-------------
2 files changed, 24 insertions(+), 27 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index a90777c82bf63a..ab3349bd3e8436 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -622,7 +622,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
- attributes, MatmulOp::getRegionBuilder());
+ attributes, MatmulOp::getRegionBuilder(),
+ MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@@ -630,7 +631,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildMatmulOp($_builder, $_state, resultTensorTypes,
- inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+ inputs, outputs, attributes, MatmulOp::getRegionBuilder(),
+ MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@@ -648,7 +650,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
[{
$_state.addAttribute("cast", cast);
buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
- attributes, MatmulOp::getRegionBuilder());
+ attributes, MatmulOp::getRegionBuilder(),
+ MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
}]>
];
@@ -664,7 +667,15 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
Block &block, ArrayRef<NamedAttribute> attrs);
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
- SmallVector<AffineMap> getDefaultIndexingMaps();
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context){
+ AffineExpr d0, d1, d2;
+ SmallVector<AffineMap, 3> indexingMaps;
+ bindDims(context, d0, d1, d2);
+ indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+ indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+ indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+ return indexingMaps;
+ }
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index dee8a4e27e6b26..011ce0182deb6e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -168,14 +168,6 @@ getDefaultIndexingMapsForMatmul(MLIRContext *context) {
return indexingMaps;
}
-/// Wrapper to return the typical indexing map array attribute for MatmulOp.
-static SmallVector<Attribute>
-getDefaultMatmulIndexingMapAttr(MLIRContext *context) {
- return llvm::map_to_vector(
- getDefaultIndexingMapsForMatmul(context),
- [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
-}
-
/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
/// The result types are derived automatically if `resultTensorTypes` is none.
/// The body of the operation is filled using `regionBuilder`. All ods-gen
@@ -222,9 +214,6 @@ buildMatmulOp(OpBuilder &b, OperationState &state,
indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
}
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
- } else {
- indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr(b.getContext());
- state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
}
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
attributes, regionBuilder);
@@ -3457,7 +3446,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
unsigned opIndex) {
SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
SmallVector<AffineMap, 3> defaultIndexingMaps =
- matmulOp.getDefaultIndexingMaps();
+ matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
auto opIndexingMap = opIndexingMaps[opIndex];
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
@@ -3501,7 +3490,8 @@ bool MatmulOp::hasDynamicIndexingMaps() { return true; }
/// Check if the op has broadcast and/or transpose semantic. Returns true if
/// the user defined indexing maps are not equal to default map.
bool MatmulOp::hasUserDefinedMaps() {
- SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(this->getContext());
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
return defaultMaps != explicitMaps;
}
@@ -3535,13 +3525,6 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
helper.yieldOutputs(yields);
}
-/// Returns a list of AffineMap with the typical matmul indexing
-/// charactristic.
-SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
- MLIRContext *context = this->getContext();
- return getDefaultIndexingMapsForMatmul(context);
-}
-
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
@@ -3578,7 +3561,9 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
}
// Initialize indexingMaps, if not supplied explicitly.
if (indexingMapsAttr.empty()) {
- indexingMapsAttr = getDefaultMatmulIndexingMapAttr(result.getContext());
+ indexingMapsAttr = llvm::map_to_vector(
+ MatmulOp::getDefaultIndexingMaps(parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
}
result.addAttribute("indexing_maps",
parser.getBuilder().getArrayAttr(indexingMapsAttr));
@@ -3592,8 +3577,9 @@ void MatmulOp::print(OpAsmPrinter &p) {
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
elidedAttrs);
- SmallVector<Attribute, 3> indexingMaps =
- getDefaultMatmulIndexingMapAttr(getContext());
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ MatmulOp::getDefaultIndexingMaps(getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
p << " indexing_maps = [";
llvm::interleaveComma(getIndexingMaps(), p,
More information about the Mlir-commits
mailing list