[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Wed Mar 12 04:07:28 PDT 2025


https://github.com/shahidact created https://github.com/llvm/llvm-project/pull/130944

…_reduce_matmul.

This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops.

The broadcast and transpose semantic are as follows:

Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified.

    Example Transpose:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

    Example Broadcast:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

    Example Broadcast and Transpose:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 https://github.com/llvm/llvm-project/pull/115319
https://github.com/llvm/llvm-project/pull/122275

>From b5468fcf0c2149d006f62f536bb24f053c6778df Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Mon, 17 Feb 2025 03:05:48 -0800
Subject: [PATCH] [MLIR][Linalg] Introduce transpose/broadcast semantic to
 linalg.batch_reduce_matmul.

This patch exposes broadcast and transpose semantics on
'batch_reduce_matmul'. This is the last one in continuation of
other two variant of matmul ops.

The broadcast and transpose semantic are as follows:

Broadcast and Transpose semantics can be appiled by specifying the
explicit attribute 'indexing_maps' as shown below. This is a list
attribute, so must include maps for all arguments if specified.

    Example Transpose:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

    Example Broadcast:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

    Example Broadcast and Transpose:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
https://github.com/llvm/llvm-project/pull/115319
https://github.com/llvm/llvm-project/pull/122275
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  70 -----
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 131 +++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 262 ++++++++++++++++--
 .../Dialect/Linalg/generalize-named-ops.mlir  |  30 ++
 mlir/test/Dialect/Linalg/invalid.mlir         | 202 ++++++++++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 165 +++++++++++
 6 files changed, 762 insertions(+), 98 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b44af2defc3e4..6344861c53ac5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: BZp
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_reduce_matmul
-  cpp_class_name: BatchReduceMatmulOp
-  doc: |-
-    Performs a batch-reduce matrix multiplication of two 3D inputs.
-    The partial multiplication results are reduced into a 2D output.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
-  iterator_types:
-  - reduction
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matvec
   cpp_class_name: MatvecOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..5191a658bbf26 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -1054,6 +1054,137 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
 }
 
 
+//===----------------------------------------------------------------------===//
+// Op definition for BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
+                          AttrSizedOperandSegments,
+                          LinalgContractionOpInterface]> {
+    
+  let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
+The partial multiplication results are reduced into a 2D output.}];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+    arguments if specified.
+
+    Example Transpose:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                   affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                   ]
+                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+    ```
+
+    Example Broadcast:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+                     outs(%arg2: memref<3x7xf32>)
+    ```
+
+    Example Broadcast and Transpose:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+                     outs(%arg2: memref<3x7xf32>)
+    ```
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<
+        AffineMapArrayAttr,
+        "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+      >:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>
+      
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
+      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps() { return true; };
+      /// Returns true if the user defined indexing maps are not equal to default maps.
+      bool hasUserDefinedMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..d46fbf988d762 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -218,6 +218,23 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
+                                     std::optional<TypeRange> resultTensorTypes,
+                                     ValueRange inputs, ValueRange outputs,
+                                     ArrayRef<NamedAttribute> attributes,
+                                     RegionBuilderFn regionBuilder,
+                                     ArrayRef<AffineMap> indexingMaps) {
+  // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
+  SmallVector<Attribute, 4> indexingMapsAttrVal;
+  indexingMapsAttrVal =
+      llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+        return AffineMapAttr::get(map);
+      });
+  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3464,19 +3481,24 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
-// Check general validity of input indexing map.
-static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+// Check general validity of input indexing map of
+// BatchMatmulOp/BatchReduceMatmulOp.
+template <typename OpTy>
+static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
                                      AffineMap opIndexingMap,
                                      AffineMap defaultIndexingMap, bool isLHS) {
+  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+          isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+         "Expected BatchMatmulOp or BatchReduceMatmulOp");
   // Check the result dims are valid.
   if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Unexpected result dim expression (outside the set of default "
               "result dims).";
 
   // Check for valid number of result dims of input maps.
   if (opIndexingMap.getNumResults() > 3)
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "no. of result dim expressions exceeds 3.";
 
   auto hasValidBatchDim = [](AffineMap map) {
@@ -3486,60 +3508,83 @@ static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
 
   // Check if the requested broadcast is valid.
   if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
-    if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
-      return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+    if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+      return batchVariantMatmulOp->emitOpError()
+             << "Invalid broadcast requested.";
   } else if (!hasValidBatchDim(opIndexingMap)) {
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Invalid batch dimension expression.";
   }
   return success();
 }
 
 /// This function checks if the given AffineMap for the output of a
-/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
-/// dimensions are valid.
-static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
+/// dimensions and if the output map result dimensions are valid.
+template <typename OpTy>
+static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
                                      AffineMap opIndexingMap) {
-  if (opIndexingMap.getNumResults() != 3)
-    return batchMatmulOp->emitOpError()
+  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+          isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+         "Expected BatchMatmulOp or BatchReduceMatmulOp");
+  if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
+      opIndexingMap.getNumResults() != 3) {
+
+    return batchVariantMatmulOp->emitOpError()
            << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
            << ").";
+  }
+  if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
+      opIndexingMap.getNumResults() != 2) {
+    return batchVariantMatmulOp->emitOpError()
+           << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
+           << ").";
+  }
 
-  auto areValidOutputResultDim = [](AffineMap outputMap) {
-    return outputMap.getResult(0).isFunctionOfDim(0) &&
-           outputMap.getResult(1).isFunctionOfDim(1) &&
-           outputMap.getResult(2).isFunctionOfDim(2);
+  auto areValidOutputResultDim = [&](AffineMap outputMap) {
+    return isa<BatchMatmulOp>(batchVariantMatmulOp)
+               ? outputMap.getResult(0).isFunctionOfDim(0) &&
+                     outputMap.getResult(1).isFunctionOfDim(1) &&
+                     outputMap.getResult(2).isFunctionOfDim(2)
+               : outputMap.getResult(0).isFunctionOfDim(1) &&
+                     outputMap.getResult(1).isFunctionOfDim(2);
   };
 
-  if (!areValidOutputResultDim(opIndexingMap))
-    return batchMatmulOp->emitOpError()
+  if (!areValidOutputResultDim(opIndexingMap)) {
+    return batchVariantMatmulOp->emitOpError()
            << "Invalid output map result dimension.";
+  }
 
   return success();
 }
 
 /// Verifies the broadcast and transpose semantic specified by the explicit
-/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
+/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
+/// specified by opIndex.
+template <typename OpTy>
 static LogicalResult
-verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
-                                  unsigned opIndex) {
+verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
+                                         unsigned opIndex) {
   SmallVector<AffineMap, 3> opIndexingMaps =
-      batchMatmulOp.getIndexingMapsArray();
+      batchVariantMatmulOp.getIndexingMapsArray();
   SmallVector<AffineMap, 3> defaultIndexingMaps =
-      batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+      batchVariantMatmulOp.getDefaultIndexingMaps(
+          batchVariantMatmulOp->getContext());
 
   if (opIndexingMaps.size() != 3)
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Indexing_map attribute must have 3 affine maps.";
 
   auto opIndexingMap = opIndexingMaps[opIndex];
   auto defaultIndexingMap = defaultIndexingMaps[opIndex];
 
-  if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+  if (opIndex == 2 &&
+      failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
     return failure();
 
-  if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
-                             opIndex == 0)))
+  if (opIndex != 2 &&
+      failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
+                             defaultIndexingMap, opIndex == 0)))
     return failure();
 
   return success();
@@ -4035,7 +4080,7 @@ LogicalResult BatchMatmulOp::verify() {
     return success();
 
   for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
-    if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+    if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
       return failure();
   }
   return success();
@@ -5340,6 +5385,167 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{
+      utils::IteratorType::reduction, utils::IteratorType::parallel,
+      utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+SmallVector<AffineMap>
+BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+  AffineExpr d0, d1, d2, d3;
+  SmallVector<AffineMap> indexingMaps;
+  bindDims(context, d0, d1, d2, d3);
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
+  return indexingMaps;
+}
+
+unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchReduceMatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchReduceMatmulOp::hasUserDefinedMaps() {
+  SmallVector<AffineMap, 3> defaultMaps =
+      getDefaultIndexingMaps(this->getContext());
+  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+  return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map bcastMap is valid for this op.
+bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
+                                                    bool isLHS) {
+  assert(bcastMap.getNumResults() < 3 &&
+         "Expected less than 3 result dim expr.");
+  bool isValid = false;
+  enum Indices { batchPos, mPos, nPos, kPos };
+  if (bcastMap.getNumResults() == 1) {
+    AffineExpr exp = bcastMap.getResult(0);
+    isValid = exp.isFunctionOfDim(kPos);
+  } else if (bcastMap.getNumResults() == 2) {
+    AffineExpr exp0 = bcastMap.getResult(0);
+    AffineExpr exp1 = bcastMap.getResult(1);
+    isValid = isLHS
+                  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+                  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+  }
+  return isValid;
+}
+
+void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                        ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  auto toType = block.getArgument(2).getType();
+  Value castValA =
+      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
+  Value castValB =
+      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+  Value addVal =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+  yields.push_back(addVal);
+  helper.yieldOutputs(yields);
+}
+
+ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = llvm::map_to_vector(
+        BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+  return ::parseNamedStructuredOp(parser, result,
+                                  BatchReduceMatmulOp::getNumRegionArgs(),
+                                  BatchReduceMatmulOp::getRegionBuilder());
+}
+
+void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
+  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+      BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                           elidedAttrs);
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchReduceMatmulOp::verify() {
+  // Verification of pure batch_reduce_matmul is handled by
+  // verifyStructuredOpInterface().
+  if (!hasUserDefinedMaps())
+    return success();
+
+  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
+    if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
+      return failure();
+  }
+  return success();
+}
+LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
+                                        SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void BatchReduceMatmulOp::getEffects(
+  SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 0ec71c35497b1..feb627e84beca 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1024,6 +1024,36 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul(
+// CHECK-SAME:                                   %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x3x5xf32>,
+// CHECK-SAME:                                   %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x5x7xf32>,
+// CHECK-SAME:                                   %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
+// CHECK:           %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<3x7xf32>) {
+// CHECK:           ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK:             %[[VAL_7:.*]] = arith.mulf %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK:             %[[VAL_8:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK:             linalg.yield %[[VAL_8]] : f32
+// CHECK:           } -> tensor<3x7xf32>
+// CHECK:           return %[[VAL_3]] : tensor<3x7xf32>
+// CHECK:         }
+
+func.func @batch_reduce_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<3x7xf32>) -> tensor<3x7xf32> {
+  %0 = linalg.batch_reduce_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                           ]
+    ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
+    outs(%arg2: tensor<3x7xf32>) -> tensor<3x7xf32>
+  return %0 : tensor<3x7xf32>
+}
+
+// -----
+
 // CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 // CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
 // CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index f2283db8f89b2..978df2acc8a10 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1462,6 +1462,208 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
 }
 
 
+// -----
+
+func.func @indexing_map_size_mismatch_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+     %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+     // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+     linalg.batch_reduce_matmul indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+    ]
+    ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+     %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+     // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+     linalg.batch_reduce_matmul indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+    ]
+    ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+
+}
+
+// -----
+
+func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{expected attribute value}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       ,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                      ]
+                      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+                      outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>) outs(%arg2: memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_A_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_B_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_C_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op expects 2 dims, but got (1).}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_C_map_result_dim_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid output map result dimension.}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
 // -----
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 1bd9c8825b05e..93ef4f0a4a956 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1637,6 +1637,171 @@ func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memre
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(
+// CHECK-SAME:                                                                  %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
+// CHECK-SAME:                                                                  %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                                                  %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+
+func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_batch_dim_A(
+// CHECK-SAME:                                                     %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
+// CHECK-SAME:                                                     %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                                     %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+
+func.func @batch_reduce_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(
+// CHECK-SAME:                                                           %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                                           %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
+// CHECK-SAME:                                                           %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+
+func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_batch_dim_B(
+// CHECK-SAME:                                                     %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                                     %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5x7xf32>,
+// CHECK-SAME:                                                     %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+
+func.func @batch_reduce_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_explicit_transpose_A(
+// CHECK-SAME:                                                        %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x3xf32>,
+// CHECK-SAME:                                                        %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                                        %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+func.func @batch_reduce_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_explicit_transpose_B(
+// CHECK-SAME:                                                        %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                                        %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
+// CHECK-SAME:                                                        %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+func.func @batch_reduce_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_A_transpose_B(
+// CHECK-SAME:                                                       %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
+// CHECK-SAME:                                                       %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
+// CHECK-SAME:                                                       %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+func.func @batch_reduce_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @batchmatmul_transpose_a
 //       CHECK:   linalg.batch_matmul_transpose_a
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)



More information about the Mlir-commits mailing list