[Mlir-commits] [mlir] [mlir][Linalg] Fix non-matmul linalg structured ops (PR #116412)

Kunwar Grover llvmlistbot at llvm.org
Fri Nov 15 09:53:06 PST 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/116412

>From 93cc6ba1f8959be56ab11a678274ec13a137abb4 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 15 Nov 2024 17:26:05 +0000
Subject: [PATCH 1/2] [mlir][Linalg] Fix non-matmul linalg structured ops

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |   6 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 123 ++++++++++--------
 .../mlir-linalg-ods-yaml-gen.cpp              |   3 +-
 3 files changed, 71 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e578f4b956ef5e..a90777c82bf63a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -621,7 +621,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
       (ins "ValueRange":$inputs, "ValueRange":$outputs,
             CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
-        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+        buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
           attributes, MatmulOp::getRegionBuilder());
       }]>,
       OpBuilder<
@@ -629,7 +629,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
             "ValueRange":$outputs,
             CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
-        buildStructuredOp($_builder, $_state, resultTensorTypes,
+        buildMatmulOp($_builder, $_state, resultTensorTypes,
           inputs, outputs, attributes, MatmulOp::getRegionBuilder());
       }]>,
       OpBuilder<
@@ -647,7 +647,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
        "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
       [{
         $_state.addAttribute("cast", cast);
-        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+        buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
           attributes, MatmulOp::getRegionBuilder());
       }]>
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c909d13e4314b4..dee8a4e27e6b26 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -169,7 +169,8 @@ getDefaultIndexingMapsForMatmul(MLIRContext *context) {
 }
 
 /// Wrapper to return the typical indexing map array attribute for MatmulOp.
-static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
+static SmallVector<Attribute>
+getDefaultMatmulIndexingMapAttr(MLIRContext *context) {
   return llvm::map_to_vector(
       getDefaultIndexingMapsForMatmul(context),
       [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
@@ -179,12 +180,11 @@ static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
 /// The result types are derived automatically if `resultTensorTypes` is none.
 /// The body of the operation is filled using `regionBuilder`. All ods-gen
 /// created structured operations use the method to implement their builders.
-static void buildStructuredOp(
-    OpBuilder &b, OperationState &state,
-    std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
-    ValueRange outputs, ArrayRef<NamedAttribute> attributes,
-    RegionBuilderFn regionBuilder,
-    std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
+static void buildStructuredOp(OpBuilder &b, OperationState &state,
+                              std::optional<TypeRange> resultTensorTypes,
+                              ValueRange inputs, ValueRange outputs,
+                              ArrayRef<NamedAttribute> attributes,
+                              RegionBuilderFn regionBuilder) {
   // Derive the result types if needed.
   SmallVector<Type> derivedResultTypes =
       resultTensorTypes.value_or(TypeRange());
@@ -196,6 +196,24 @@ static void buildStructuredOp(
   state.addOperands(outputs);
   state.addTypes(derivedResultTypes);
 
+  state.addAttributes(attributes);
+  state.addAttribute(
+      "operandSegmentSizes",
+      b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
+                              static_cast<int32_t>(outputs.size())}));
+
+  // Create and fill the region of the structured operation.
+  Region &region = *state.addRegion();
+  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
+                         state.attributes.getAttrs(), regionBuilder);
+}
+
+static void
+buildMatmulOp(OpBuilder &b, OperationState &state,
+              std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+              ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+              RegionBuilderFn regionBuilder,
+              std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
   // Initialize indexingMaps, for MatmulOp.
   SmallVector<Attribute, 3> indexingMapsAttrVal;
   if (indexingMaps.has_value()) {
@@ -205,20 +223,11 @@ static void buildStructuredOp(
     }
     state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
   } else {
-    indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
+    indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr(b.getContext());
     state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
   }
-
-  state.addAttributes(attributes);
-  state.addAttribute(
-      "operandSegmentSizes",
-      b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
-                              static_cast<int32_t>(outputs.size())}));
-
-  // Create and fill the region of the structured operation.
-  Region &region = *state.addRegion();
-  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
-                         state.attributes.getAttrs(), regionBuilder);
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
 }
 
 /// Common parsing used for both named structured ops created by ods-gen and by
@@ -340,39 +349,6 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result,
                                           unsigned numRegionArgs,
                                           RegionBuilderFn regionBuilder) {
-
-  SmallVector<Attribute, 3> indexingMapsAttr;
-  Attribute mapAttr;
-  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
-    if (parser.parseEqual())
-      return failure();
-
-    if (parser.parseLSquare())
-      return failure();
-
-    do {
-      if (parser.parseAttribute(mapAttr))
-        return failure();
-      if (!isa<AffineMapAttr>(mapAttr)) {
-        return parser.emitError(parser.getCurrentLocation(),
-                                "expected affine map attribute");
-      }
-      indexingMapsAttr.push_back(mapAttr);
-
-      if (parser.parseOptionalComma())
-        break;
-    } while (true);
-
-    if (parser.parseRSquare())
-      return failure();
-  }
-  // Initialize indexingMaps, if not supplied explicitly.
-  if (indexingMapsAttr.empty()) {
-    indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
-  }
-  result.addAttribute("indexing_maps",
-                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
-
   // TODO: Enable when ods-gen supports captures.
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
@@ -3503,9 +3479,11 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
 
 namespace mlir {
 namespace linalg {
+
 //===----------------------------------------------------------------------===//
 // MatMulOp
 //===----------------------------------------------------------------------===//
+
 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
   return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
                                           utils::IteratorType::parallel,
@@ -3520,8 +3498,8 @@ std::string MatmulOp::getLibraryCallName() {
 
 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
 
-/// Check if the op has broadcast and/or transpose semantic. Returns true if the
-/// user defined indexing maps are not equal to default map.
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
 bool MatmulOp::hasUserDefinedMaps() {
   SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
   SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
@@ -3557,7 +3535,8 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
   helper.yieldOutputs(yields);
 }
 
-/// Returns a list of AffineMap with the typical matmul indexing charactristic.
+/// Returns a list of AffineMap with the typical matmul indexing
+/// charactristic.
 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
   MLIRContext *context = this->getContext();
   return getDefaultIndexingMapsForMatmul(context);
@@ -3572,6 +3551,38 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
 }
 
 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = getDefaultMatmulIndexingMapAttr(result.getContext());
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
   return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
                                 MatmulOp::getRegionBuilder());
 }
@@ -3582,7 +3593,7 @@ void MatmulOp::print(OpAsmPrinter &p) {
                          elidedAttrs);
 
   SmallVector<Attribute, 3> indexingMaps =
-      getDefaultIndexingMapAttr(getContext());
+      getDefaultMatmulIndexingMapAttr(getContext());
   if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
     p << " indexing_maps = [";
     llvm::interleaveComma(getIndexingMaps(), p,
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 6be7d4320c6562..80d979864921d6 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -679,8 +679,7 @@ ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
 }
 void {0}::print(OpAsmPrinter &p) {{
   SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes",
-                                           "linalg.memoized_indexing_maps",
-                                           "indexing_maps"};
+                                           "linalg.memoized_indexing_maps"};
   ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
                            elidedAttrs);
 }

>From 442e3f5a864bae75c3b9bcb6cb0a1b7ff3e828fc Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 15 Nov 2024 17:51:02 +0000
Subject: [PATCH 2/2] Fix tests

---
 .../Dialect/Linalg/rank-reduce-contraction-ops.mlir  | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index c086d0fd7e6332..ebdbe70ff46eb7 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -43,7 +43,8 @@ func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<
   //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1], [2]]
   //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
   //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
-  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>)
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec
+  //  CHECK-SAME:   ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<128x512xf32>, tensor<512xf32>) outs(%[[COLLAPSED_INIT]] : tensor<128xf32>)
   //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, 128]
   //  CHECK-NEXT:   return %[[RES]]
   %1 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<1x128x512xf32>, tensor<1x512xf32>)
@@ -62,7 +63,8 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x
   //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
   //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
   //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
-  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.vecmat 
+  //  CHECK-SAME:   ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
   //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
   //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
   //  CHECK-NEXT:   return %[[RES]]
@@ -113,7 +115,8 @@ func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32
   //  CHECK-DAG:    %[[C0:.*]] = arith.constant 0
   //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1]]
   //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
-  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec 
+  //  CHECK-SAME:   ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
   //  CHECK-NEXT:   %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]]
   //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1]
   //  CHECK-NEXT:   return %[[RES]]
@@ -140,7 +143,8 @@ func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32
   //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
   //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
   //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
-  //  CHECK-NEXT:   %[[RESULT:.*]] = linalg.vecmat ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
+  //  CHECK-NEXT:   %[[RESULT:.*]] = linalg.vecmat 
+  //  CHECK-SAME:   ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
   //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
   //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
   //  CHECK-NEXT:   return %[[RES]]



More information about the Mlir-commits mailing list