[Mlir-commits] [mlir] [NFC][MLIR][Linalg] Refactor linalg.matmul tablegen ODS and related C++ code. (PR #116377)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 15 22:31:48 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Md Asghar Ahmad Shahid (shahidact)

<details>
<summary>Changes</summary>



This commit refactors part of the code in preparation for the migration of other *matmul* variants from OpDSL to ODS.
Moves getDefaultIndexingmaps() helper into the MatmulOp class.

---
Full diff: https://github.com/llvm/llvm-project/pull/116377.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+15-4) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+4-25) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e578f4b956ef5e..aed28c72ea5bde 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -622,7 +622,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
             CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
         buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
-          attributes, MatmulOp::getRegionBuilder());
+          attributes, MatmulOp::getRegionBuilder(),
+          MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
       }]>,
       OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@@ -630,7 +631,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
             CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
         buildStructuredOp($_builder, $_state, resultTensorTypes,
-          inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+          inputs, outputs, attributes, MatmulOp::getRegionBuilder(),
+          MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
       }]>,
       OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@@ -648,7 +650,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
       [{
         $_state.addAttribute("cast", cast);
         buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
-          attributes, MatmulOp::getRegionBuilder());
+          attributes, MatmulOp::getRegionBuilder(),
+          MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
       }]>
 
     ];
@@ -664,7 +667,15 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
                                 Block &block, ArrayRef<NamedAttribute> attrs);
 
       /// Returns a list of AffineMap with the typical matmul indexing charactristic.
-      SmallVector<AffineMap> getDefaultIndexingMaps();
+      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context){
+        AffineExpr d0, d1, d2;
+        SmallVector<AffineMap, 3> indexingMaps;
+        bindDims(context, d0, d1, d2);
+        indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+        indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+        indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+        return indexingMaps;
+      }
 
       /// Returns true if the given broadcast map \p bcastMap is valid for this op.
       bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c909d13e4314b4..259e060933b436 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -155,23 +155,10 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   // iterator_types is an auto-generated method.
 }
 
-/// Helper to create a typical indexing map for MatmulOp. Returns a list of
-/// AffineMap.
-static SmallVector<AffineMap, 3>
-getDefaultIndexingMapsForMatmul(MLIRContext *context) {
-  AffineExpr d0, d1, d2;
-  SmallVector<AffineMap, 3> indexingMaps;
-  bindDims(context, d0, d1, d2);
-  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
-  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
-  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
-  return indexingMaps;
-}
-
 /// Wrapper to return the typical indexing map array attribute for MatmulOp.
 static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
   return llvm::map_to_vector(
-      getDefaultIndexingMapsForMatmul(context),
+      MatmulOp::getDefaultIndexingMaps(context),
       [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
 }
 
@@ -204,9 +191,6 @@ static void buildStructuredOp(
       indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
     }
     state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
-  } else {
-    indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
-    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
   }
 
   state.addAttributes(attributes);
@@ -3481,7 +3465,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
                                                   unsigned opIndex) {
   SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
   SmallVector<AffineMap, 3> defaultIndexingMaps =
-      matmulOp.getDefaultIndexingMaps();
+      matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
 
   auto opIndexingMap = opIndexingMaps[opIndex];
   auto defaultIndexingMap = defaultIndexingMaps[opIndex];
@@ -3523,7 +3507,8 @@ bool MatmulOp::hasDynamicIndexingMaps() { return true; }
 /// Check if the op has broadcast and/or transpose semantic. Returns true if the
 /// user defined indexing maps are not equal to default map.
 bool MatmulOp::hasUserDefinedMaps() {
-  SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
+  SmallVector<AffineMap, 3> defaultMaps =
+      MatmulOp::getDefaultIndexingMaps(this->getContext());
   SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
   return defaultMaps != explicitMaps;
 }
@@ -3557,12 +3542,6 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
   helper.yieldOutputs(yields);
 }
 
-/// Returns a list of AffineMap with the typical matmul indexing charactristic.
-SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
-  MLIRContext *context = this->getContext();
-  return getDefaultIndexingMapsForMatmul(context);
-}
-
 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
   assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");

``````````

</details>


https://github.com/llvm/llvm-project/pull/116377


More information about the Mlir-commits mailing list