[Mlir-commits] [mlir] [NFC][MLIR][Linalg] Refactor linalg.matmul tablegen ODS and related C++ code. (PR #116377)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Fri Nov 15 04:04:10 PST 2024


https://github.com/shahidact created https://github.com/llvm/llvm-project/pull/116377



This commit refactors part of the code in preparation for the migration of other *matmul* variants from OpDSL to ODS.
Moves getDefaultIndexingmaps() helper into the MatmulOp class.

>From 55b2bf4b2d258d38f335b180edbe4afc4490d568 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Fri, 15 Nov 2024 03:49:27 -0800
Subject: [PATCH] [NFC][MLIR][Linalg] Refactor linalg.matmul tablegen ODS and
 related c++ code.

This commit refactors part of the code in preparation for the migration
of other *matmul* variant from OpDSL to ODS. Moves getDefaultIndexingmaps()
helper into the MatmulOp class.
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 19 +++++++++---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 29 +++----------------
 2 files changed, 19 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e578f4b956ef5e..aed28c72ea5bde 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -622,7 +622,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
             CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
         buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
-          attributes, MatmulOp::getRegionBuilder());
+          attributes, MatmulOp::getRegionBuilder(),
+          MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
       }]>,
       OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@@ -630,7 +631,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
             CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
         buildStructuredOp($_builder, $_state, resultTensorTypes,
-          inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+          inputs, outputs, attributes, MatmulOp::getRegionBuilder(),
+          MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
       }]>,
       OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@@ -648,7 +650,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
       [{
         $_state.addAttribute("cast", cast);
         buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
-          attributes, MatmulOp::getRegionBuilder());
+          attributes, MatmulOp::getRegionBuilder(),
+          MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
       }]>
 
     ];
@@ -664,7 +667,15 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
                                 Block &block, ArrayRef<NamedAttribute> attrs);
 
       /// Returns a list of AffineMap with the typical matmul indexing charactristic.
-      SmallVector<AffineMap> getDefaultIndexingMaps();
+      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context){
+        AffineExpr d0, d1, d2;
+        SmallVector<AffineMap, 3> indexingMaps;
+        bindDims(context, d0, d1, d2);
+        indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+        indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+        indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+        return indexingMaps;
+      }
 
       /// Returns true if the given broadcast map \p bcastMap is valid for this op.
       bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c909d13e4314b4..259e060933b436 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -155,23 +155,10 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   // iterator_types is an auto-generated method.
 }
 
-/// Helper to create a typical indexing map for MatmulOp. Returns a list of
-/// AffineMap.
-static SmallVector<AffineMap, 3>
-getDefaultIndexingMapsForMatmul(MLIRContext *context) {
-  AffineExpr d0, d1, d2;
-  SmallVector<AffineMap, 3> indexingMaps;
-  bindDims(context, d0, d1, d2);
-  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
-  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
-  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
-  return indexingMaps;
-}
-
 /// Wrapper to return the typical indexing map array attribute for MatmulOp.
 static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
   return llvm::map_to_vector(
-      getDefaultIndexingMapsForMatmul(context),
+      MatmulOp::getDefaultIndexingMaps(context),
       [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
 }
 
@@ -204,9 +191,6 @@ static void buildStructuredOp(
       indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
     }
     state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
-  } else {
-    indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
-    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
   }
 
   state.addAttributes(attributes);
@@ -3481,7 +3465,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
                                                   unsigned opIndex) {
   SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
   SmallVector<AffineMap, 3> defaultIndexingMaps =
-      matmulOp.getDefaultIndexingMaps();
+      matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
 
   auto opIndexingMap = opIndexingMaps[opIndex];
   auto defaultIndexingMap = defaultIndexingMaps[opIndex];
@@ -3523,7 +3507,8 @@ bool MatmulOp::hasDynamicIndexingMaps() { return true; }
 /// Check if the op has broadcast and/or transpose semantic. Returns true if the
 /// user defined indexing maps are not equal to default map.
 bool MatmulOp::hasUserDefinedMaps() {
-  SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
+  SmallVector<AffineMap, 3> defaultMaps =
+      MatmulOp::getDefaultIndexingMaps(this->getContext());
   SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
   return defaultMaps != explicitMaps;
 }
@@ -3557,12 +3542,6 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
   helper.yieldOutputs(yields);
 }
 
-/// Returns a list of AffineMap with the typical matmul indexing charactristic.
-SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
-  MLIRContext *context = this->getContext();
-  return getDefaultIndexingMapsForMatmul(context);
-}
-
 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
   assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");



More information about the Mlir-commits mailing list