[Mlir-commits] [mlir] f2bca9e - [MLIR][Linalg] Introduce broadcast/transpose semantic to batch_matmul (#122275)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 6 11:08:53 PST 2025


Author: Md Asghar Ahmad Shahid
Date: 2025-02-06T19:08:50Z
New Revision: f2bca9e3856bd070320473a016f706bf84a5dd5e

URL: https://github.com/llvm/llvm-project/commit/f2bca9e3856bd070320473a016f706bf84a5dd5e
DIFF: https://github.com/llvm/llvm-project/commit/f2bca9e3856bd070320473a016f706bf84a5dd5e.diff

LOG: [MLIR][Linalg] Introduce broadcast/transpose semantic to batch_matmul (#122275)

Goals:
1. To add syntax and semantic to 'batch_matmul' without changing any of
the existing syntax expectations for current usage. batch_matmul is
still just batch_matmul.

2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS
infra.

Scope of this patch:
To expose broadcast and transpose semantics on the 'batch_matmul'.

The broadcast and transpose semantic are as follows:

By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast
and Transpose semantics can be applied by specifying the explicit
attribute 'indexing_maps' as shown below. This is a list attribute, so
the list must include all the maps if specified.

    Example Transpose:
    ```
    linalg.batch_matmul indexing_maps = [
affine_map< (d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose
                   affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>,
                   affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)>
                   ]
ins (%arg0, %arg1: memref<2x5x3xf32>,memref<2x5x7xf32>)
                   outs (%arg2: memref<2x3x7xf32>)
    ```

    Example Broadcast:
    ```
    linalg.batch_matmul indexing_maps = [
affine_map< (d0, d1, d2, d3) -> (d3)>, //broadcast
                       affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>,
                       affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)>
                     ]
                     ins (%arg0, %arg1: memref<5xf32>,memref<2x5x7xf32>)
                     outs (%arg2: memref<2x3x7xf32>)
    ```

    Example Broadcast and transpose:
    ```
    linalg.batch_matmul indexing_maps = [
affine_map< (d0, d1, d2, d3) -> (d1, d3)>, //broadcast
affine_map< (d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose
                       affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)>
                     ]
ins (%arg0, %arg1: memref<3x5xf32>, memref<2x7x5xf32>)
                     outs (%arg2: memref<2x3x7xf32>)
    ```

RFCs and related PR:

https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149
https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863
https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586
https://github.com/llvm/llvm-project/pull/115319

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/named-ops.mlir
    mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 244db23925ab3c7..98a5fd278a99771 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -710,7 +710,10 @@ def LinalgStructuredInterface
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return true if the user has supplied an explicit indexing maps for this op.
+        Returns true if the user has supplied explicit indexing maps that are
+        
diff erent from default indexing maps for this op. Returns `false` otherwise.
+        Note, if the user define maps that are identical to the default maps,
+        this method returns `false`.
       }],
       /*retTy=*/"bool",
       /*methodName=*/"hasUserDefinedMaps",

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b0ea1f769558163..496a323249e8522 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul
-  cpp_class_name: BatchMatmulOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_matmul_transpose_a
   cpp_class_name: BatchMatmulTransposeAOp

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e3d122189f8b779..110ed7d2fc00e2a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -674,8 +674,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
       static unsigned getNumRegionArgs();
       std::string getLibraryCallName();
       bool hasDynamicIndexingMaps();
-      /// Check if the op has broadcast and/or transpose semantic. Returns true if the
-      /// user defined indexing maps are not equal to default map.
+      /// Returns true if the user defined indexing maps are not equal to default maps.
       bool hasUserDefinedMaps();
     }];
 }
@@ -816,6 +815,129 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
+  /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+    
+  let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
+  let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below. This is a list attribute, so must include maps for all
+    arguments if specified.
+
+    Example Transpose:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                   affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                   ]
+                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+                   outs(%arg2: memref<2x3x7xf32>)
+    ```
+
+    Example Broadcast:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,         // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+
+    Example Broadcast and Transpose:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+}];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, BatchMatmulOp::getRegionBuilder(),
+          BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
+          BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion(),
+        BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
+      }]>
+      
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      /// Returns a list with default AffineMap(s), i.e. without broadcasts and transpositions.
+      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      bool hasDynamicIndexingMaps() { return true; }
+      std::string getLibraryCallName();
+      /// Returns true if the user defined indexing maps are not equal to default maps.
+      bool hasUserDefinedMaps();
+    }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2f1c22b10dd36f9..b50931f15826ce2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -200,6 +200,23 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
+                               std::optional<TypeRange> resultTensorTypes,
+                               ValueRange inputs, ValueRange outputs,
+                               ArrayRef<NamedAttribute> attributes,
+                               RegionBuilderFn regionBuilder,
+                               ArrayRef<AffineMap> indexingMaps) {
+  // Initialize indexingMaps attribute, for BatchMatmulOp.
+  SmallVector<Attribute, 4> indexingMapsAttrVal;
+  indexingMapsAttrVal =
+      llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+        return AffineMapAttr::get(map);
+      });
+  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3409,19 +3426,22 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
-/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
-/// defaultMap.
-static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
-  auto explicitRange = explictMap.getResults();
-  auto defaultRange = defaultMap.getResults();
+// Returns true if the result expression of `subMap` are a subset of `fullMap`.
+static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
+  auto explicitRange = subMap.getResults();
+  auto defaultRange = fullMap.getResults();
   DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
   DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
   llvm::set_union(explicitSet, defaultSet);
   return explicitSet == defaultSet;
 }
 
-/// Returns true if the \p explictMap is broadcasted with respect to the
-/// \p defaultMap.
+/// Check if the user defined map is valid broadcast map. Here broadcast
+/// indexing maps are defined in context of corresponding default indexing maps
+/// for the given Op. This way the check becomes very simple i.e just check the
+/// number of result dims.
+/// Returns true if the explictMap is broadcasted with respect to the
+/// defaultMap.
 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
   return explictMap.getNumResults() < defaultMap.getNumResults();
 }
@@ -3438,11 +3458,10 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   auto opIndexingMap = opIndexingMaps[opIndex];
   auto defaultIndexingMap = defaultIndexingMaps[opIndex];
   // Check general validity of indexing map results.
-  if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
     return matmulOp->emitOpError()
            << "Unexpected dim expression in map result.";
 
-  // Check if the requested broadcast is valid.
   if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
     if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
       return matmulOp->emitOpError()
@@ -3453,6 +3472,87 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
+// Check general validity of input indexing map.
+static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+                                     AffineMap opIndexingMap,
+                                     AffineMap defaultIndexingMap, bool isLHS) {
+  // Check the result dims are valid.
+  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
+    return batchMatmulOp->emitOpError()
+           << "Unexpected result dim expression (outside the set of default "
+              "result dims).";
+
+  // Check for valid number of result dims of input maps.
+  if (opIndexingMap.getNumResults() > 3)
+    return batchMatmulOp->emitOpError()
+           << "no. of result dim expressions exceeds 3.";
+
+  auto hasValidBatchDim = [](AffineMap map) {
+    AffineExpr batchDim = map.getResult(0);
+    return batchDim.isFunctionOfDim(0);
+  };
+
+  // Check if the requested broadcast is valid.
+  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+    if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+      return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+  } else if (!hasValidBatchDim(opIndexingMap)) {
+    return batchMatmulOp->emitOpError()
+           << "Invalid batch dimension expression.";
+  }
+  return success();
+}
+
+/// This function checks if the given AffineMap for the output of a
+/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
+/// dimensions are valid.
+static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+                                     AffineMap opIndexingMap) {
+  if (opIndexingMap.getNumResults() != 3)
+    return batchMatmulOp->emitOpError()
+           << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
+           << ").";
+
+  auto areValidOutputResultDim = [](AffineMap outputMap) {
+    return outputMap.getResult(0).isFunctionOfDim(0) &&
+           outputMap.getResult(1).isFunctionOfDim(1) &&
+           outputMap.getResult(2).isFunctionOfDim(2);
+  };
+
+  if (!areValidOutputResultDim(opIndexingMap))
+    return batchMatmulOp->emitOpError()
+           << "Invalid output map result dimension.";
+
+  return success();
+}
+
+/// Verifies the broadcast and transpose semantic specified by the explicit
+/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
+static LogicalResult
+verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
+                                  unsigned opIndex) {
+  SmallVector<AffineMap, 3> opIndexingMaps =
+      batchMatmulOp.getIndexingMapsArray();
+  SmallVector<AffineMap, 3> defaultIndexingMaps =
+      batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+
+  if (opIndexingMaps.size() != 3)
+    return batchMatmulOp->emitOpError()
+           << "Indexing_map attribute must have 3 affine maps.";
+
+  auto opIndexingMap = opIndexingMaps[opIndex];
+  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
+
+  if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+    return failure();
+
+  if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
+                             opIndex == 0)))
+    return failure();
+
+  return success();
+}
+
 namespace mlir {
 namespace linalg {
 
@@ -3798,5 +3898,166 @@ Speculation::Speculatability ContractOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// Implementation of BatchMatmulOp
+//===----------------------------------------------------------------------===//
+SmallVector<AffineMap>
+BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+  AffineExpr d0, d1, d2, d3;
+  SmallVector<AffineMap> indexingMaps;
+  bindDims(context, d0, d1, d2, d3);
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
+  return indexingMaps;
+}
+
+SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{
+      utils::IteratorType::parallel, utils::IteratorType::parallel,
+      utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchMatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchMatmulOp::hasUserDefinedMaps() {
+  SmallVector<AffineMap, 3> defaultMaps =
+      getDefaultIndexingMaps(this->getContext());
+  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+  return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map bcastMap is valid for this op.
+bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
+  assert(bcastMap.getNumResults() < 3 &&
+         "Expected less than 3 result dim expr.");
+  bool isValid = false;
+  enum Indices { batchPos, mPos, nPos, kPos };
+  if (bcastMap.getNumResults() == 1) {
+    AffineExpr exp = bcastMap.getResult(0);
+    isValid = exp.isFunctionOfDim(kPos);
+  } else if (bcastMap.getNumResults() == 2) {
+    AffineExpr exp0 = bcastMap.getResult(0);
+    AffineExpr exp1 = bcastMap.getResult(1);
+    isValid = isLHS
+                  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+                  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+  }
+  return isValid;
+}
+
+void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                  ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "BatchMatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  auto toType = block.getArgument(2).getType();
+  Value castValA =
+      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
+  Value castValB =
+      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
+  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+  Value addVal =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+  yields.push_back(addVal);
+  helper.yieldOutputs(yields);
+}
+
+ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = llvm::map_to_vector(
+        BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
+  return ::parseNamedStructuredOp(parser, result,
+                                  BatchMatmulOp::getNumRegionArgs(),
+                                  BatchMatmulOp::getRegionBuilder());
+}
+
+void BatchMatmulOp::print(OpAsmPrinter &p) {
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                           elidedAttrs);
+
+  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+      BatchMatmulOp::getDefaultIndexingMaps(getContext()),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchMatmulOp::verify() {
+  // Verification of pure batch_matmul is handled by
+  // verifyStructuredOpInterface().
+  if (!hasUserDefinedMaps())
+    return success();
+
+  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
+    if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+      return failure();
+  }
+  return success();
+}
+
+LogicalResult BatchMatmulOp::fold(FoldAdaptor,
+                                  SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+void BatchMatmulOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index ed1685a9cb9e69c..7f9a0f7a6ca43b4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -138,6 +138,16 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
 FailureOr<PackResult>
 linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
                         const ControlBlockPackMatmulFn &controlPackMatmul) {
+  // Check to not let go the batch_matmul with extended semantic, through this
+  // transform.
+  if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
+    if (batchMatmulOp->hasUserDefinedMaps()) {
+      return rewriter.notifyMatchFailure(
+          *batchMatmulOp,
+          "only batch_matmul ops with non-extended semantics are supported");
+    }
+  }
+
   if (linalgOp.hasPureBufferSemantics())
     return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9b97865990bfdda..bd4ffabfbb929ec 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -906,6 +906,10 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
 
   LogicalResult matchAndRewrite(FromOpTy contractionOp,
                                 PatternRewriter &rewriter) const override {
+    if (contractionOp.hasUserDefinedMaps()) {
+      return rewriter.notifyMatchFailure(
+          contractionOp, "ops with user-defined maps are not supported");
+    }
 
     auto loc = contractionOp.getLoc();
     auto inputs = contractionOp.getDpsInputs();
@@ -935,7 +939,8 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
         loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
         ValueRange{collapsedInit});
     for (auto attr : contractionOp->getAttrs()) {
-      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
+          attr.getName() == "indexing_maps")
         continue;
       collapsedOp->setAttr(attr.getName(), attr.getValue());
     }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 6b934f7e8157d47..e624f589917d1ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -88,6 +88,11 @@ FailureOr<Operation *>
 mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
                                    linalg::BatchMatmulOp batchMatmulOp,
                                    bool transposeLHS) {
+  if (batchMatmulOp.hasUserDefinedMaps()) {
+    return rewriter.notifyMatchFailure(
+        batchMatmulOp, "ops with user-defined maps are not supported");
+  }
+
   if (!bufferization::hasTensorSemantics(batchMatmulOp))
     return rewriter.notifyMatchFailure(
         batchMatmulOp, "only matmul ops with tensors are supported");

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c95cd5eecfffca3..040663c882a086b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -484,24 +484,6 @@ def batch_mmt4d(
     ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
 
 
- at linalg_structured_op
-def batch_matmul(
-    A=TensorDef(T1, Batch, S.M, S.K),
-    B=TensorDef(T2, Batch, S.K, S.N),
-    C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
-    """Performs a batched matrix multiplication of two 3D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.b, D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
-        U, B[D.b, D.k, D.n]
-    )
-
-
 @linalg_structured_op
 def batch_matmul_transpose_a(
     A=TensorDef(T1, Batch, S.K, S.M),

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index a3611b8e4ec62e7..0ec71c35497b17a 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -999,6 +999,31 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul(
+// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<2x3x5xf32>,
+// CHECK-SAME:                            %[[VAL_1:.*]]: tensor<2x5x7xf32>,
+// CHECK-SAME:                            %[[VAL_2:.*]]: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> {
+// CHECK:           %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<2x3x7xf32>) {
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> {
+  %0 = linalg.batch_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                           ]
+    ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
+    outs(%arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32>
+  return %0 : tensor<2x3x7xf32>
+}
+
+// -----
+
 // CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 // CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
 // CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 0ea805fef5361ff..cff741e75077e49 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1258,3 +1258,204 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>
   %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32>
   return %0 : tensor<2x12x11x2xf32>
 }
+
+// -----
+
+func.func @indexing_map_size_mismatch_batch_matmul(%arg0: memref<?x?x?xf32>,
+     %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+     // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+     linalg.batch_matmul indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+    ]
+    ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @indexing_map_size_one_batch_matmul(%arg0: memref<?x?x?xf32>,
+     %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+     // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+     linalg.batch_matmul indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+    ]
+    ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @missing_indexing_map_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{expected attribute value}}
+  linalg.batch_matmul indexing_maps = [
+                       ,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                      ]
+                      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+                      outs(%arg2 :memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>) outs(%arg2: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expressions exceeds 3.}}
+  linalg.batch_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expressions exceeds 3.}}
+  linalg.batch_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_C_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op expects 3 dims, but got (2).}}
+  linalg.batch_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid output map result dimension.}}
+  linalg.batch_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?x?xf32>)
+    return
+}

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 578d24a550b08c7..ed8683522c74a1b 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1485,6 +1485,154 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(
+// CHECK-SAME:                                    %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME:                                    %[[VAL_1:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                    %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_dim_A(
+// CHECK-SAME:                                              %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                              %[[VAL_1:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                              %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_and_n_dim_B(
+// CHECK-SAME:                                                    %[[VAL_0:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                                    %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                                                    %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_dim_B(
+// CHECK-SAME:                                              %[[VAL_0:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                              %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                                              %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+func.func @batch_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @batch_matmul_explicit_transpose_A
+//       CHECK:   linalg.batch_matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @batch_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @batch_matmul_explicit_transpose_B
+//       CHECK:   linalg.batch_matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @batch_matmul_explicit_transpose_B(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_A_transpose_B(
+// CHECK-SAME:                                                %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                                %[[VAL_1:.*]]: memref<2x7x5xf32>,
+// CHECK-SAME:                                                %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @batchmatmul_transpose_a
 //       CHECK:   linalg.batch_matmul_transpose_a
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)

diff  --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index ebdbe70ff46eb7e..c68a6362f52c5b6 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -35,6 +35,23 @@ func.func @singleton_batch_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memr
 
 // -----
 
+func.func @negative_singleton_batch_matmul_to_matmul_memref(%arg0 : memref<1x?x?xf32>, %arg1 : memref<1x?x?xf32>, %arg2: memref<1x?x?xf32>) {
+  // CHECK-LABEL: @negative_singleton_batch_matmul_to_matmul_memref
+  // CHECK-NOT:   collapse_shape
+  // CHECK-NOT:   linalg.matmul
+  // CHECK-NOT:   expand_shape
+  linalg.batch_matmul indexing_maps = [
+                          affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                          affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                          affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                          ]
+      ins(%arg0, %arg1 : memref<1x?x?xf32>, memref<1x?x?xf32>)
+      outs(%arg2 : memref<1x?x?xf32>)
+  return
+}
+
+// -----
+
 func.func @singleton_batch_matvec(%arg0 : tensor<1x128x512xf32>, %arg1 : tensor<1x512xf32>, %arg2: tensor<1x128xf32>) -> tensor<1x128xf32> {
   // CHECK-LABEL: @singleton_batch_matvec
   //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x128x512xf32>
@@ -135,6 +152,20 @@ func.func @matmul_to_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x1xf32>, %arg
 
 // -----
 
+func.func @negative_matmul_to_matvec(%arg0: memref<?xf32>, %arg1: memref<?x1xf32>, %arg2: memref<?x1xf32>) {
+  // CHECK-LABEL: @negative_matmul_to_matvec
+  // CHECK-NOT: linalg.matvec
+    linalg.matmul indexing_maps = [
+                          affine_map<(d0, d1, d2) -> (d2)>,
+                          affine_map<(d0, d1, d2) -> (d2, d1)>,
+                          affine_map<(d0, d1, d2) -> (d0, d1)>
+                          ]
+                          ins(%arg0, %arg1: memref<?xf32>, memref<?x1xf32>) outs(%arg2: memref<?x1xf32>)
+    return
+}
+
+// -----
+
 func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
   // CHECK-LABEL: @matmul_to_vecmat
   //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>


        


More information about the Mlir-commits mailing list