[Mlir-commits] [mlir] [NFC][MLIR][Linalg] Refactor linalg.matmul tablegen ODS and related C++ code. (PR #116377)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 15 22:31:48 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Md Asghar Ahmad Shahid (shahidact)
<details>
<summary>Changes</summary>
This commit refactors part of the code in preparation for the migration of other *matmul* variants from OpDSL to ODS.
Moves getDefaultIndexingmaps() helper into the MatmulOp class.
---
Full diff: https://github.com/llvm/llvm-project/pull/116377.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+15-4)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+4-25)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e578f4b956ef5e..aed28c72ea5bde 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -622,7 +622,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
- attributes, MatmulOp::getRegionBuilder());
+ attributes, MatmulOp::getRegionBuilder(),
+ MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@@ -630,7 +631,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, resultTensorTypes,
- inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+ inputs, outputs, attributes, MatmulOp::getRegionBuilder(),
+ MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@@ -648,7 +650,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
[{
$_state.addAttribute("cast", cast);
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
- attributes, MatmulOp::getRegionBuilder());
+ attributes, MatmulOp::getRegionBuilder(),
+ MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
}]>
];
@@ -664,7 +667,15 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
Block &block, ArrayRef<NamedAttribute> attrs);
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
- SmallVector<AffineMap> getDefaultIndexingMaps();
+ static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context){
+ AffineExpr d0, d1, d2;
+ SmallVector<AffineMap, 3> indexingMaps;
+ bindDims(context, d0, d1, d2);
+ indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+ indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+ indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+ return indexingMaps;
+ }
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c909d13e4314b4..259e060933b436 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -155,23 +155,10 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
// iterator_types is an auto-generated method.
}
-/// Helper to create a typical indexing map for MatmulOp. Returns a list of
-/// AffineMap.
-static SmallVector<AffineMap, 3>
-getDefaultIndexingMapsForMatmul(MLIRContext *context) {
- AffineExpr d0, d1, d2;
- SmallVector<AffineMap, 3> indexingMaps;
- bindDims(context, d0, d1, d2);
- indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
- indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
- indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
- return indexingMaps;
-}
-
/// Wrapper to return the typical indexing map array attribute for MatmulOp.
static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
return llvm::map_to_vector(
- getDefaultIndexingMapsForMatmul(context),
+ MatmulOp::getDefaultIndexingMaps(context),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
}
@@ -204,9 +191,6 @@ static void buildStructuredOp(
indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
}
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
- } else {
- indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
- state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
}
state.addAttributes(attributes);
@@ -3481,7 +3465,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
unsigned opIndex) {
SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
SmallVector<AffineMap, 3> defaultIndexingMaps =
- matmulOp.getDefaultIndexingMaps();
+ matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
auto opIndexingMap = opIndexingMaps[opIndex];
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
@@ -3523,7 +3507,8 @@ bool MatmulOp::hasDynamicIndexingMaps() { return true; }
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
/// user defined indexing maps are not equal to default map.
bool MatmulOp::hasUserDefinedMaps() {
- SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
+ SmallVector<AffineMap, 3> defaultMaps =
+ MatmulOp::getDefaultIndexingMaps(this->getContext());
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
return defaultMaps != explicitMaps;
}
@@ -3557,12 +3542,6 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
helper.yieldOutputs(yields);
}
-/// Returns a list of AffineMap with the typical matmul indexing charactristic.
-SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
- MLIRContext *context = this->getContext();
- return getDefaultIndexingMapsForMatmul(context);
-}
-
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
``````````
</details>
https://github.com/llvm/llvm-project/pull/116377
More information about the Mlir-commits
mailing list