[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Sat May 3 06:51:31 PDT 2025


https://github.com/shahidact updated https://github.com/llvm/llvm-project/pull/130944

>From 139edbcf9b88ea93ed184dc7be34c167ee3b9e48 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Mon, 17 Feb 2025 03:05:48 -0800
Subject: [PATCH 1/7] [MLIR][Linalg] Introduce transpose/broadcast semantic to
 linalg.batch_reduce_matmul.

This patch exposes broadcast and transpose semantics on
'batch_reduce_matmul'. This is the last one in continuation of
other two variant of matmul ops.

The broadcast and transpose semantic are as follows:

Broadcast and Transpose semantics can be appiled by specifying the
explicit attribute 'indexing_maps' as shown below. This is a list
attribute, so must include maps for all arguments if specified.

    Example Transpose:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

    Example Broadcast:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

    Example Broadcast and Transpose:
    ```
    linalg.batch_reduce_matmul indexing_maps = [
       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
       ]
          ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
          outs(%arg2: memref<3x7xf32>)
    ```

RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
https://github.com/llvm/llvm-project/pull/115319
https://github.com/llvm/llvm-project/pull/122275
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  70 -----
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 131 +++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 261 ++++++++++++++++--
 .../Dialect/Linalg/generalize-named-ops.mlir  |  30 ++
 mlir/test/Dialect/Linalg/invalid.mlir         | 202 ++++++++++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 165 +++++++++++
 6 files changed, 761 insertions(+), 98 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b44af2defc3e4..6344861c53ac5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1717,76 +1717,6 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: BZp
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_reduce_matmul
-  cpp_class_name: BatchReduceMatmulOp
-  doc: |-
-    Performs a batch-reduce matrix multiplication of two 3D inputs.
-    The partial multiplication results are reduced into a 2D output.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
-  iterator_types:
-  - reduction
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matvec
   cpp_class_name: MatvecOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f3dbeb274deda..41e37fbba6afa 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -1065,6 +1065,137 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
 }
 
 
+//===----------------------------------------------------------------------===//
+// Op definition for BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
+                          AttrSizedOperandSegments,
+                          LinalgContractionOpInterface]> {
+    
+  let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
+The partial multiplication results are reduced into a 2D output.}];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+    arguments if specified.
+
+    Example Transpose:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                   affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                   ]
+                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+    ```
+
+    Example Broadcast:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+                     outs(%arg2: memref<3x7xf32>)
+    ```
+
+    Example Broadcast and Transpose:
+    ```
+    linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+                     outs(%arg2: memref<3x7xf32>)
+    ```
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<
+        AffineMapArrayAttr,
+        "BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext())"
+      >:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchReduceMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildBatchReduceMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, BatchReduceMatmulOp::getRegionBuilder(),
+          BatchReduceMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>
+      
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
+      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps() { return true; };
+      /// Returns true if the user defined indexing maps are not equal to default maps.
+      bool hasUserDefinedMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 089ccc6680e48..404db532454c7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -220,6 +220,23 @@ static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
+                                     std::optional<TypeRange> resultTensorTypes,
+                                     ValueRange inputs, ValueRange outputs,
+                                     ArrayRef<NamedAttribute> attributes,
+                                     RegionBuilderFn regionBuilder,
+                                     ArrayRef<AffineMap> indexingMaps) {
+  // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
+  SmallVector<Attribute, 4> indexingMapsAttrVal;
+  indexingMapsAttrVal =
+      llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+        return AffineMapAttr::get(map);
+      });
+  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3485,19 +3502,24 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
-// Check general validity of input indexing map.
-static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+// Check general validity of input indexing map of
+// BatchMatmulOp/BatchReduceMatmulOp.
+template <typename OpTy>
+static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
                                      AffineMap opIndexingMap,
                                      AffineMap defaultIndexingMap, bool isLHS) {
+  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+          isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+         "Expected BatchMatmulOp or BatchReduceMatmulOp");
   // Check the result dims are valid.
   if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Unexpected result dim expression (outside the set of default "
               "result dims).";
 
   // Check for valid number of result dims of input maps.
   if (opIndexingMap.getNumResults() > 3)
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "no. of result dim expressions exceeds 3.";
 
   auto hasValidBatchDim = [](AffineMap map) {
@@ -3507,60 +3529,83 @@ static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
 
   // Check if the requested broadcast is valid.
   if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
-    if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
-      return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+    if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+      return batchVariantMatmulOp->emitOpError()
+             << "Invalid broadcast requested.";
   } else if (!hasValidBatchDim(opIndexingMap)) {
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Invalid batch dimension expression.";
   }
   return success();
 }
 
 /// This function checks if the given AffineMap for the output of a
-/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
-/// dimensions are valid.
-static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
+/// dimensions and if the output map result dimensions are valid.
+template <typename OpTy>
+static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
                                      AffineMap opIndexingMap) {
-  if (opIndexingMap.getNumResults() != 3)
-    return batchMatmulOp->emitOpError()
+  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
+          isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
+         "Expected BatchMatmulOp or BatchReduceMatmulOp");
+  if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
+      opIndexingMap.getNumResults() != 3) {
+
+    return batchVariantMatmulOp->emitOpError()
            << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
            << ").";
+  }
+  if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
+      opIndexingMap.getNumResults() != 2) {
+    return batchVariantMatmulOp->emitOpError()
+           << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
+           << ").";
+  }
 
-  auto areValidOutputResultDim = [](AffineMap outputMap) {
-    return outputMap.getResult(0).isFunctionOfDim(0) &&
-           outputMap.getResult(1).isFunctionOfDim(1) &&
-           outputMap.getResult(2).isFunctionOfDim(2);
+  auto areValidOutputResultDim = [&](AffineMap outputMap) {
+    return isa<BatchMatmulOp>(batchVariantMatmulOp)
+               ? outputMap.getResult(0).isFunctionOfDim(0) &&
+                     outputMap.getResult(1).isFunctionOfDim(1) &&
+                     outputMap.getResult(2).isFunctionOfDim(2)
+               : outputMap.getResult(0).isFunctionOfDim(1) &&
+                     outputMap.getResult(1).isFunctionOfDim(2);
   };
 
-  if (!areValidOutputResultDim(opIndexingMap))
-    return batchMatmulOp->emitOpError()
+  if (!areValidOutputResultDim(opIndexingMap)) {
+    return batchVariantMatmulOp->emitOpError()
            << "Invalid output map result dimension.";
+  }
 
   return success();
 }
 
 /// Verifies the broadcast and transpose semantic specified by the explicit
-/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
+/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
+/// specified by opIndex.
+template <typename OpTy>
 static LogicalResult
-verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
-                                  unsigned opIndex) {
+verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
+                                         unsigned opIndex) {
   SmallVector<AffineMap, 3> opIndexingMaps =
-      batchMatmulOp.getIndexingMapsArray();
+      batchVariantMatmulOp.getIndexingMapsArray();
   SmallVector<AffineMap, 3> defaultIndexingMaps =
-      batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+      batchVariantMatmulOp.getDefaultIndexingMaps(
+          batchVariantMatmulOp->getContext());
 
   if (opIndexingMaps.size() != 3)
-    return batchMatmulOp->emitOpError()
+    return batchVariantMatmulOp->emitOpError()
            << "Indexing_map attribute must have 3 affine maps.";
 
   auto opIndexingMap = opIndexingMaps[opIndex];
   auto defaultIndexingMap = defaultIndexingMaps[opIndex];
 
-  if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+  if (opIndex == 2 &&
+      failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
     return failure();
 
-  if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
-                             opIndex == 0)))
+  if (opIndex != 2 &&
+      failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
+                             defaultIndexingMap, opIndex == 0)))
     return failure();
 
   return success();
@@ -4045,7 +4090,7 @@ LogicalResult BatchMatmulOp::verify() {
     return success();
 
   for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
-    if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+    if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
       return failure();
   }
   return success();
@@ -5366,6 +5411,166 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// BatchReduceMatmulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{
+      utils::IteratorType::reduction, utils::IteratorType::parallel,
+      utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+SmallVector<AffineMap>
+BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+  AffineExpr d0, d1, d2, d3;
+  SmallVector<AffineMap> indexingMaps;
+  bindDims(context, d0, d1, d2, d3);
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
+  return indexingMaps;
+}
+
+unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchReduceMatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchReduceMatmulOp::hasUserDefinedMaps() {
+  SmallVector<AffineMap, 3> defaultMaps =
+      getDefaultIndexingMaps(this->getContext());
+  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+  return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map bcastMap is valid for this op.
+bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
+                                                    bool isLHS) {
+  assert(bcastMap.getNumResults() < 3 &&
+         "Expected less than 3 result dim expr.");
+  bool isValid = false;
+  enum Indices { batchPos, mPos, nPos, kPos };
+  if (bcastMap.getNumResults() == 1) {
+    AffineExpr exp = bcastMap.getResult(0);
+    isValid = exp.isFunctionOfDim(kPos);
+  } else if (bcastMap.getNumResults() == 2) {
+    AffineExpr exp0 = bcastMap.getResult(0);
+    AffineExpr exp1 = bcastMap.getResult(1);
+    isValid = isLHS
+                  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+                  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+  }
+  return isValid;
+}
+
+void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                        ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  auto toType = block.getArgument(2).getType();
+  Value castValA =
+      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
+  Value castValB =
+      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+  Value addVal =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+  yields.push_back(addVal);
+  helper.yieldOutputs(yields);
+}
+
+ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = llvm::map_to_vector(
+        BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+  return ::parseNamedStructuredOp(parser, result,
+                                  BatchReduceMatmulOp::getNumRegionArgs(),
+                                  BatchReduceMatmulOp::getRegionBuilder());
+}
+
+void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
+  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+      BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                           elidedAttrs);
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchReduceMatmulOp::verify() {
+  // Verification of pure batch_reduce_matmul is handled by
+  // verifyStructuredOpInterface().
+  if (!hasUserDefinedMaps())
+    return success();
+
+  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
+    if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
+      return failure();
+  }
+  return success();
+}
+LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
+                                        SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void BatchReduceMatmulOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 0ec71c35497b1..feb627e84beca 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1024,6 +1024,36 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul(
+// CHECK-SAME:                                   %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x3x5xf32>,
+// CHECK-SAME:                                   %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x5x7xf32>,
+// CHECK-SAME:                                   %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
+// CHECK:           %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<3x7xf32>) {
+// CHECK:           ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK:             %[[VAL_7:.*]] = arith.mulf %[[VAL_4]], %[[VAL_5]] : f32
+// CHECK:             %[[VAL_8:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK:             linalg.yield %[[VAL_8]] : f32
+// CHECK:           } -> tensor<3x7xf32>
+// CHECK:           return %[[VAL_3]] : tensor<3x7xf32>
+// CHECK:         }
+
+func.func @batch_reduce_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<3x7xf32>) -> tensor<3x7xf32> {
+  %0 = linalg.batch_reduce_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                           ]
+    ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
+    outs(%arg2: tensor<3x7xf32>) -> tensor<3x7xf32>
+  return %0 : tensor<3x7xf32>
+}
+
+// -----
+
 // CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 // CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
 // CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 90ceadebbc1fa..ed51e5784a713 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1484,6 +1484,208 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
 }
 
 
+// -----
+
+func.func @indexing_map_size_mismatch_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+     %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+     // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+     linalg.batch_reduce_matmul indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+    ]
+    ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+     %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+     // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+     linalg.batch_reduce_matmul indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+    ]
+    ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+
+}
+
+// -----
+
+func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{expected attribute value}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       ,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                      ]
+                      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+                      outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>) outs(%arg2: memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_A_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_B_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_C_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op expects 2 dims, but got (1).}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_C_map_result_dim_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid output map result dimension.}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
 // -----
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 1bd9c8825b05e..93ef4f0a4a956 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1637,6 +1637,171 @@ func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memre
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(
+// CHECK-SAME:                                                                  %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
+// CHECK-SAME:                                                                  %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                                                  %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+
+func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_batch_dim_A(
+// CHECK-SAME:                                                     %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
+// CHECK-SAME:                                                     %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                                     %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+
+func.func @batch_reduce_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(
+// CHECK-SAME:                                                           %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                                           %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
+// CHECK-SAME:                                                           %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+
+func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_batch_dim_B(
+// CHECK-SAME:                                                     %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                                     %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5x7xf32>,
+// CHECK-SAME:                                                     %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+
+func.func @batch_reduce_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_explicit_transpose_A(
+// CHECK-SAME:                                                        %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x3xf32>,
+// CHECK-SAME:                                                        %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                                        %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+func.func @batch_reduce_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_explicit_transpose_B(
+// CHECK-SAME:                                                        %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                                        %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
+// CHECK-SAME:                                                        %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+func.func @batch_reduce_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_A_transpose_B(
+// CHECK-SAME:                                                       %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
+// CHECK-SAME:                                                       %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
+// CHECK-SAME:                                                       %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK:           return
+// CHECK:         }
+func.func @batch_reduce_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.batch_reduce_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @batchmatmul_transpose_a
 //       CHECK:   linalg.batch_matmul_transpose_a
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)

>From 5e6e39008858cbf3eab5a6cb1f7a45816809ca9c Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 2 Apr 2025 06:19:29 -0700
Subject: [PATCH 2/7] 1. Improved indentation in op definition description and
 test cases.

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 130 ++++++------
 .../Dialect/Linalg/generalize-named-ops.mlir  |  26 ++-
 mlir/test/Dialect/Linalg/invalid.mlir         | 191 +++++++++---------
 3 files changed, 165 insertions(+), 182 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 41e37fbba6afa..8ba19a42b990f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -689,35 +689,33 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     the maps if specified.
 
     Example Transpose:
-    ```mlir
-    linalg.matmul indexing_maps = [
-                   affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
-                   affine_map<(d0, d1, d2) -> (d2, d1)>,
-                   affine_map<(d0, d1, d2) -> (d0, d1)>
-                   ]
-                   ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
-                   outs(%arg2: memref<3x7xf32>)
+    ```
+    linalg.matmul
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                         affine_map<(d0, d1, d2) -> (d2, d1)>,
+                         affine_map<(d0, d1, d2) -> (d0, d1)>]
+        ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+        outs(%arg2: memref<3x7xf32>)
      ```
 
     Example Broadcast:
-    ```mlir
-    linalg.matmul indexing_maps = [
-                   affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
-                   affine_map<(d0, d1, d2) -> (d2, d1)>,
-                   affine_map<(d0, d1, d2) -> (d0, d1)>
-                  ]
-                  ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
-                  outs(%arg2: memref<3x7xf32>)
+     ```
+    linalg.matmul
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                         affine_map<(d0, d1, d2) -> (d2, d1)>,
+                         affine_map<(d0, d1, d2) -> (d0, d1)>]
+        ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+        outs(%arg2: memref<3x7xf32>)
     ```
 
     Example Broadcast and transpose:
-    ```mlir
-    linalg.matmul indexing_maps = [
-                      affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
-                      affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
-                      affine_map<(d0, d1, d2) -> (d0, d1)>
-                    ]
-                    ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+    ```
+    linalg.matmul
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                         affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                         affine_map<(d0, d1, d2) -> (d0, d1)>]
+        ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>)
+        outs(%arg2: memref<3x7xf32>)
     ```
     }];
 
@@ -953,36 +951,33 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
     arguments if specified.
 
     Example Transpose:
-    ```mlir
-    linalg.batch_matmul indexing_maps = [
-                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
-                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                   affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-                   ]
-                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
-                   outs(%arg2: memref<2x3x7xf32>)
+    ```
+    linalg.batch_matmul
+        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+                         affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                         affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+        ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+        outs(%arg2: memref<2x3x7xf32>)
     ```
 
     Example Broadcast:
-    ```mlir
-    linalg.batch_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
-                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
-                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+    linalg.batch_matmul
+        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
+                         affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                         affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+        ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+        outs(%arg2: memref<2x3x7xf32>)
     ```
 
     Example Broadcast and Transpose:
-    ```mlir
-    linalg.batch_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
-                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
-                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
-                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+    linalg.batch_matmul
+        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
+                         affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+                         affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+        ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+        outs(%arg2: memref<2x3x7xf32>)
     ```
 }];
 
@@ -1074,7 +1069,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
                           LinalgContractionOpInterface]> {
     
   let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
-The partial multiplication results are reduced into a 2D output.}];
+    The partial multiplication results are reduced into a 2D output.}];
   let description = [{
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
@@ -1085,35 +1080,32 @@ The partial multiplication results are reduced into a 2D output.}];
 
     Example Transpose:
     ```
-    linalg.batch_reduce_matmul indexing_maps = [
-                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
-                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                   affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                   ]
-                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
-                   outs(%arg2: memref<3x7xf32>)
+    linalg.batch_reduce_matmul
+        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+                         affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                         affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+        ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+        outs(%arg2: memref<3x7xf32>)
     ```
 
     Example Broadcast:
     ```
-    linalg.batch_reduce_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
-                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
-                     outs(%arg2: memref<3x7xf32>)
+    linalg.batch_reduce_matmul
+        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
+                         affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                         affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+        ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+        outs(%arg2: memref<3x7xf32>)
     ```
 
     Example Broadcast and Transpose:
     ```
-    linalg.batch_reduce_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
-                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
-                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
-                     outs(%arg2: memref<3x7xf32>)
+    linalg.batch_reduce_matmul
+        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
+                         affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+                         affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+        ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+        outs(%arg2: memref<3x7xf32>)
     ```
     }];
 
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index feb627e84beca..e15dfde02b7cc 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1024,22 +1024,20 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
 
 // -----
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
 // CHECK-LABEL:   func.func @batch_reduce_matmul(
-// CHECK-SAME:                                   %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x3x5xf32>,
-// CHECK-SAME:                                   %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<2x5x7xf32>,
-// CHECK-SAME:                                   %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
-// CHECK:           %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<3x7xf32>) {
-// CHECK:           ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
-// CHECK:             %[[VAL_7:.*]] = arith.mulf %[[VAL_4]], %[[VAL_5]] : f32
-// CHECK:             %[[VAL_8:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
-// CHECK:             linalg.yield %[[VAL_8]] : f32
-// CHECK:           } -> tensor<3x7xf32>
-// CHECK:           return %[[VAL_3]] : tensor<3x7xf32>
-// CHECK:         }
+// CHECK-SAME:        %[[ARG_A:.*]]: tensor<2x3x5xf32>,
+// CHECK-SAME:        %[[ARG_B:.*]]: tensor<2x5x7xf32>,
+// CHECK-SAME:        %[[ARG_C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
+// CHECK:           linalg.generic
+// CHECK-SAME:          indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]],
+// CHECK-SAME:          iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+// CHECK:           linalg.yield
 
 func.func @batch_reduce_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<3x7xf32>) -> tensor<3x7xf32> {
   %0 = linalg.batch_reduce_matmul indexing_maps = [
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index ed51e5784a713..a29bbb4c559b5 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1486,16 +1486,15 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
 
 // -----
 
-func.func @indexing_map_size_mismatch_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
      %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
      // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
-     linalg.batch_reduce_matmul indexing_maps = [
-      affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-      affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-    ]
-    ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
-    outs(%arg2: memref<?x?xf32>)
-    return
+     linalg.batch_reduce_matmul
+         indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                          affine_map<(batch, m, n, k) -> (batch, n, k)>]
+         ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+         outs(%arg2: memref<?x?xf32>)
+     return
 }
 
 // -----
@@ -1503,12 +1502,11 @@ func.func @indexing_map_size_mismatch_batch_reduce_matmul(%arg0: memref<?x?x?xf3
 func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
      %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
      // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
-     linalg.batch_reduce_matmul indexing_maps = [
-      affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-    ]
-    ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
-    outs(%arg2: memref<?x?xf32>)
-    return
+     linalg.batch_reduce_matmul
+         indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>]
+         ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+         outs(%arg2: memref<?x?xf32>)
+     return
 
 }
 
@@ -1518,11 +1516,10 @@ func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %a
   // expected-error @+1 {{expected attribute value}}
   linalg.batch_reduce_matmul indexing_maps = [
                        ,
-                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                      ]
-                      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
-                      outs(%arg2 :memref<?x?xf32>)
+                       affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2 :memref<?x?xf32>)
   return
 }
 
@@ -1530,12 +1527,12 @@ func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %a
 
 func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
-                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, n, k)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2 :memref<?x?xf32>)
   return
 }
 
@@ -1543,12 +1540,12 @@ func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg
 
 func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
-                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, m)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2 :memref<?x?xf32>)
   return
 }
 
@@ -1556,12 +1553,12 @@ func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg
 
 func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d0)>,
-                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>)
+      outs(%arg2: memref<?x?xf32>)
   return
 }
 
@@ -1569,12 +1566,12 @@ func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memr
 
 func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
-                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2: memref<?x?xf32>)
   return
 }
 
@@ -1582,12 +1579,12 @@ func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?x
 
 func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-                       affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
-                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?xf32>)
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                       affine_map<(batch, m, n, k) -> (k, batch)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>)
+      outs(%arg2: memref<?x?xf32>)
   return
 }
 
@@ -1595,12 +1592,12 @@ func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_b(%arg0: memref<?x?x
 
 func.func @invalid_bcast_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-                       affine_map<(d0, d1, d2, d3) -> (d2)>,
-                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>) outs(%arg2: memref<?x?xf32>)
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                       affine_map<(batch, m, n, k) -> (n)>,
+                       affine_map<(batch, m, n, k) -> (batch, m, n)>]
+      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>)
+      outs(%arg2: memref<?x?xf32>)
   return
 }
 
@@ -1608,12 +1605,12 @@ func.func @invalid_bcast_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1:
 
 func.func @invalid_batch_dim_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>,
-                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (m, batch, k)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2 :memref<?x?xf32>)
   return
 }
 
@@ -1621,12 +1618,12 @@ func.func @invalid_batch_dim_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %ar
 
 func.func @invalid_batch_dim_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-                       affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>,
-                       affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                     ]
-                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?xf32>)
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                       affine_map<(batch, m, n, k) -> (n, k, batch)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2 :memref<?x?xf32>)
   return
 }
 
@@ -1634,56 +1631,52 @@ func.func @invalid_batch_dim_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %ar
 
 func.func @invalid_A_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
-                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                            affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                           ]
-    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
-    outs(%arg2: memref<?x?xf32>)
-    return
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k, k)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2: memref<?x?xf32>)
+  return
 }
 
 // -----
 
 func.func @invalid_B_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
-                            affine_map<(d0, d1, d2, d3) -> (d1, d2)>
-                           ]
-    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
-    outs(%arg2: memref<?x?xf32>)
-    return
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, n, k)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2: memref<?x?xf32>)
+  return
 }
 
 // -----
 
 func.func @invalid_C_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_reduce_matmul' op expects 2 dims, but got (1).}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                            affine_map<(d0, d1, d2, d3) -> (d1)>
-                           ]
-    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
-    outs(%arg2: memref<?x?xf32>)
-    return
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                       affine_map<(batch, m, n, k) -> (m)>]
+      ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2: memref<?x?xf32>)
+  return
 }
 
 // -----
 
 func.func @invalid_C_map_result_dim_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid output map result dimension.}}
-  linalg.batch_reduce_matmul indexing_maps = [
-                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
-                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                            affine_map<(d0, d1, d2, d3) -> (d1, d3)>
-                           ]
-    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
-    outs(%arg2: memref<?x?xf32>)
-    return
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                       affine_map<(batch, m, n, k) -> (m, k)>]
+      ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2: memref<?x?xf32>)
+  return
 }
 
 // -----

>From c08c962a989397262f4e19b6496d29c2a0a4130b Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Sat, 12 Apr 2025 22:06:26 -0700
Subject: [PATCH 3/7]  -Added python test for linalg.batch_reduce_matmul

---
 mlir/python/mlir/dialects/linalg/__init__.py |  13 +++
 mlir/test/python/dialects/linalg/ops.py      | 101 +++++++++++++++++++
 2 files changed, 114 insertions(+)

diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 63586a5bb8bbb..9a8a7b40e25e4 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -203,6 +203,19 @@ def batch_matmul(
     )
 
 
+def batch_reduce_matmul(
+    *ins: Union[Operation, OpView, Value],
+    outs: Sequence[Union[Operation, OpView, Value]],
+    indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
+    cast: Optional[Union[TypeFn, Attribute]] = None,
+):
+    return _get_op_result_or_op_results(
+        _create_matmul_like_op(
+            BatchReduceMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
+        )
+    )
+
+
 def contract(
     *ins: Union[Operation, OpView, Value],
     outs: Sequence[Union[Operation, OpView, Value]],
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index e32a911b24b11..79bb576ff6738 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -568,6 +568,107 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
     print(module)
 
 
+# CHECK-LABEL: TEST: testBatchReduceMatmulOp
+ at run
+def testBatchReduceMatmulOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_shape = (5, 4, 8)
+            b_shape = (5, 8, 12)
+            b_transposed_shape = (5, 12, 8)
+            c_shape = (4, 12)
+
+            dimBatch = ir.AffineDimExpr.get(0)
+            dimM = ir.AffineDimExpr.get(1)
+            dimN = ir.AffineDimExpr.get(2)
+            dimK = ir.AffineDimExpr.get(3)
+
+            # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+            # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+            # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+            a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
+            b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
+            c_map = ir.AffineMap.get(4, 0, [dimM, dimN])
+
+            # CHECK: func.func @batch_reduce_matmul_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[A:.*]]: tensor<5x4x8xf32>,
+                RankedTensorType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[Amem:.*]]: memref<5x4x8xf32>,
+                MemRefType.get(a_shape, f32),
+                # CHECK-SAME:                         %[[B:.*]]: tensor<5x8x12xf32>,
+                RankedTensorType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[Bmem:.*]]: memref<5x8x12xf32>,
+                MemRefType.get(b_shape, f32),
+                # CHECK-SAME:                         %[[BTrans:.*]]: tensor<5x12x8xf32>,
+                RankedTensorType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[BTransmem:.*]]: memref<5x12x8xf32>,
+                MemRefType.get(b_transposed_shape, f32),
+                # CHECK-SAME:                         %[[C:.*]]: tensor<4x12xf32>,
+                RankedTensorType.get(c_shape, f32),
+                # CHECK-SAME:                         %[[Cmem:.*]]: memref<4x12xf32>)
+                MemRefType.get(c_shape, f32),
+            )
+            def batch_reduce_matmul_op(
+                A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
+            ):
+                # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                res = linalg.BatchReduceMatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, B),
+                    outputs=(C,),
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                res = linalg.batch_reduce_matmul(A, B, outs=(C,))
+
+                # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                res = linalg.BatchReduceMatmulOp(
+                    result_tensors=(C.type,),
+                    inputs=(A, Btransposed),
+                    outputs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
+                res = linalg.batch_reduce_matmul(
+                    A,
+                    Btransposed,
+                    outs=(C,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+
+                # CHECK: linalg.batch_reduce_matmul ins(%[[Amem]], %[[Bmem]] : memref<5x4x8xf32>, memref<5x8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                res = linalg.BatchReduceMatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Bmem),
+                    outputs=(Cmem,),
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_reduce_matmul ins(%[[Amem]], %[[Bmem]] : memref<5x4x8xf32>, memref<5x8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.batch_reduce_matmul(Amem, Bmem, outs=(Cmem,))
+
+                # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<5x4x8xf32>, memref<5x12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                res = linalg.BatchReduceMatmulOp(
+                    result_tensors=[],
+                    inputs=(Amem, Btransposedmem),
+                    outputs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+                linalg.fill_builtin_region(res.operation)
+                # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<5x4x8xf32>, memref<5x12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
+                linalg.batch_reduce_matmul(
+                    Amem,
+                    Btransposedmem,
+                    outs=(Cmem,),
+                    indexing_maps=[a_map, b_transposed_map, c_map],
+                )
+
+        print(module)
+
+
 # CHECK-LABEL: TEST: testPackUnPackOp
 @run
 def testPackUnPackOp():

>From 9b018a7dffa0c01fb3bb4ece1d7c561fe6ff7c57 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Mon, 28 Apr 2025 10:29:16 -0700
Subject: [PATCH 4/7] -Fixed invalid broadcast test. -Added consistent variable
 and function naming in test cases. -Improved ops indexing_maps description.

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  58 ++++-----
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  18 ++-
 .../Dialect/Linalg/generalize-named-ops.mlir  |  12 +-
 mlir/test/Dialect/Linalg/invalid.mlir         | 112 +++++++++---------
 mlir/test/Dialect/Linalg/named-ops.mlir       | 102 ++++++++--------
 5 files changed, 158 insertions(+), 144 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 8ba19a42b990f..2a9ee0516275e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -691,9 +691,9 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     Example Transpose:
     ```
     linalg.matmul
-        indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
-                         affine_map<(d0, d1, d2) -> (d2, d1)>,
-                         affine_map<(d0, d1, d2) -> (d0, d1)>]
+        indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
+                         affine_map<(m, n, k) -> (k, n)>,
+                         affine_map<(m, n, k) -> (m, n)>]
         ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
         outs(%arg2: memref<3x7xf32>)
      ```
@@ -701,9 +701,9 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     Example Broadcast:
      ```
     linalg.matmul
-        indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
-                         affine_map<(d0, d1, d2) -> (d2, d1)>,
-                         affine_map<(d0, d1, d2) -> (d0, d1)>]
+        indexing_maps = [affine_map<(m, n, k) -> (k)>,     // broadcast
+                         affine_map<(m, n, k) -> (k, n)>,
+                         affine_map<(m, n, k) -> (m, n)>]
         ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
         outs(%arg2: memref<3x7xf32>)
     ```
@@ -711,9 +711,9 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     Example Broadcast and transpose:
     ```
     linalg.matmul
-        indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
-                         affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
-                         affine_map<(d0, d1, d2) -> (d0, d1)>]
+        indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
+                         affine_map<(m, n, k) -> (k)>,    // broadcast
+                         affine_map<(m, n, k) -> (m, n)>]
         ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>)
         outs(%arg2: memref<3x7xf32>)
     ```
@@ -773,7 +773,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
       static void regionBuilder(ImplicitLocOpBuilder &b,
                                 Block &block, ArrayRef<NamedAttribute> attrs);
 
-      /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+      /// Returns a list of AffineMap with the default matmul indexing charactristic.
       static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
 
       /// Returns true if the given broadcast map \p bcastMap is valid for this op.
@@ -953,9 +953,9 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
     Example Transpose:
     ```
     linalg.batch_matmul
-        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
-                         affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                         affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
         ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
         outs(%arg2: memref<2x3x7xf32>)
     ```
@@ -963,9 +963,9 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
     Example Broadcast:
     ```
     linalg.batch_matmul
-        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
-                         affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                         affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+        indexing_maps = [affine_map<(batch, m, n, k) -> (k)>,           // broadcast
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
         ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
         outs(%arg2: memref<2x3x7xf32>)
     ```
@@ -973,9 +973,9 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
     Example Broadcast and Transpose:
     ```
     linalg.batch_matmul
-        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
-                         affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
-                         affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+        indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>,        // broadcast
+                         affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
         ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
         outs(%arg2: memref<2x3x7xf32>)
     ```
@@ -1081,9 +1081,9 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
     Example Transpose:
     ```
     linalg.batch_reduce_matmul
-        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
-                         affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                         affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (m, n)>]
         ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
         outs(%arg2: memref<3x7xf32>)
     ```
@@ -1091,9 +1091,9 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
     Example Broadcast:
     ```
     linalg.batch_reduce_matmul
-        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
-                         affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
-                         affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+        indexing_maps = [affine_map<(batch, m, n, k) -> (k)>,         // broadcast
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (m, n)>]
         ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
         outs(%arg2: memref<3x7xf32>)
     ```
@@ -1101,9 +1101,9 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
     Example Broadcast and Transpose:
     ```
     linalg.batch_reduce_matmul
-        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
-                         affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
-                         affine_map<(d0, d1, d2, d3) -> (d1, d2)>]
+        indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>,        // broadcast
+                         affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
+                         affine_map<(batch, m, n, k) -> (m, n)>]
         ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
         outs(%arg2: memref<3x7xf32>)
     ```
@@ -1163,7 +1163,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
       static void regionBuilder(ImplicitLocOpBuilder &b,
                                 Block &block, ArrayRef<NamedAttribute> attrs);
 
-      /// Returns a list of AffineMap with the typical batch_reducematmul indexing charactristic.
+      /// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
       static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
 
       /// Returns true if the given broadcast map \p bcastMap is valid for this op.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 404db532454c7..9dd4575240f6e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3996,9 +3996,12 @@ bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
   } else if (bcastMap.getNumResults() == 2) {
     AffineExpr exp0 = bcastMap.getResult(0);
     AffineExpr exp1 = bcastMap.getResult(1);
-    isValid = isLHS
-                  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
-                  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+    isValid =
+        isLHS
+            ? ((exp0.isFunctionOfDim(batchPos) || exp0.isFunctionOfDim(mPos)) &&
+               exp1.isFunctionOfDim(kPos))
+            : ((exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(kPos)) ||
+               (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos)));
   }
   return isValid;
 }
@@ -5459,9 +5462,12 @@ bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
   } else if (bcastMap.getNumResults() == 2) {
     AffineExpr exp0 = bcastMap.getResult(0);
     AffineExpr exp1 = bcastMap.getResult(1);
-    isValid = isLHS
-                  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
-                  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+    isValid =
+        isLHS
+            ? ((exp0.isFunctionOfDim(batchPos) || exp0.isFunctionOfDim(mPos)) &&
+               exp1.isFunctionOfDim(kPos))
+            : ((exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(kPos)) ||
+               (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos)));
   }
   return isValid;
 }
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index e15dfde02b7cc..ae07b1b82228c 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1029,9 +1029,9 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
 // CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
 // CHECK-LABEL:   func.func @batch_reduce_matmul(
-// CHECK-SAME:        %[[ARG_A:.*]]: tensor<2x3x5xf32>,
-// CHECK-SAME:        %[[ARG_B:.*]]: tensor<2x5x7xf32>,
-// CHECK-SAME:        %[[ARG_C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
+// CHECK-SAME:        %[[A:.*]]: tensor<2x3x5xf32>,
+// CHECK-SAME:        %[[B:.*]]: tensor<2x5x7xf32>,
+// CHECK-SAME:        %[[C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
 // CHECK:           linalg.generic
 // CHECK-SAME:          indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]],
 // CHECK-SAME:          iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
@@ -1039,14 +1039,14 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
 // CHECK:           arith.addf
 // CHECK:           linalg.yield
 
-func.func @batch_reduce_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<3x7xf32>) -> tensor<3x7xf32> {
+func.func @batch_reduce_matmul(%A: tensor<2x3x5xf32>, %B: tensor<2x5x7xf32>, %C: tensor<3x7xf32>) -> tensor<3x7xf32> {
   %0 = linalg.batch_reduce_matmul indexing_maps = [
                             affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
                             affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                             affine_map<(d0, d1, d2, d3) -> (d1, d2)>
                            ]
-    ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
-    outs(%arg2: tensor<3x7xf32>) -> tensor<3x7xf32>
+    ins(%A, %B: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
+    outs(%C: tensor<3x7xf32>) -> tensor<3x7xf32>
   return %0 : tensor<3x7xf32>
 }
 
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a29bbb4c559b5..04c59777d9d7a 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1364,10 +1364,10 @@ func.func @invalid_bcast_batch_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x
 
 // -----
 
-func.func @invalid_multi_dim_bcast_expr_batch_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+func.func @invalid_single_dim_bcast_expr_batch_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
   linalg.batch_matmul indexing_maps = [
-                       affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                      ]
@@ -1377,14 +1377,14 @@ func.func @invalid_multi_dim_bcast_expr_batch_matmul_a(%arg0: memref<?x?xf32>, %
 
 // -----
 
-func.func @invalid_multi_dim_bcast_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
+func.func @invalid_single_dim_bcast_expr_batch_matmul_B(%A: memref<?x?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?x?xf32>) {
   // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
   linalg.batch_matmul indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                      ]
-                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+                     ins(%A, %B : memref<?x?x?xf32>, memref<?x?xf32>) outs(%C: memref<?x?x?xf32>)
   return
 }
 
@@ -1486,7 +1486,11 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
 
 // -----
 
-func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+//===----------------------------------------------------------------------===//
+// linalg.batch_reduce_matmul
+//===----------------------------------------------------------------------===//
+
+func.func @missing_one_indexing_map(%arg0: memref<?x?x?xf32>,
      %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
      // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
      linalg.batch_reduce_matmul
@@ -1499,7 +1503,7 @@ func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
 
 // -----
 
-func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+func.func @missing_two_indexing_map(%arg0: memref<?x?x?xf32>,
      %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
      // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
      linalg.batch_reduce_matmul
@@ -1512,7 +1516,7 @@ func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
 
 // -----
 
-func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+func.func @missing_indexing_map(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
   // expected-error @+1 {{expected attribute value}}
   linalg.batch_reduce_matmul indexing_maps = [
                        ,
@@ -1525,157 +1529,157 @@ func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %a
 
 // -----
 
-func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+func.func @invalid_dim_expr_A(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
   // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
   linalg.batch_reduce_matmul
       indexing_maps = [affine_map<(batch, m, n, k) -> (batch, n, k)>,
                        affine_map<(batch, m, n, k) -> (batch, k, n)>,
                        affine_map<(batch, m, n, k) -> (m, n)>]
-      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
-      outs(%arg2 :memref<?x?xf32>)
+      ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%C :memref<?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+func.func @invalid_dim_expr_B(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
   // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
   linalg.batch_reduce_matmul
       indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
                        affine_map<(batch, m, n, k) -> (batch, k, m)>,
                        affine_map<(batch, m, n, k) -> (m, n)>]
-      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
-      outs(%arg2 :memref<?x?xf32>)
+      ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%C :memref<?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
-  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+func.func @invalid_bcast_A(%A: memref<?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  // expected-error @+1 {{Invalid broadcast requested}}
   linalg.batch_reduce_matmul
       indexing_maps = [affine_map<(batch, m, n, k) -> (batch)>,
                        affine_map<(batch, m, n, k) -> (batch, k, n)>,
                        affine_map<(batch, m, n, k) -> (m, n)>]
-      ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>)
-      outs(%arg2: memref<?x?xf32>)
+      ins(%A, %B : memref<?xf32>, memref<?x?x?xf32>)
+      outs(%C: memref<?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
-  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+func.func @invalid_multi_dim_bcast_expr_A(%A: memref<?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  // expected-error @+1 {{Invalid broadcast requested}}
   linalg.batch_reduce_matmul
-      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k)>,
+      indexing_maps = [affine_map<(batch, m, n, k) -> (k, batch)>,
                        affine_map<(batch, m, n, k) -> (batch, k, n)>,
                        affine_map<(batch, m, n, k) -> (m, n)>]
-      ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>)
-      outs(%arg2: memref<?x?xf32>)
+      ins(%A, %B : memref<?x?xf32>, memref<?x?x?xf32>)
+      outs(%C: memref<?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
-  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+func.func @invalid_multi_dim_bcast_expr_B(%A: memref<?x?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
+  // expected-error @+1 {{Invalid broadcast requested}}
   linalg.batch_reduce_matmul
       indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
                        affine_map<(batch, m, n, k) -> (k, batch)>,
                        affine_map<(batch, m, n, k) -> (m, n)>]
-      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>)
-      outs(%arg2: memref<?x?xf32>)
+      ins(%A, %B : memref<?x?x?xf32>, memref<?x?xf32>)
+      outs(%C: memref<?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_bcast_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?xf32>) {
-  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+func.func @invalid_bcast_B(%A: memref<?x?x?xf32>, %B: memref<?xf32>, %C: memref<?x?xf32>) {
+  // expected-error @+1 {{Invalid broadcast requested}}
   linalg.batch_reduce_matmul
       indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
                        affine_map<(batch, m, n, k) -> (n)>,
                        affine_map<(batch, m, n, k) -> (batch, m, n)>]
-      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>)
-      outs(%arg2: memref<?x?xf32>)
+      ins(%A, %B : memref<?x?x?xf32>, memref<?xf32>)
+      outs(%C: memref<?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_batch_dim_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
-  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+func.func @invalid_batch_dim_A(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  // expected-error @+1 {{Invalid batch dimension expression}}
   linalg.batch_reduce_matmul
       indexing_maps = [affine_map<(batch, m, n, k) -> (m, batch, k)>,
                        affine_map<(batch, m, n, k) -> (batch, k, n)>,
                        affine_map<(batch, m, n, k) -> (m, n)>]
-      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
-      outs(%arg2 :memref<?x?xf32>)
+      ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%C :memref<?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_batch_dim_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
-  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid batch dimension expression}}
+func.func @invalid_batch_dim_B(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  // expected-error @+1 {{Invalid batch dimension expression}}
   linalg.batch_reduce_matmul
       indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
                        affine_map<(batch, m, n, k) -> (n, k, batch)>,
                        affine_map<(batch, m, n, k) -> (m, n)>]
-      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
-      outs(%arg2 :memref<?x?xf32>)
+      ins(%A, %B : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%C :memref<?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_A_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
-  // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+func.func @invalid_A_map_result_num(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  // expected-error @+1 {{no. of result dim expressions exceeds 3.}}
   linalg.batch_reduce_matmul
       indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k, k)>,
                        affine_map<(batch, m, n, k) -> (batch, k, n)>,
                        affine_map<(batch, m, n, k) -> (m, n)>]
-      ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
-      outs(%arg2: memref<?x?xf32>)
+      ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%C: memref<?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_B_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
-  // expected-error @+1 {{'linalg.batch_reduce_matmul' op no. of result dim expressions exceeds 3.}}
+func.func @invalid_B_map_result_num(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  // expected-error @+1 {{no. of result dim expressions exceeds 3.}}
   linalg.batch_reduce_matmul
       indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
                        affine_map<(batch, m, n, k) -> (batch, k, n, k)>,
                        affine_map<(batch, m, n, k) -> (m, n)>]
-      ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
-      outs(%arg2: memref<?x?xf32>)
+      ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%C: memref<?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_C_map_result_num_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
-  // expected-error @+1 {{'linalg.batch_reduce_matmul' op expects 2 dims, but got (1).}}
+func.func @invalid_C_map_result_num(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  // expected-error @+1 {{expects 2 dims, but got (1).}}
   linalg.batch_reduce_matmul
       indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
                        affine_map<(batch, m, n, k) -> (batch, k, n)>,
                        affine_map<(batch, m, n, k) -> (m)>]
-      ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
-      outs(%arg2: memref<?x?xf32>)
+      ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%C: memref<?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_C_map_result_dim_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
-  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid output map result dimension.}}
+func.func @invalid_C_map_result_dim(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?xf32>) {
+  // expected-error @+1 {{Invalid output map result dimension.}}
   linalg.batch_reduce_matmul
       indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
                        affine_map<(batch, m, n, k) -> (batch, k, n)>,
                        affine_map<(batch, m, n, k) -> (m, k)>]
-      ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
-      outs(%arg2: memref<?x?xf32>)
+      ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%C: memref<?x?xf32>)
   return
 }
 
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 93ef4f0a4a956..470bc1c78640c 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1637,25 +1637,29 @@ func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memre
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// linalg.batch_reduce_matmul
+//===----------------------------------------------------------------------===//
+
 // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(
-// CHECK-SAME:                                                                  %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
-// CHECK-SAME:                                                                  %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
-// CHECK-SAME:                                                                  %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL:   func.func @bcast_k_to_fill_missing_dims_A(
+// CHECK-SAME:      %[[A:.*]]: memref<5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[C]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 
-func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+func.func @bcast_k_to_fill_missing_dims_A(%A: memref<5xf32>, %B: memref<2x5x7xf32>, %C: memref<3x7xf32>) {
   linalg.batch_reduce_matmul indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                        affine_map<(d0, d1, d2, d3) -> (d1, d2)>
                      ]
-                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+                     ins(%A, %B : memref<5xf32>, memref<2x5x7xf32>) outs(%C: memref<3x7xf32>)
   return
 }
 
@@ -1665,21 +1669,21 @@ func.func @batch_reduce_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf3
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_batch_dim_A(
-// CHECK-SAME:                                                     %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
-// CHECK-SAME:                                                     %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
-// CHECK-SAME:                                                     %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL:   func.func @bcast_batch_dim_A(
+// CHECK-SAME:      %[[A:.*]]: memref<3x5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[C]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 
-func.func @batch_reduce_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+func.func @bcast_batch_dim_A(%A: memref<3x5xf32>, %B: memref<2x5x7xf32>, %C: memref<3x7xf32>) {
   linalg.batch_reduce_matmul indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                        affine_map<(d0, d1, d2, d3) -> (d1, d2)>
                      ]
-                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+                     ins(%A, %B : memref<3x5xf32>, memref<2x5x7xf32>) outs(%C: memref<3x7xf32>)
   return
 }
 
@@ -1689,21 +1693,21 @@ func.func @batch_reduce_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1:
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(
-// CHECK-SAME:                                                           %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
-// CHECK-SAME:                                                           %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5xf32>,
-// CHECK-SAME:                                                           %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL:   func.func @bcast_batch_and_n_dim_B(
+// CHECK-SAME:      %[[A:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<5xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[C]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 
-func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+func.func @bcast_batch_and_n_dim_B(%A: memref<2x3x5xf32>, %B: memref<5xf32>, %C: memref<3x7xf32>) {
   linalg.batch_reduce_matmul indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d1, d2)>
                      ]
-                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+                     ins(%A, %B : memref<2x3x5xf32>, memref<5xf32>) outs(%C: memref<3x7xf32>)
   return
 }
 
@@ -1713,21 +1717,21 @@ func.func @batch_reduce_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>,
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_batch_dim_B(
-// CHECK-SAME:                                                     %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
-// CHECK-SAME:                                                     %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<5x7xf32>,
-// CHECK-SAME:                                                     %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL:   func.func @bcast_batch_dim_B(
+// CHECK-SAME:      %[[A:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<5x7xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[C]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
 
-func.func @batch_reduce_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+func.func @bcast_batch_dim_B(%A: memref<2x3x5xf32>, %B: memref<5x7xf32>, %C: memref<3x7xf32>) {
   linalg.batch_reduce_matmul indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
                        affine_map<(d0, d1, d2, d3) -> (d1, d2)>
                      ]
-                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+                     ins(%A, %B : memref<2x3x5xf32>, memref<5x7xf32>) outs(%C: memref<3x7xf32>)
   return
 }
 
@@ -1737,20 +1741,20 @@ func.func @batch_reduce_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-// CHECK-LABEL:   func.func @batch_reduce_matmul_explicit_transpose_A(
-// CHECK-SAME:                                                        %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x3xf32>,
-// CHECK-SAME:                                                        %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x5x7xf32>,
-// CHECK-SAME:                                                        %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL:   func.func @explicit_transpose_A(
+// CHECK-SAME:      %[[A:.*]]: memref<2x5x3xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%[[C]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
-func.func @batch_reduce_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<3x7xf32>) {
+func.func @explicit_transpose_A(%A: memref<2x5x3xf32>, %B: memref<2x5x7xf32>, %C: memref<3x7xf32>) {
   linalg.batch_reduce_matmul indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                        affine_map<(d0, d1, d2, d3) -> (d1, d2)>
                      ]
-                     ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>)
+                     ins(%A, %B : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%C: memref<3x7xf32>)
   return
 }
 
@@ -1760,20 +1764,20 @@ func.func @batch_reduce_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %a
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-// CHECK-LABEL:   func.func @batch_reduce_matmul_explicit_transpose_B(
-// CHECK-SAME:                                                        %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x3x5xf32>,
-// CHECK-SAME:                                                        %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
-// CHECK-SAME:                                                        %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL:   func.func @explicit_transpose_B(
+// CHECK-SAME:      %[[A:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<2x7x5xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%[[C]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
-func.func @batch_reduce_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+func.func @explicit_transpose_B(%A: memref<2x3x5xf32>, %B: memref<2x7x5xf32>, %C: memref<3x7xf32>) {
   linalg.batch_reduce_matmul indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d1, d2)>
                      ]
-                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+                     ins(%A, %B : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%C: memref<3x7xf32>)
   return
 }
 
@@ -1783,20 +1787,20 @@ func.func @batch_reduce_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %a
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-// CHECK-LABEL:   func.func @batch_reduce_matmul_bcast_A_transpose_B(
-// CHECK-SAME:                                                       %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x5xf32>,
-// CHECK-SAME:                                                       %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<2x7x5xf32>,
-// CHECK-SAME:                                                       %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<3x7xf32>) {
-// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
+// CHECK-LABEL:   func.func @bcast_A_transpose_B(
+// CHECK-SAME:      %[[A:.*]]: memref<3x5xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<2x7x5xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.batch_reduce_matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[A]], %[[B]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[C]] : memref<3x7xf32>)
 // CHECK:           return
 // CHECK:         }
-func.func @batch_reduce_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<3x7xf32>) {
+func.func @bcast_A_transpose_B(%A: memref<3x5xf32>, %B: memref<2x7x5xf32>, %C: memref<3x7xf32>) {
   linalg.batch_reduce_matmul indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d1, d2)>
                      ]
-                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>)
+                     ins(%A, %B : memref<3x5xf32>, memref<2x7x5xf32>) outs(%C: memref<3x7xf32>)
   return
 }
 

>From 938a6f7d3cf4aef62c562767ec64a2da6a473016 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Tue, 29 Apr 2025 23:50:27 -0700
Subject: [PATCH 5/7] -Formats example IR for the linalg.*matmul ops for syntax
 highlighting.

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td   | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 2a9ee0516275e..1aefd7e7f428d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -689,7 +689,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     the maps if specified.
 
     Example Transpose:
-    ```
+    ```mlir
     linalg.matmul
         indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
                          affine_map<(m, n, k) -> (k, n)>,
@@ -699,7 +699,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
      ```
 
     Example Broadcast:
-     ```
+     ```mlir
     linalg.matmul
         indexing_maps = [affine_map<(m, n, k) -> (k)>,     // broadcast
                          affine_map<(m, n, k) -> (k, n)>,
@@ -709,7 +709,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     ```
 
     Example Broadcast and transpose:
-    ```
+    ```mlir
     linalg.matmul
         indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose
                          affine_map<(m, n, k) -> (k)>,    // broadcast
@@ -951,7 +951,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
     arguments if specified.
 
     Example Transpose:
-    ```
+    ```mlir
     linalg.batch_matmul
         indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
                          affine_map<(batch, m, n, k) -> (batch, k, n)>,
@@ -961,7 +961,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
     ```
 
     Example Broadcast:
-    ```
+    ```mlir
     linalg.batch_matmul
         indexing_maps = [affine_map<(batch, m, n, k) -> (k)>,           // broadcast
                          affine_map<(batch, m, n, k) -> (batch, k, n)>,
@@ -971,7 +971,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
     ```
 
     Example Broadcast and Transpose:
-    ```
+    ```mlir
     linalg.batch_matmul
         indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>,        // broadcast
                          affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose
@@ -1079,7 +1079,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
     arguments if specified.
 
     Example Transpose:
-    ```
+    ```mlir
     linalg.batch_reduce_matmul
         indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>, // transpose
                          affine_map<(batch, m, n, k) -> (batch, k, n)>,
@@ -1089,7 +1089,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
     ```
 
     Example Broadcast:
-    ```
+    ```mlir
     linalg.batch_reduce_matmul
         indexing_maps = [affine_map<(batch, m, n, k) -> (k)>,         // broadcast
                          affine_map<(batch, m, n, k) -> (batch, k, n)>,
@@ -1099,7 +1099,7 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
     ```
 
     Example Broadcast and Transpose:
-    ```
+    ```mlir
     linalg.batch_reduce_matmul
         indexing_maps = [affine_map<(batch, m, n, k) -> (m, k)>,        // broadcast
                          affine_map<(batch, m, n, k) -> (batch, n, k)>, // transpose

>From 44427267c01a1361a73ddff5cfd960db6002d593 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Thu, 1 May 2025 10:02:31 -0700
Subject: [PATCH 6/7] -Updates formatting, comments and variable names.

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  8 +--
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 51 ++++++++++---------
 2 files changed, 32 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 1aefd7e7f428d..61783812920bc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -1068,13 +1068,13 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
                           AttrSizedOperandSegments,
                           LinalgContractionOpInterface]> {
     
-  let summary = [{Performs a batch-reduce matrix multiplication of two 3D inputs.
+  let summary = [{Performs a batch-reduce matrix multiplication on two inputs.
     The partial multiplication results are reduced into a 2D output.}];
   let description = [{
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
+    Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
 
-    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    Broadcast and Transpose semantics can be applied by specifying the explicit attribute
     'indexing_maps' as shown below. This is a list attribute, so must include maps for all
     arguments if specified.
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9dd4575240f6e..cd764e4582d08 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3681,12 +3681,13 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
   helper.yieldOutputs(yields);
 }
 
-/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+/// Returns true if the given bcastMap map is a valid broadcast map. A valid
+/// broadcast map must include K dimension.
 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
   assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
-  AffineExpr exp = bcastMap.getResult(0);
+  AffineExpr expr = bcastMap.getResult(0);
   // Invalid map if the common dimension of matmul not found.
-  return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
+  return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
 }
 
 FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
@@ -3984,24 +3985,26 @@ bool BatchMatmulOp::hasUserDefinedMaps() {
   return defaultMaps != explicitMaps;
 }
 
-/// Returns true if the given broadcast map bcastMap is valid for this op.
+/// Returns true if the given bcastMap map is a valid broadcast map. A valid
+/// broadcast map must include K dimension.
 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
   assert(bcastMap.getNumResults() < 3 &&
          "Expected less than 3 result dim expr.");
   bool isValid = false;
   enum Indices { batchPos, mPos, nPos, kPos };
   if (bcastMap.getNumResults() == 1) {
-    AffineExpr exp = bcastMap.getResult(0);
-    isValid = exp.isFunctionOfDim(kPos);
+    AffineExpr expr = bcastMap.getResult(0);
+    isValid = expr.isFunctionOfDim(kPos);
   } else if (bcastMap.getNumResults() == 2) {
-    AffineExpr exp0 = bcastMap.getResult(0);
-    AffineExpr exp1 = bcastMap.getResult(1);
+    AffineExpr expr0 = bcastMap.getResult(0);
+    AffineExpr expr1 = bcastMap.getResult(1);
     isValid =
-        isLHS
-            ? ((exp0.isFunctionOfDim(batchPos) || exp0.isFunctionOfDim(mPos)) &&
-               exp1.isFunctionOfDim(kPos))
-            : ((exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(kPos)) ||
-               (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos)));
+        isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
+                  expr0.isFunctionOfDim(mPos)) &&
+                 expr1.isFunctionOfDim(kPos))
+              : ((expr0.isFunctionOfDim(batchPos) &&
+                  expr1.isFunctionOfDim(kPos)) ||
+                 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
   }
   return isValid;
 }
@@ -5449,7 +5452,8 @@ bool BatchReduceMatmulOp::hasUserDefinedMaps() {
   return defaultMaps != explicitMaps;
 }
 
-/// Returns true if the given broadcast map bcastMap is valid for this op.
+/// Returns true if the given bcastMap map is a valid broadcast map. A valid
+/// broadcast map must include K dimension.
 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
                                                     bool isLHS) {
   assert(bcastMap.getNumResults() < 3 &&
@@ -5457,17 +5461,18 @@ bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
   bool isValid = false;
   enum Indices { batchPos, mPos, nPos, kPos };
   if (bcastMap.getNumResults() == 1) {
-    AffineExpr exp = bcastMap.getResult(0);
-    isValid = exp.isFunctionOfDim(kPos);
+    AffineExpr expr = bcastMap.getResult(0);
+    isValid = expr.isFunctionOfDim(kPos);
   } else if (bcastMap.getNumResults() == 2) {
-    AffineExpr exp0 = bcastMap.getResult(0);
-    AffineExpr exp1 = bcastMap.getResult(1);
+    AffineExpr expr0 = bcastMap.getResult(0);
+    AffineExpr expr1 = bcastMap.getResult(1);
     isValid =
-        isLHS
-            ? ((exp0.isFunctionOfDim(batchPos) || exp0.isFunctionOfDim(mPos)) &&
-               exp1.isFunctionOfDim(kPos))
-            : ((exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(kPos)) ||
-               (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos)));
+        isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
+                  expr0.isFunctionOfDim(mPos)) &&
+                 expr1.isFunctionOfDim(kPos))
+              : ((expr0.isFunctionOfDim(batchPos) &&
+                  expr1.isFunctionOfDim(kPos)) ||
+                 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
   }
   return isValid;
 }

>From 0042622aaee5519f1041fa87a3cd76057fa211d3 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Sat, 3 May 2025 06:50:56 -0700
Subject: [PATCH 7/7] -Adds TODO comment to support inference of K dimension
 for broadcast map from the existing one.

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cd764e4582d08..cee51730bd743 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3683,6 +3683,11 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
 
 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
 /// broadcast map must include K dimension.
+/// TODO: Strict inclusion of K dimension in the broadcast map is not
+/// necessary for both input matrices simultaneously. We can relax this
+/// condition to have K dimension for one input matrix map and infer the K
+/// dimension for other input matrix map from the one already having K
+/// dimension.
 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
   assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
   AffineExpr expr = bcastMap.getResult(0);
@@ -3987,6 +3992,11 @@ bool BatchMatmulOp::hasUserDefinedMaps() {
 
 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
 /// broadcast map must include K dimension.
+/// TODO: Strict inclusion of K dimension in the broadcast map is not
+/// necessary for both input matrices simultaneously. We can relax this
+/// condition to have K dimension for one input matrix map and infer the K
+/// dimension for other input matrix map from the one already having K
+/// dimension.
 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
   assert(bcastMap.getNumResults() < 3 &&
          "Expected less than 3 result dim expr.");
@@ -5454,6 +5464,11 @@ bool BatchReduceMatmulOp::hasUserDefinedMaps() {
 
 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
 /// broadcast map must include K dimension.
+/// TODO: Strict inclusion of K dimension in the broadcast map is not
+/// necessary for both input matrices simultaneously. We can relax this
+/// condition to have K dimension for one input matrix map and infer the K
+/// dimension for other input matrix map from the one already having K
+/// dimension.
 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
                                                     bool isLHS) {
   assert(bcastMap.getNumResults() < 3 &&



More information about the Mlir-commits mailing list