[Mlir-commits] [mlir] [mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (PR #104783)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Thu Sep 12 09:32:37 PDT 2024


https://github.com/shahidact updated https://github.com/llvm/llvm-project/pull/104783

>From 121cc0b59230659273b62be9765fb06a4816d9b2 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 21 Aug 2024 05:34:27 -0700
Subject: [PATCH] [mlir][linalg] Extends 'linalg.matmul' named op to define
 broadcast and transpose semantic.

Goals:
1. To add syntax to matmul without changing any of the existing syntax
   expectations for current usage. matmul is still just matmul.

2. To expose broadcast and transpose semantics on the three matmul
   variations: matmul, batch_matmul and batch_reduce_matmul.

Scope of this patch:
To expose broadcast and transpose semantics on the 'matmul'.

The broadcast and transpose semantic is as follows:

By default 'linalg.matmul' behavior will remain as is.Broadcast and
Transpose semantics can be appiled by specifying the explicit attribute
'indexing_maps' as shown below.This is a list attribute, so the list
must include all the maps if specified.

Example Transpose:
linalg.matmul indexing_maps = [
               affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
               affine_map<(d0, d1, d2) -> (d2, d1)>,
               affine_map<(d0, d1, d2) -> (d0, d1)>
               ]
               ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
               outs(%arg2: memref<3x7xf32>)

Example Broadcast:
linalg.matmul indexing_maps = [
               affine_map<(d0, d1, d2) -> (d0)>,     // broadcast
               affine_map<(d0, d1, d2) -> (d2, d1)>,
               affine_map<(d0, d1, d2) -> (d0, d1)>
              ]
              ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
              outs(%arg2: memref<3x7xf32>)
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  72 ----
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |   1 +
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 146 +++++++++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  91 +++++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 309 +++++++++++++++++-
 .../linalg/opdsl/ops/core_named_ops.py        |  17 -
 .../Dialect/Linalg/generalize-named-ops.mlir  | 167 ++++++++++
 mlir/test/Dialect/Linalg/invalid.mlir         |  94 ++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 219 +++++++++++++
 mlir/test/python/dialects/linalg/ops.py       |  75 -----
 10 files changed, 1014 insertions(+), 177 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 8cb698096ef5b7..97b90333e2b200 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul
-  cpp_class_name: MatmulOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_matmul
   cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 5b6a90f806bedd..c58c64fbb5554a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -400,4 +400,5 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
   let hasVerifier = 1;
 }
 
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 31f29139247267..4402d8bc92b305 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -535,6 +535,152 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+    
+  let summary = [{Performs a matrix multiplication of two 2D inputs without broadcast or transpose.}];
+  let description = [{Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                   ]
+                   ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+     ```
+
+    Example Broadcast:
+     ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d0)>,     // broadcast
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+     ```
+
+     Example Broadcast and transpose:
+     ```
+     inalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                       affine_map<(d0, d1, d2) -> (d1)>,     // broadcast
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>
+
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Populates the output parameter \p indexingMaps with the typical matmul
+      /// indexing maps.
+      void getDefaultIndexingMaps(SmallVectorImpl<AffineMap> &indexingMaps);
+
+      /// Infers the dimension sizes and populate into \p dimSizeMap.
+      /// Input parameter \p allTypes is a map of shapes for each operand.
+      void inferDimensionSizes(llvm::DenseMap<unsigned, ShapedType> &allTypes,
+                       llvm::DenseMap<unsigned, unsigned> &dimSizeMap);
+
+      /// Infers the broadcasted dimension and populates \p broadcastDims which is
+      /// a map of dimensions to a pair of Boolean and AffineDimExpr position,
+      /// indicating broadcast and the corresponding AffineDimExpr position.
+      /// It uses input parameters \p explicitMap parsed from the op and \p defaultMap
+      /// corresponding to an input operand.
+      void inferBroadcastDimensions(AffineMap explicitMap, AffineMap defaultMap,
+                       DenseMap<unsigned, std::pair<bool, unsigned>> &broadcastDims);
+
+      /// Construct and returns a broadcasted type for the input operand identified by
+      /// input parameter \p opIndex.
+      ShapedType constructBroadcastedType(unsigned opIndex);
+
+      /// Returns true if the input operand identified by \p opIndex need broadcasting.
+      bool hasBroadcastSemantic(unsigned opIndex);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 0c48a5aeb26a26..e67dc1b6856cb1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -17,10 +17,15 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 
 using namespace mlir;
@@ -1139,9 +1144,45 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
          operandNumber - start;
 }
 
+/// Try to verify the broadcasted type \p dst given the source type \p src for
+/// the MatmulOp \p op with broadcast semantic. Input parameter \p
+/// broadcastedDims is a map of dimensions to a pair of boolean and
+/// AffineDimExpr position, indicating broadcast and the corresponding
+/// AffineDimExpr position.
+static LogicalResult verifyBroadcastSemantic(
+    Operation *op, Type &src, Type &dst,
+    DenseMap<unsigned, std::pair<bool, unsigned>> &broadcastedDims) {
+  auto srcType = cast<ShapedType>(src);
+  auto dstType = cast<ShapedType>(dst);
+
+  int64_t inputRank = srcType.getRank();
+  int64_t initRank = dstType.getRank();
+
+  unsigned numBroadcastedDims = 0;
+  for (auto it : broadcastedDims) {
+    if (it.second.first) {
+      numBroadcastedDims++;
+    }
+  }
+  if ((size_t)inputRank + numBroadcastedDims != (size_t)initRank)
+    return op->emitOpError()
+           << "input rank plus added dimensions does not "
+              "match init rank. input rank: "
+           << inputRank << ", dimensions size: " << numBroadcastedDims
+           << ", init rank: " << initRank;
+
+  for (unsigned dim = 0; dim < broadcastedDims.size(); dim++) {
+    if (broadcastedDims[dim].first && (dim < 0 || dim >= initRank))
+      return op->emitOpError()
+             << "dimension " << dim << " is out of range. expected range: [0, "
+             << initRank - 1 << "], got: " << dim;
+  }
+
+  return success();
+}
+
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
-
   // Mixed tensor/buffer operands are not allowed.
   if (!linalgOp.hasPureTensorSemantics() &&
       !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1161,9 +1202,11 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
            << ") to be equal to the number of input/output operands ("
            << linalgOp->getNumOperands() << ")";
 
+  bool hasBroadcast = false;
+  MatmulOp matmulOp = dyn_cast<MatmulOp>(op);
   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
-
+    unsigned opIndex = opOperand.getOperandNumber();
     // Symbols disallowed.
     if (indexingMap.getNumSymbols() != 0)
       return op->emitOpError("unexpected symbols in indexing_map #")
@@ -1177,13 +1220,48 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
              << " dim(s) to match the number of loops";
 
     int64_t rank = linalgOp.getRank(&opOperand);
-    if (indexingMap.getNumResults() != rank)
+    // Check if matmulOp need broadcast.
+    if (matmulOp)
+      hasBroadcast = matmulOp.hasBroadcastSemantic(opIndex);
+
+    if (!hasBroadcast && indexingMap.getNumResults() != rank)
       return op->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
-  }
 
+    // For verification of broadcast/transpose, create a temporary broadcasted
+    // type and try to verify the constructed type from the provided types and
+    // indexing maps.
+    if (hasBroadcast) {
+      llvm::SmallVector<AffineMap, 3> defaultIndexingMaps;
+      matmulOp.getDefaultIndexingMaps(defaultIndexingMaps);
+
+      // Check for valid broadcast request.
+      SmallVector<AffineMap, 3> opIndexingMaps =
+          matmulOp.getIndexingMapsArray();
+      for (unsigned dim = 0; dim < opIndexingMaps[opIndex].getNumResults();
+           dim++) {
+        AffineExpr exp = opIndexingMaps[opIndex].getResult(dim);
+        // Invalid map if dim expr 'd2' not found.
+        if (!exp.isFunctionOfDim(2)) {
+          return op->emitOpError() << "Invalid broadcast requested (d2).";
+        }
+      }
+
+      // Construct the broadcasted type
+      Type originalType = opOperand.get().getType();
+      DenseMap<unsigned, std::pair<bool, unsigned>> broadcastedDims;
+      matmulOp.inferBroadcastDimensions(opIndexingMaps[opIndex],
+                                        defaultIndexingMaps[opIndex],
+                                        broadcastedDims);
+      Type broadcastedType = matmulOp.constructBroadcastedType(opIndex);
+      if (failed(verifyBroadcastSemantic(op, originalType, broadcastedType,
+                                         broadcastedDims)))
+        return failure();
+    }
+    hasBroadcast = false;
+  }
   SmallVector<unsigned> redDims;
   linalgOp.getReductionDims(redDims);
 
@@ -1193,9 +1271,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   // Check if given shapes match to inferred shapes.
   SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
   SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
-
-  // Verify only static cases since we can't get exact dimension sizes and loop
-  // ranges for dynamic cases in this stage.
+  // Verify only static cases since we can't get exact dimension sizes and
+  // loop ranges for dynamic cases in this stage.
   if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
     for (int64_t &range : endLoopRangeValues)
       range -= 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b888005625eda7..1e9d90089ce2ba 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
@@ -37,12 +38,16 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cassert>
 #include <optional>
 
 using namespace mlir;
@@ -149,15 +154,30 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   // iterator_types is an auto-generated method.
 }
 
+static void getDefaultIndexingMap(MLIRContext *context,
+                                  SmallVectorImpl<Attribute> &indexingMaps) {
+  Attribute mapAttr;
+  mapAttr = llvm::cast<AffineMapAttr>(
+      mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d2)>", context));
+  indexingMaps.push_back(mapAttr);
+  mapAttr = llvm::cast<AffineMapAttr>(
+      mlir::parseAttribute("affine_map<(d0, d1, d2)->(d2, d1)>", context));
+  indexingMaps.push_back(mapAttr);
+  mapAttr = llvm::cast<AffineMapAttr>(
+      mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d1)>", context));
+  indexingMaps.push_back(mapAttr);
+}
+
 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
 /// The result types are derived automatically if `resultTensorTypes` is none.
 /// The body of the operation is filled using `regionBuilder`. All ods-gen
 /// created structured operations use the method to implement their builders.
-static void buildStructuredOp(OpBuilder &b, OperationState &state,
-                              std::optional<TypeRange> resultTensorTypes,
-                              ValueRange inputs, ValueRange outputs,
-                              ArrayRef<NamedAttribute> attributes,
-                              RegionBuilderFn regionBuilder) {
+static void buildStructuredOp(
+    OpBuilder &b, OperationState &state,
+    std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+    ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+    RegionBuilderFn regionBuilder,
+    std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
   // Derive the result types if needed.
   SmallVector<Type> derivedResultTypes =
       resultTensorTypes.value_or(TypeRange());
@@ -168,6 +188,20 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
   state.addOperands(inputs);
   state.addOperands(outputs);
   state.addTypes(derivedResultTypes);
+
+  // Initialize indexingMaps, for MatmulOp.
+  SmallVector<Attribute, 3> indexingMapsAttrVal;
+  if (indexingMaps.has_value()) {
+    for (mlir::AffineMap map : *indexingMaps) {
+      // Convert each AffineMap to an AffineMapAttr
+      indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
+    }
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  } else {
+    getDefaultIndexingMap(b.getContext(), indexingMapsAttrVal);
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  }
+
   state.addAttributes(attributes);
   state.addAttribute(
       "operandSegmentSizes",
@@ -299,11 +333,48 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result,
                                           unsigned numRegionArgs,
                                           RegionBuilderFn regionBuilder) {
+
+  SmallVector<Attribute, 3> indexingMaps;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMaps.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMaps.empty()) {
+    getDefaultIndexingMap(result.getContext(), indexingMaps);
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMaps));
+
   // TODO: Enable when ods-gen supports captures.
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
 
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
   // TODO: consider merging results parsing into region parsing.
   // Need to wait for declarative assembly resolution to decide.
   SmallVector<Type, 1> outputTensorsTypes;
@@ -335,7 +406,7 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
       /*elidedAttrs=*/{"operandSegmentSizes",
                        // See generated code in
                        // LinalgNamedStructuredOps.yamlgen.cpp.inc
-                       "linalg.memoized_indexing_maps"});
+                       "linalg.memoized_indexing_maps", "indexing_maps"});
 
   // Printing is shared with generic ops, except for the region and
   // attributes.
@@ -3383,3 +3454,229 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+                                          utils::IteratorType::parallel,
+                                          utils::IteratorType::reduction};
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+/// Implements the block region builder for the MatmulOp.This is called by
+/// 'fillStructuredOpRegion'.
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                             ArrayRef<NamedAttribute> attrs) {
+  assert(3 > 0 && block.getNumArguments() == 3 &&
+         "MatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
+  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(0));
+  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+/// Populates the output parameter \p indexingMaps with the typical matmul
+/// indexing maps.
+void MatmulOp::getDefaultIndexingMaps(
+    SmallVectorImpl<AffineMap> &indexingMaps) {
+  MLIRContext *context = this->getContext();
+  AffineExpr d0, d1, d2;
+  bindDims(context, d0, d1, d2);
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+}
+
+/// Returns true if the input operand identified by \p opIndex need
+/// broadcasting.
+bool MatmulOp::hasBroadcastSemantic(unsigned opIndex) {
+  if (opIndex > 1)
+    return false;
+  SmallVector<AffineMap, 3> defaultMaps;
+  SmallVector<AffineMap, 3> explicitMaps;
+  getDefaultIndexingMaps(defaultMaps);
+  explicitMaps = getIndexingMapsArray();
+  if (explicitMaps[opIndex].getNumResults() <
+      defaultMaps[opIndex].getNumResults())
+    return true;
+  return false;
+}
+
+/// Infers the dimension sizes and populate into \p dimSizeMap.
+/// Input parameter \p allTypes is a map of shapes for each operand.
+void MatmulOp::inferDimensionSizes(
+    llvm::DenseMap<unsigned, ShapedType> &allTypes,
+    llvm::DenseMap<unsigned, unsigned> &dimSizeMap) {
+  assert(!allTypes.empty() && "Expected non empty types");
+  assert(allTypes[0].getRank() > 0 && allTypes[1].getRank() > 0 &&
+         "Input rank must be positive");
+  assert(allTypes[2].getRank() == 2 && "Output rank must be 2");
+
+  dimSizeMap[0] = allTypes[2].getDimSize(0);
+  dimSizeMap[1] = allTypes[2].getDimSize(1);
+
+  // Get dimension size for 'd2' from input types which needs broadcast.
+  unsigned outputRank = allTypes[2].getRank();
+  for (unsigned i = 0; i < outputRank; i++) {
+    if (allTypes[i].getRank() < outputRank) {
+      dimSizeMap[2] = allTypes[i].getDimSize(0);
+      return;
+    }
+  }
+}
+
+/// Infers the broadcasted dimension and populates \p broadcastDims which is
+/// a map of dimensions to a pair of Boolean and AffineDimExpr position,
+/// indicating broadcast and the corresponding AffineDimExpr position.
+/// It uses input parameters \p explicitMap parsed from the op and \p defaultMap
+/// corresponding to an input operand.
+void MatmulOp::inferBroadcastDimensions(
+    AffineMap explicitMap, AffineMap defaultMap,
+    DenseMap<unsigned, std::pair<bool, unsigned>> &broadcastDims) {
+  assert(!explicitMap.isEmpty() && "Expected non empty map");
+  assert(!defaultMap.isEmpty() && "Expected non empty map");
+
+  llvm::SetVector<unsigned> typicalDims, providedDims, broadcastedDims;
+  // Build set of dimensions using default matmul indexing map
+  for (auto expr : defaultMap.getResults()) {
+    if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+      typicalDims.insert(dimExpr.getPosition());
+    }
+  }
+
+  // Build set of dimensions from explicitly provided indexing map
+  for (auto expr : explicitMap.getResults()) {
+    if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+      providedDims.insert(dimExpr.getPosition());
+    }
+  }
+
+  // Compute set difference to get broadcast dimensions
+  broadcastedDims = llvm::set_difference(typicalDims, providedDims);
+
+  // Update broadcastDims map
+  for (unsigned i = 0; i < typicalDims.size(); i++) {
+    broadcastDims[i] = {false, 0};
+    if (!providedDims.count(typicalDims[i])) {
+      broadcastDims[i] = {true, typicalDims[i]};
+    }
+  }
+}
+
+/// Construct and returns a broadcasted type for the input operand identified by
+/// input parameter \p opIndex.
+ShapedType MatmulOp::constructBroadcastedType(unsigned opIndex) {
+  assert(opIndex < 2 && "Operand index out of range");
+  DenseMap<unsigned, ShapedType> allTypes;
+  allTypes[0] = cast<ShapedType>(this->getOperand(0).getType());
+  allTypes[1] = cast<ShapedType>(this->getOperand(1).getType());
+  allTypes[2] = cast<ShapedType>(this->getOperand(2).getType());
+
+  ShapedType outputType = allTypes[2];
+  ShapedType inputType = allTypes[opIndex];
+  SmallVector<AffineMap, 3> defaultIndexingMaps;
+  DenseMap<unsigned, std::pair<bool, unsigned>> broadcastDims;
+  SmallVector<AffineMap, 3> indexingMaps = this->getIndexingMapsArray();
+
+  getDefaultIndexingMaps(defaultIndexingMaps);
+  inferBroadcastDimensions(indexingMaps[opIndex], defaultIndexingMaps[opIndex],
+                           broadcastDims);
+
+  AffineMap defaultInputMap = defaultIndexingMaps[opIndex];
+  // Initialize new shape for input operand requiring broadcast.
+  unsigned numDims = outputType.getRank();
+  SmallVector<int64_t, 4> newShape(numDims, ShapedType::kDynamic);
+
+  DenseMap<unsigned, unsigned> dimSizeMap;
+  inferDimensionSizes(allTypes, dimSizeMap);
+
+  // Fill in the known dimensions using defaultIndexingMap
+  for (unsigned i = 0; i < defaultInputMap.getNumResults(); ++i) {
+    if (!broadcastDims[i].first) {
+      if (auto dimExpr = dyn_cast<AffineDimExpr>(defaultInputMap.getResult(i)))
+        newShape[i] = dimSizeMap[dimExpr.getPosition()];
+    }
+  }
+
+  // Fill in the broadcast dimension.
+  for (unsigned i = 0; i < broadcastDims.size(); i++) {
+    if (broadcastDims[i].first) {
+      newShape[i] = dimSizeMap[broadcastDims[i].second];
+    }
+  }
+
+  // Create the new ShapedType
+  if (auto tensorType = dyn_cast<RankedTensorType>(inputType)) {
+    return RankedTensorType::get(newShape, tensorType.getElementType());
+  } else if (auto memrefType = dyn_cast<MemRefType>(inputType)) {
+    return MemRefType::get(newShape, memrefType.getElementType(),
+                           MemRefLayoutAttrInterface(),
+                           memrefType.getMemorySpace());
+  } else {
+    llvm::errs() << "Error: Unsupported ShapedType\n";
+    return ShapedType();
+  }
+}
+
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
+                                MatmulOp::getRegionBuilder());
+}
+void MatmulOp::print(OpAsmPrinter &p) {
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
+
+  SmallVector<Attribute, 3> indexingMaps;
+  getDefaultIndexingMap(getContext(), indexingMaps);
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+}
+
+LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void MatmulOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability MatmulOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
+} // namespace linalg
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index e4a6ec7487bb2f..d5e79b4d3cb6dd 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -383,23 +383,6 @@ def select(
     O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
 
 
- at linalg_structured_op
-def matmul(
-    A=TensorDef(T1, S.M, S.K),
-    B=TensorDef(T2, S.K, S.N),
-    C=TensorDef(U, S.M, S.N, output=True),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
-    """Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
-
-
 @linalg_structured_op
 def quantized_matmul(
     A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 1e8f1435ca0fa5..d782311f7542a2 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -29,6 +29,90 @@ func.func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>,
 
 // -----
 
+func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @matmul_bcast_a(
+// CHECK-SAME:                              %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                              %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) {
+// CHECK:           ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK:             %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK:             linalg.yield %[[VAL_7]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_a_b(
+// CHECK-SAME:                                %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                                %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) {
+// CHECK:           ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK:             %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK:             linalg.yield %[[VAL_7]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_b(
+// CHECK-SAME:                              %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                              %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) {
+// CHECK:           ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK:             %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK:             linalg.yield %[[VAL_7]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
 func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
   %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
                     outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
@@ -891,3 +975,86 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
 
   return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
 }
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_a_explicit(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                  %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                                  %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @matmul_transpose_b_explicit(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                           %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index c481a723c5623c..80f5b930ddaa89 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -361,6 +361,100 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
 
 // -----
 
+func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+  // expected-error @+1 {{expected attribute value}}
+  linalg.matmul indexing_maps = [
+                       ,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)
+                      outs(%arg2 :memref<2x4xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_transpose_a_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
+  // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 1}}
+  %0 = linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return %0: tensor<4x64xf32>
+}
+
+// -----
+
+func.func @invalid_transpose_b_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
+  // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #1 to be 1, but found 64}}
+  %0 = linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return %0: tensor<4x64xf32>
+}
+
+// -----
+
+func.func @invalid_bcast_a(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested (d2)}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 5, but found 7}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_matmul_bcast_b_transpose_a_wrong_dim(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #1 to be 3, but found 7}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+2 {{custom op 'indexing_maps' is unknown (tried 'func.indexing_maps' as well)}}
+  linalg.matmul ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) outs(%init : tensor<4x64xf32>)
+                        indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+  return
+}
+
+// -----
+
 func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
   // expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
   linalg.conv_2d_nhwc_hwcf
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 02ecbed232c8b5..7c5c41e622f459 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1201,6 +1201,225 @@ func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %a
 
 // -----
 
+// CHECK-LABEL: func @matmul_transpose_a_explicit
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_b_explicit(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                           %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_a
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_a_dim1
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_b
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_a_b(
+// CHECK-SAME:                                %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                                %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_b_dim1
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_a_transpose_b(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME:                                  %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                  %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_b_transpose_a(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                                          %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
 // CHECK-LABEL: func @matmul_transpose_b
 //       CHECK:   linalg.matmul_transpose_b
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 3bfbcf7d7f7c81..72045a07b2da80 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,81 +84,6 @@ def named_form(lhs, rhs):
 
     print(module)
 
-
-# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
- at run
-def testNamedStructuredOpGenericForm():
-    with Context() as ctx, Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def named_form(lhs, rhs):
-                init_result = tensor.empty([4, 8], f32)
-                #      CHECK: "linalg.matmul"(%{{.*}})
-                # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
-                # CHECK-SAME:    operandSegmentSizes = array<i32: 2, 1>
-                # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
-                # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
-                # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
-                # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
-                # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
-                return linalg.matmul(lhs, rhs, outs=[init_result])
-
-    module.operation.print(print_generic_op_form=True)
-
-
-# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
- at run
-def testNamedStructuredAsGenericOp():
-    with Context() as ctx, Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def generic_form(lhs, rhs):
-                init_result = tensor.EmptyOp([4, 8], f32)
-                # CHECK: linalg.generic
-                return linalg.matmul(
-                    lhs, rhs, outs=[init_result.result], emit_generic=True
-                )
-
-    print(module)
-
-
-# CHECK-LABEL: TEST: testOpResultFromOtherOp
- at run
-def testOpResultFromOtherOp():
-    with Context(), Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def pass_an_op_directly(arg0, arg1):
-                one = arith.ConstantOp(F32Type.get(), 1.0)
-                # CHECK: %[[LHS:.*]] = linalg.fill
-                lhs = linalg.fill(one, outs=[arg0])
-                # CHECK: %[[RHS:.*]] = linalg.fill
-                rhs = linalg.fill(one, outs=[arg1])
-                # CHECK: %[[INIT:.*]] = tensor.empty
-                init = tensor.EmptyOp([4, 8], f32)
-                # CHECK: linalg.matmul
-                # CHECK: ins(%[[LHS]], %[[RHS]]
-                # CHECK: outs(%[[INIT]]
-                return linalg.matmul(lhs, rhs, outs=init)
-
-    print(module)
-
-
 # CHECK-LABEL: TEST: testIdentityRegionOps
 @run
 def testIdentityRegionOps():



More information about the Mlir-commits mailing list