[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Wed Jan 15 03:20:38 PST 2025


https://github.com/shahidact updated https://github.com/llvm/llvm-project/pull/122275

>From 0f0aa7df9029db9efea13ab0b0b817d63418b6f9 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 13 Nov 2024 01:08:25 -0800
Subject: [PATCH 1/4] [MLIR][Linalg] Introduce broadcast/transpose semantic to
 'linalg.batch_matmul' operation.

Goals:
1. To add syntax and semantic to 'batch_matmul' without changing any of the
   existing syntax expectations for current usage. batch_matmul is still
   just batch_matmul.

2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS
   infra.

Scope of this patch:
To expose broadcast and transpose semantics on the 'batch_matmul'.

The broadcast and transpose semantic is as follows:

By default 'linalg.batch_matmul' behavior will remain as is.
Broadcast and Transpose semantics can be appiled by specifying the
explicit attribute 'indexing_maps' as shown below.This is a list attribute, so the list
must include all the maps if specified.

    Example Transpose:
    ```
    linalg.batch_matmul indexing_maps = [
                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose
                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                   affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                   ]
                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
                   outs(%arg2: memref<2x3x7xf32>)
    ```

    Example Broadcast:
    ```
    linalg.batch_matmul indexing_maps = [
                       affine_map<(d0, d1, d2, d3) -> (d3)>,     //broadcast
                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                     ]
                     ins(%arg0, %arg1 : memref<5xf32>,memref<2x5x7xf32>)
                     outs(%arg2: memref<2x3x7xf32>)
    ```

    Example Broadcast and transpose:
    ```
    linalg.batch_matmul indexing_maps = [
                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     //broadcast
                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose
                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                     ]
                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
                     outs(%arg2: memref<2x3x7xf32>)
    ```
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  69 ------
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 124 ++++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 217 ++++++++++++++++++
 .../Linalg/Transforms/DropUnitDims.cpp        |   3 +-
 .../linalg/opdsl/ops/core_named_ops.py        |  18 --
 .../Dialect/Linalg/generalize-named-ops.mlir  |  23 ++
 mlir/test/Dialect/Linalg/invalid.mlir         | 118 ++++++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 148 ++++++++++++
 8 files changed, 632 insertions(+), 88 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index b0ea1f76955816..496a323249e852 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul
-  cpp_class_name: BatchMatmulOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_matmul_transpose_a
   cpp_class_name: BatchMatmulTransposeAOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..47b871aa322309 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,130 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
+  /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+    
+  let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
+  let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
+                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                   affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                   ]
+                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
+                   outs(%arg2: memref<2x3x7xf32>)
+    ```
+
+    Example Broadcast:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
+                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+
+    Example Broadcast and transpose:
+    ```
+    linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     // broadcast
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
+                     outs(%arg2: memref<2x3x7xf32>)
+    ```
+}];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, BatchMatmulOp::getRegionBuilder(),
+          BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
+          BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion(),
+        BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
+      }]>
+      
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      /// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
+      static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      bool hasDynamicIndexingMaps() { return true; }
+      std::string getLibraryCallName();
+      /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+      /// user defined indexing maps are not equal to default map.
+      bool hasUserDefinedMaps();
+    }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8973e87c063b33..868892d1e5f5cc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,23 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
                            attributes, regionBuilder);
 }
 
+static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
+                               std::optional<TypeRange> resultTensorTypes,
+                               ValueRange inputs, ValueRange outputs,
+                               ArrayRef<NamedAttribute> attributes,
+                               RegionBuilderFn regionBuilder,
+                               ArrayRef<AffineMap> indexingMaps) {
+  // Initialize indexingMaps attribute, for BatchMatmulOp.
+  SmallVector<Attribute, 4> indexingMapsAttrVal;
+  indexingMapsAttrVal =
+      llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+        return AffineMapAttr::get(map);
+      });
+  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+                           attributes, regionBuilder);
+}
+
 /// Common parsing used for both named structured ops created by ods-gen and by
 /// manually defined C++ ops. Does not handle regions.
 static ParseResult
@@ -3450,6 +3467,46 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
   return success();
 }
 
+/// Checks if the given AffineMap represents a valid batch dimension.
+/// It checks if the first result dimension is a function of the first
+/// dimension.
+static bool isValidBatchDim(AffineMap bcastMap) {
+  assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr.");
+  AffineExpr exp = bcastMap.getResult(0);
+  return exp.isFunctionOfDim(0);
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult
+verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
+                                  unsigned opIndex) {
+  SmallVector<AffineMap, 3> opIndexingMaps =
+      batchMatmulOp.getIndexingMapsArray();
+  SmallVector<AffineMap, 3> defaultIndexingMaps =
+      batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
+
+  auto opIndexingMap = opIndexingMaps[opIndex];
+  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
+  // Check general validity of indexing map results.
+  if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+    return batchMatmulOp->emitOpError()
+           << "Unexpected dim expression in map result.";
+  // Check if the requested broadcast is valid.
+  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+    if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, opIndex == 0)) {
+      return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+    }
+  } else {
+    if (!isValidBatchDim(opIndexingMap)) {
+      return batchMatmulOp->emitOpError()
+             << "Invalid batch dimension expression.";
+    }
+  }
+  return success();
+}
+
 namespace mlir {
 namespace linalg {
 
@@ -3611,5 +3668,165 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// Implementation of BatchMatmulOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineMap>
+BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
+  AffineExpr d0, d1, d2, d3;
+  SmallVector<AffineMap> indexingMaps;
+  bindDims(context, d0, d1, d2, d3);
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
+  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
+  return indexingMaps;
+}
+
+SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{
+      utils::IteratorType::parallel, utils::IteratorType::parallel,
+      utils::IteratorType::parallel, utils::IteratorType::reduction};
+}
+
+unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
+
+std::string BatchMatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if
+/// the user defined indexing maps are not equal to default map.
+bool BatchMatmulOp::hasUserDefinedMaps() {
+  SmallVector<AffineMap, 3> defaultMaps =
+      getDefaultIndexingMaps(this->getContext());
+  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+  return defaultMaps != explicitMaps;
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
+  assert(bcastMap.getNumResults() < 3 && "Expected single result dim expr.");
+  bool isValid = false;
+  enum Indices { batchPos, mPos, nPos, kPos };
+  if (bcastMap.getNumResults() == 1) {
+    AffineExpr exp = bcastMap.getResult(0);
+    isValid = exp.isFunctionOfDim(kPos);
+  } else if (bcastMap.getNumResults() == 2) {
+    AffineExpr exp0 = bcastMap.getResult(0);
+    AffineExpr exp1 = bcastMap.getResult(1);
+    isValid = isLHS
+                  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
+                  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
+  }
+  return isValid;
+}
+
+void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                  ArrayRef<NamedAttribute> attrs) {
+  assert(3 > 0 && block.getNumArguments() == 3 &&
+         "BatchMatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  Value value1 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(0));
+  Value value2 =
+      helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+                         block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = llvm::map_to_vector(
+        BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
+        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
+  return ::parseNamedStructuredOp(parser, result,
+                                  BatchMatmulOp::getNumRegionArgs(),
+                                  BatchMatmulOp::getRegionBuilder());
+}
+
+void BatchMatmulOp::print(OpAsmPrinter &p) {
+  SmallVector<StringRef, 3> elidedAttrs = {
+      "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
+  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                           elidedAttrs);
+
+  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+      BatchMatmulOp::getDefaultIndexingMaps(getContext()),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+}
+
+/// Verify the user defined indexing maps.
+LogicalResult BatchMatmulOp::verify() {
+  // Verification of pure batch_matmul is handled by
+  // verifyStructuredOpInterface().
+  if (!hasUserDefinedMaps())
+    return success();
+
+  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
+    if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
+      return failure();
+  }
+  return success();
+}
+
+LogicalResult BatchMatmulOp::fold(FoldAdaptor,
+                                  SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void BatchMatmulOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9b97865990bfdd..a5d4c7fe9908c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -935,7 +935,8 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
         loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
         ValueRange{collapsedInit});
     for (auto attr : contractionOp->getAttrs()) {
-      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
+      if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
+          attr.getName() == "indexing_maps")
         continue;
       collapsedOp->setAttr(attr.getName(), attr.getValue());
     }
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c95cd5eecfffca..040663c882a086 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -484,24 +484,6 @@ def batch_mmt4d(
     ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
 
 
- at linalg_structured_op
-def batch_matmul(
-    A=TensorDef(T1, Batch, S.M, S.K),
-    B=TensorDef(T2, Batch, S.K, S.N),
-    C=TensorDef(U, Batch, S.M, S.N, output=True),
-):
-    """Performs a batched matrix multiplication of two 3D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.b, D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
-        U, B[D.b, D.k, D.n]
-    )
-
-
 @linalg_structured_op
 def batch_matmul_transpose_a(
     A=TensorDef(T1, Batch, S.K, S.M),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..638238b5c38a60 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1002,3 +1002,26 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul(
+// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<?x?x?xf32>, %[[VAL_1:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME:                            %[[VAL_2:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[VAL_2]] : tensor<?x?x?xf32>) {
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.batch_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                           ]
+    ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+    outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a59472377a732c..e14124097fe2b6 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1142,3 +1142,121 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>
   %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32>
   return %0 : tensor<2x12x11x2xf32>
 }
+
+// -----
+
+func.func @missing_indexing_map_batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+  // expected-error @+1 {{expected attribute value}}
+  linalg.batch_matmul indexing_maps = [
+                       ,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                      ]
+                      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+                      outs(%arg2 :memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_matmul_a(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+  // expected-error @+1 {{Unexpected dim expression in map result}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 :tensor<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_matmul_b(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+  // expected-error @+1 {{Unexpected dim expression in map result}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 :tensor<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>) outs(%arg2: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}}
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+  return
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 68aa5a85b5e0e6..75200fc2e5b1ee 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1487,6 +1487,154 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_and_m_dim_A(
+// CHECK-SAME:                                    %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME:                                    %[[VAL_1:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                    %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_batch_and_m_dim_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_dim_A(
+// CHECK-SAME:                                              %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                              %[[VAL_1:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                              %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_and_n_dim_B(
+// CHECK-SAME:                                                    %[[VAL_0:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                                    %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                                                    %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_dim_B(
+// CHECK-SAME:                                              %[[VAL_0:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                              %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                                              %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+func.func @batch_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @batch_matmul_explicit_transpose_a
+//       CHECK:   linalg.batch_matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @batch_matmul_explicit_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x5x3xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @batch_matmul_explicit_transpose_b
+//       CHECK:   linalg.batch_matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @batch_matmul_explicit_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_A_transpose_B(
+// CHECK-SAME:                                                %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                                %[[VAL_1:.*]]: memref<2x7x5xf32>,
+// CHECK-SAME:                                                %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x7x5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_A_transpose_B(%arg0: memref<3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @batchmatmul_transpose_a
 //       CHECK:   linalg.batch_matmul_transpose_a
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)

>From 46d7f36c87472c96b734557ef4946619cc28e7e2 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Fri, 10 Jan 2025 10:02:47 -0800
Subject: [PATCH 2/4] -Added output map verification and corresponding tests.
 -Replaced assert for the count of number of dim expression with proper error
 reporting and new test case. -Fixed typos.

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 53 ++++++++++++++-----
 mlir/test/Dialect/Linalg/invalid.mlir    | 66 ++++++++++++++++++++++--
 2 files changed, 101 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 868892d1e5f5cc..db48c2baf009a8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3476,7 +3476,18 @@ static bool isValidBatchDim(AffineMap bcastMap) {
   return exp.isFunctionOfDim(0);
 }
 
-/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// Checks if the given AffineMap's result dimensions are valid output result
+/// dimensions.
+static bool isValidOutputResultDim(AffineMap outputMap) {
+  enum Indices { batchPos, mPos, nPos };
+  AffineExpr exp0 = outputMap.getResult(batchPos);
+  AffineExpr exp1 = outputMap.getResult(mPos);
+  AffineExpr exp2 = outputMap.getResult(nPos);
+  return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) &&
+         exp2.isFunctionOfDim(nPos);
+}
+
+/// Verifies the broadcast and transpose semantic specified by the explicit
 /// indexing map for the BatchMatmulOp \p op for each operand specified by \p
 /// opIndex.
 static LogicalResult
@@ -3490,19 +3501,35 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
   auto opIndexingMap = opIndexingMaps[opIndex];
   auto defaultIndexingMap = defaultIndexingMaps[opIndex];
   // Check general validity of indexing map results.
-  if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
-    return batchMatmulOp->emitOpError()
-           << "Unexpected dim expression in map result.";
-  // Check if the requested broadcast is valid.
-  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
-    if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, opIndex == 0)) {
-      return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+  if (opIndex < 2) {
+    if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+      return batchMatmulOp->emitOpError()
+             << "Unexpected dim expression in map result.";
+    // Check if the requested broadcast is valid.
+    if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+      if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap,
+                                                   opIndex == 0)) {
+        return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+      }
+    } else {
+      // Check for valid number of result dims of input maps.
+      if (opIndexingMap.getNumResults() != 3)
+        return batchMatmulOp->emitOpError()
+               << "no. of result dim expression cannot exceed 3.";
+
+      if (!isValidBatchDim(opIndexingMap))
+        return batchMatmulOp->emitOpError()
+               << "Invalid batch dimension expression.";
     }
   } else {
-    if (!isValidBatchDim(opIndexingMap)) {
+    // Check for valid number of result dims of output map.
+    if (opIndexingMap.getNumResults() != 3)
       return batchMatmulOp->emitOpError()
-             << "Invalid batch dimension expression.";
-    }
+             << "no. of result dim expression cannot exceed 3.";
+
+    if (!isValidOutputResultDim(opIndexingMap))
+      return batchMatmulOp->emitOpError()
+             << "Invalid output map result dimension.";
   }
   return success();
 }
@@ -3724,7 +3751,7 @@ bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
 
 void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
                                   ArrayRef<NamedAttribute> attrs) {
-  assert(3 > 0 && block.getNumArguments() == 3 &&
+  assert(block.getNumArguments() == 3 &&
          "BatchMatmulOp regionBuilder expects 3 (>=0) args");
   RegionBuilderHelper helper(b, block);
   SmallVector<Value> yields;
@@ -3806,7 +3833,7 @@ LogicalResult BatchMatmulOp::verify() {
   if (!hasUserDefinedMaps())
     return success();
 
-  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
+  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
     if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
       return failure();
   }
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index e14124097fe2b6..5faba3a815b8b9 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1145,7 +1145,7 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>
 
 // -----
 
-func.func @missing_indexing_map_batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+func.func @missing_indexing_map_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
   // expected-error @+1 {{expected attribute value}}
   linalg.batch_matmul indexing_maps = [
                        ,
@@ -1159,27 +1159,27 @@ func.func @missing_indexing_map_batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: te
 
 // -----
 
-func.func @invalid_dim_expr_batch_matmul_a(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
   // expected-error @+1 {{Unexpected dim expression in map result}}
   linalg.batch_matmul indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                      ]
-                     ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 :tensor<?x?x?xf32>)
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
   return
 }
 
 // -----
 
-func.func @invalid_dim_expr_batch_matmul_b(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) {
+func.func @invalid_dim_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
   // expected-error @+1 {{Unexpected dim expression in map result}}
   linalg.batch_matmul indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                      ]
-                     ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%arg2 :tensor<?x?x?xf32>)
+                     ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
   return
 }
 
@@ -1260,3 +1260,59 @@ func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: mem
                      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
   return
 }
+
+// -----
+
+func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+  linalg.batch_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+  linalg.batch_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_C_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+  linalg.batch_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_matmul' op Invalid output map result dimension.}}
+  linalg.batch_matmul indexing_maps = [
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+                           ]
+    ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?x?xf32>)
+    return
+}

>From 5670406f4b6673634110589ab5a87295a6511e42 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 15 Jan 2025 01:05:07 -0800
Subject: [PATCH 3/4] *Added checks for extended semantics and exit gracefully
 in user passes. *Added and udated test cases. *Refactored verification logic.

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 80 ++++++++++++-------
 .../Linalg/Transforms/BlockPackMatmul.cpp     | 10 +++
 .../Linalg/Transforms/DropUnitDims.cpp        |  9 +++
 .../Linalg/Transforms/TransposeMatmul.cpp     |  8 ++
 .../Dialect/Linalg/generalize-named-ops.mlir  | 15 ++--
 mlir/test/Dialect/Linalg/invalid.mlir         |  2 +-
 6 files changed, 85 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index db48c2baf009a8..aa3be4a3763e8f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3471,7 +3471,6 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
 /// It checks if the first result dimension is a function of the first
 /// dimension.
 static bool isValidBatchDim(AffineMap bcastMap) {
-  assert(bcastMap.getNumResults() == 3 && "Expected three result dim expr.");
   AffineExpr exp = bcastMap.getResult(0);
   return exp.isFunctionOfDim(0);
 }
@@ -3487,6 +3486,48 @@ static bool isValidOutputResultDim(AffineMap outputMap) {
          exp2.isFunctionOfDim(nPos);
 }
 
+// Check general validity of input indexing map.
+static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
+                                     AffineMap opIndexingMap,
+                                     AffineMap defaultIndexingMap, bool isLHS) {
+  // Check the result dims are valid.
+  if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+    return batchMatmulOp->emitOpError()
+           << "Unexpected dim expression in map result.";
+
+  // Check for valid number of result dims of input maps.
+  if (opIndexingMap.getNumResults() > 3)
+    return batchMatmulOp->emitOpError()
+           << "no. of result dim expression cannot exceed 3.";
+
+  // Check if the requested broadcast is valid.
+  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+    if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
+      return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
+  } else if (!isValidBatchDim(opIndexingMap)) {
+    return batchMatmulOp->emitOpError()
+           << "Invalid batch dimension expression.";
+  }
+  return success();
+}
+
+/// This function checks if the given AffineMap for the output of a
+/// BatchMatmulOp has exactly 3 result dimensions and if the output map result
+/// dimensions are valid.
+static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
+                                     AffineMap opIndexingMap) {
+  if (opIndexingMap.getNumResults() != 3)
+    return batchMatmulOp->emitOpError()
+           << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
+           << ").";
+
+  if (!isValidOutputResultDim(opIndexingMap))
+    return batchMatmulOp->emitOpError()
+           << "Invalid output map result dimension.";
+
+  return success();
+}
+
 /// Verifies the broadcast and transpose semantic specified by the explicit
 /// indexing map for the BatchMatmulOp \p op for each operand specified by \p
 /// opIndex.
@@ -3500,37 +3541,14 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
 
   auto opIndexingMap = opIndexingMaps[opIndex];
   auto defaultIndexingMap = defaultIndexingMaps[opIndex];
-  // Check general validity of indexing map results.
-  if (opIndex < 2) {
-    if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
-      return batchMatmulOp->emitOpError()
-             << "Unexpected dim expression in map result.";
-    // Check if the requested broadcast is valid.
-    if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
-      if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap,
-                                                   opIndex == 0)) {
-        return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
-      }
-    } else {
-      // Check for valid number of result dims of input maps.
-      if (opIndexingMap.getNumResults() != 3)
-        return batchMatmulOp->emitOpError()
-               << "no. of result dim expression cannot exceed 3.";
-
-      if (!isValidBatchDim(opIndexingMap))
-        return batchMatmulOp->emitOpError()
-               << "Invalid batch dimension expression.";
-    }
-  } else {
-    // Check for valid number of result dims of output map.
-    if (opIndexingMap.getNumResults() != 3)
-      return batchMatmulOp->emitOpError()
-             << "no. of result dim expression cannot exceed 3.";
 
-    if (!isValidOutputResultDim(opIndexingMap))
-      return batchMatmulOp->emitOpError()
-             << "Invalid output map result dimension.";
-  }
+  if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
+    return failure();
+
+  if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
+                             opIndex == 0)))
+    return failure();
+
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 57344f986480da..de05d0f7e5b832 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -138,6 +138,16 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
 FailureOr<PackResult>
 linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
                         const ControlBlockPackMatmulFn &controlPackMatmul) {
+  // Check to not let go the batch_matmul with extended semantic, through this
+  // transform.
+  if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
+    if (batchMatmulOp->hasUserDefinedMaps()) {
+      return rewriter.notifyMatchFailure(
+          *batchMatmulOp,
+          "only batch_matmul ops with non-extended semantics are supported");
+    }
+  }
+
   if (linalgOp.hasPureBufferSemantics())
     return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index a5d4c7fe9908c5..a5ebe7628accda 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -906,6 +906,15 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
 
   LogicalResult matchAndRewrite(FromOpTy contractionOp,
                                 PatternRewriter &rewriter) const override {
+    // Check to not let go the batch_matmul with extended semantic, through this
+    // transform.
+    if (std::is_same<FromOpTy, BatchMatmulOp>::value) {
+      if (contractionOp.hasUserDefinedMaps()) {
+        return rewriter.notifyMatchFailure(
+            contractionOp,
+            "only batch_matmul ops with non-extended semantics are supported");
+      }
+    }
 
     auto loc = contractionOp.getLoc();
     auto inputs = contractionOp.getDpsInputs();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 6b934f7e8157d4..8d12f8a98dbdd1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -88,6 +88,14 @@ FailureOr<Operation *>
 mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
                                    linalg::BatchMatmulOp batchMatmulOp,
                                    bool transposeLHS) {
+  // Check to not let go the batch_matmul with extended semantic, through this
+  // transform.
+  if (batchMatmulOp.hasUserDefinedMaps()) {
+    return rewriter.notifyMatchFailure(
+        batchMatmulOp,
+        "only batch_matmul ops with non-extended semantics are supported");
+  }
+
   if (!bufferization::hasTensorSemantics(batchMatmulOp))
     return rewriter.notifyMatchFailure(
         batchMatmulOp, "only matmul ops with tensors are supported");
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 638238b5c38a60..42eb940006b96d 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1007,21 +1007,22 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
 // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 
 // CHECK-LABEL:   func.func @batch_matmul(
-// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<?x?x?xf32>, %[[VAL_1:.*]]: tensor<?x?x?xf32>,
-// CHECK-SAME:                            %[[VAL_2:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
-// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[VAL_2]] : tensor<?x?x?xf32>) {
+// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<2x3x5xf32>,
+// CHECK-SAME:                            %[[VAL_1:.*]]: tensor<2x5x7xf32>,
+// CHECK-SAME:                            %[[VAL_2:.*]]: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> {
+// CHECK:           %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<2x3x5xf32>, tensor<2x5x7xf32>) outs(%[[VAL_2]] : tensor<2x3x7xf32>) {
 // CHECK:           arith.mulf
 // CHECK:           arith.addf
 
-func.func @batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32> {
   %0 = linalg.batch_matmul indexing_maps = [
                             affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
                             affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                             affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                            ]
-    ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
-    outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-  return %0 : tensor<?x?x?xf32>
+    ins(%arg0, %arg1: tensor<2x3x5xf32>, tensor<2x5x7xf32>)
+    outs(%arg2: tensor<2x3x7xf32>) -> tensor<2x3x7xf32>
+  return %0 : tensor<2x3x7xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 5faba3a815b8b9..5554cb082dd85f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1292,7 +1292,7 @@ func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
 // -----
 
 func.func @invalid_C_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
-  // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+  // expected-error @+1 {{'linalg.batch_matmul' op expects 3 dims, but got (2).}}
   linalg.batch_matmul indexing_maps = [
                             affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
                             affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,

>From 5044162a386fc8bd394b2bb2a5f447abfbe90b0f Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 15 Jan 2025 03:20:03 -0800
Subject: [PATCH 4/4] *Added logic and tests to verify the size of supplied
 indexing_map attribute.

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp |  4 ++++
 mlir/test/Dialect/Linalg/invalid.mlir    | 27 ++++++++++++++++++++++++
 2 files changed, 31 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index aa3be4a3763e8f..77d29c89d1d727 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3539,6 +3539,10 @@ verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
   SmallVector<AffineMap, 3> defaultIndexingMaps =
       batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
 
+  if (opIndexingMaps.size() != 3)
+    return batchMatmulOp->emitOpError()
+           << "Indexing_map attribute must have 3 affine maps.";
+
   auto opIndexingMap = opIndexingMaps[opIndex];
   auto defaultIndexingMap = defaultIndexingMaps[opIndex];
 
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 5554cb082dd85f..8bf2f858ff7c97 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1145,6 +1145,33 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>
 
 // -----
 
+func.func @indexing_map_size_mismatch_batch_matmul(%arg0: memref<?x?x?xf32>,
+     %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+     // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+     linalg.batch_matmul indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+    ]
+    ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?x?xf32>)
+    return
+}
+
+// -----
+
+func.func @indexing_map_size_one_batch_matmul(%arg0: memref<?x?x?xf32>,
+     %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+     // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+     linalg.batch_matmul indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+    ]
+    ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs(%arg2: memref<?x?x?xf32>)
+    return
+}
+
+// -----
+
 func.func @missing_indexing_map_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
   // expected-error @+1 {{expected attribute value}}
   linalg.batch_matmul indexing_maps = [



More information about the Mlir-commits mailing list