[Mlir-commits] [mlir] [mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (PR #104783)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Mon Sep 9 05:39:34 PDT 2024


https://github.com/shahidact updated https://github.com/llvm/llvm-project/pull/104783

>From 2ba79c95895138ded921e172dd071aca79a403ad Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 21 Aug 2024 05:34:27 -0700
Subject: [PATCH] [mlir][linalg] Extends 'linalg.matmul' named op to define
 broadcast and transpose semantic.

Goals:
1. To add syntax to matmul without changing any of the existing syntax
   expectations for current usage. matmul is still just matmul.

2. To expose broadcast and transpose semantics on the three matmul
   variations: matmul, batch_matmul and batch_reduce_matmul.

Scope of this patch:
To expose broadcast and transpose semantics on the 'matmul'.

The broadcast and transpose semantic is as follows:

By default 'linalg.matmul' behavior will remain as is.Broadcast and
Transpose semantics can be appiled by specifying the explicit attribute
'indexing_maps' as shown below.This is a list attribute, so the list
must include all the maps if specified.

Example Transpose:
linalg.matmul indexing_maps = [
               affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
               affine_map<(d0, d1, d2) -> (d2, d1)>,
               affine_map<(d0, d1, d2) -> (d0, d1)>
               ]
               ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
               outs(%arg2: memref<3x7xf32>)

Example Broadcast:
linalg.matmul indexing_maps = [
               affine_map<(d0, d1, d2) -> (d0)>,     // broadcast
               affine_map<(d0, d1, d2) -> (d2, d1)>,
               affine_map<(d0, d1, d2) -> (d0, d1)>
              ]
              ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
              outs(%arg2: memref<3x7xf32>)
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  72 ----
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |   1 +
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 127 +++++++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  75 +++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 328 +++++++++++++++++-
 .../linalg/opdsl/ops/core_named_ops.py        |  17 -
 .../Dialect/Linalg/generalize-named-ops.mlir  | 169 +++++++++
 mlir/test/Dialect/Linalg/invalid.mlir         |  94 +++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 221 ++++++++++++
 mlir/test/python/dialects/linalg/ops.py       |  75 ----
 10 files changed, 1002 insertions(+), 177 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 8cb698096ef5b7..97b90333e2b200 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul
-  cpp_class_name: MatmulOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_matmul
   cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 5b6a90f806bedd..c58c64fbb5554a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -400,4 +400,5 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
   let hasVerifier = 1;
 }
 
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ac61117c3d6e36..14fd48b03bd0ab 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -534,6 +534,133 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", !listconcat([AttrSizedOperandSegments],
+  /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+    
+  let summary = [{Performs a matrix multiplication of two 2D inputs without broadcast or transpose.}];
+  let description = [{Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                   ]
+                   ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+     ```
+
+    Example Broadcast:
+     ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d0)>,     // broadcast
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+     ```
+
+     Example Broadcast and transpose:
+     ```
+     inalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                       affine_map<(d0, d1, d2) -> (d1)>,     // broadcast
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>
+
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+      void getDefaultIndexingMaps(SmallVectorImpl<AffineMap> &indexingMaps);
+      void inferDimensionSizes(llvm::DenseMap<unsigned, ShapedType> &allTypes,
+                       ArrayRef<AffineMap> indexingMaps,
+                       llvm::DenseMap<unsigned, unsigned> &dimSizeMap);
+      void inferBroadcastDimensions(AffineMap explicitMap, AffineMap defaultMap,
+                       DenseMap<unsigned, std::pair<bool, unsigned>> &broadcastDims);
+      ShapedType constructBroadcastedType(unsigned opIndex);
+      bool hasBroadcastSemantic(unsigned opIndex);
+      
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 0c48a5aeb26a26..c38d5c5a5df764 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -17,10 +17,15 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 
 using namespace mlir;
@@ -1139,9 +1144,40 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
          operandNumber - start;
 }
 
+static LogicalResult verifyBroadcastSemantic(
+    Operation *op, Type &src, Type &dst,
+    DenseMap<unsigned, std::pair<bool, unsigned>> &broadcastedDims) {
+  auto srcType = cast<ShapedType>(src);
+  auto dstType = cast<ShapedType>(dst);
+
+  int64_t inputRank = srcType.getRank();
+  int64_t initRank = dstType.getRank();
+
+  unsigned numBroadcastedDims = 0;
+  for (auto it : broadcastedDims) {
+    if (it.second.first) {
+      numBroadcastedDims++;
+    }
+  }
+  if ((size_t)inputRank + numBroadcastedDims != (size_t)initRank)
+    return op->emitOpError()
+           << "input rank plus added dimensions does not "
+              "match init rank. input rank: "
+           << inputRank << ", dimensions size: " << numBroadcastedDims
+           << ", init rank: " << initRank;
+
+  for (unsigned dim = 0; dim < broadcastedDims.size(); dim++) {
+    if (broadcastedDims[dim].first && (dim < 0 || dim >= initRank))
+      return op->emitOpError()
+             << "dimension " << dim << " is out of range. expected range: [0, "
+             << initRank - 1 << "], got: " << dim;
+  }
+
+  return success();
+}
+
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
-
   // Mixed tensor/buffer operands are not allowed.
   if (!linalgOp.hasPureTensorSemantics() &&
       !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1161,9 +1197,11 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
            << ") to be equal to the number of input/output operands ("
            << linalgOp->getNumOperands() << ")";
 
+  bool hasBroadcast = false;
+  MatmulOp matmulOp = dyn_cast<MatmulOp>(op);
   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
-
+    unsigned opIndex = opOperand.getOperandNumber();
     // Symbols disallowed.
     if (indexingMap.getNumSymbols() != 0)
       return op->emitOpError("unexpected symbols in indexing_map #")
@@ -1177,13 +1215,37 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
              << " dim(s) to match the number of loops";
 
     int64_t rank = linalgOp.getRank(&opOperand);
-    if (indexingMap.getNumResults() != rank)
+    // Check if matmulOp need broadcast.
+    if (matmulOp)
+      hasBroadcast = matmulOp.hasBroadcastSemantic(opIndex);
+
+    if (!hasBroadcast && indexingMap.getNumResults() != rank)
       return op->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
-  }
 
+    // For verification of broadcast/transpose, create a temporary broadcasted
+    // type and try to verify the constructed type from the provided types and
+    // indexing maps.
+    if (hasBroadcast) {
+      llvm::SmallVector<AffineMap, 3> defaultIndexingMaps;
+      matmulOp.getDefaultIndexingMaps(defaultIndexingMaps);
+      // Construct the broadcasted type
+      SmallVector<AffineMap, 3> opIndexingMaps =
+          matmulOp.getIndexingMapsArray();
+      Type originalType = opOperand.get().getType();
+      DenseMap<unsigned, std::pair<bool, unsigned>> broadcastedDims;
+      matmulOp.inferBroadcastDimensions(opIndexingMaps[opIndex],
+                                        defaultIndexingMaps[opIndex],
+                                        broadcastedDims);
+      Type broadcastedType = matmulOp.constructBroadcastedType(opIndex);
+      if (failed(verifyBroadcastSemantic(op, originalType, broadcastedType,
+                                         broadcastedDims)))
+        return failure();
+    }
+    hasBroadcast = false;
+  }
   SmallVector<unsigned> redDims;
   linalgOp.getReductionDims(redDims);
 
@@ -1193,9 +1255,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   // Check if given shapes match to inferred shapes.
   SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
   SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
-
-  // Verify only static cases since we can't get exact dimension sizes and loop
-  // ranges for dynamic cases in this stage.
+  // Verify only static cases since we can't get exact dimension sizes and
+  // loop ranges for dynamic cases in this stage.
   if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
     for (int64_t &range : endLoopRangeValues)
       range -= 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76df3ecf2d2bd4..87e565bc3be540 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
@@ -36,12 +37,16 @@
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cassert>
 #include <optional>
 
 using namespace mlir;
@@ -148,15 +153,30 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   // iterator_types is an auto-generated method.
 }
 
+static void getDefaultIndexingMap(MLIRContext *context,
+                                  SmallVectorImpl<Attribute> &indexingMaps) {
+  Attribute mapAttr;
+  mapAttr = llvm::cast<AffineMapAttr>(
+      mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d2)>", context));
+  indexingMaps.push_back(mapAttr);
+  mapAttr = llvm::cast<AffineMapAttr>(
+      mlir::parseAttribute("affine_map<(d0, d1, d2)->(d2, d1)>", context));
+  indexingMaps.push_back(mapAttr);
+  mapAttr = llvm::cast<AffineMapAttr>(
+      mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d1)>", context));
+  indexingMaps.push_back(mapAttr);
+}
+
 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
 /// The result types are derived automatically if `resultTensorTypes` is none.
 /// The body of the operation is filled using `regionBuilder`. All ods-gen
 /// created structured operations use the method to implement their builders.
-static void buildStructuredOp(OpBuilder &b, OperationState &state,
-                              std::optional<TypeRange> resultTensorTypes,
-                              ValueRange inputs, ValueRange outputs,
-                              ArrayRef<NamedAttribute> attributes,
-                              RegionBuilderFn regionBuilder) {
+static void buildStructuredOp(
+    OpBuilder &b, OperationState &state,
+    std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+    ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+    RegionBuilderFn regionBuilder,
+    std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
   // Derive the result types if needed.
   SmallVector<Type> derivedResultTypes =
       resultTensorTypes.value_or(TypeRange());
@@ -167,6 +187,23 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
   state.addOperands(inputs);
   state.addOperands(outputs);
   state.addTypes(derivedResultTypes);
+
+  // Initialize indexingMaps, for MatmulOp.
+  SmallVector<Attribute, 3> indexingMapsAttrVal;
+  StringAttr OpName = b.getStringAttr(state.name.getStringRef());
+  if (OpName.getValue() == "linalg.matmul") {
+    if (indexingMaps.has_value()) {
+      for (mlir::AffineMap map : *indexingMaps) {
+        // Convert each AffineMap to an AffineMapAttr
+        indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
+      }
+      state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+    } else {
+      getDefaultIndexingMap(b.getContext(), indexingMapsAttrVal);
+      state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+    }
+  }
+
   state.addAttributes(attributes);
   state.addAttribute(
       "operandSegmentSizes",
@@ -298,11 +335,48 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result,
                                           unsigned numRegionArgs,
                                           RegionBuilderFn regionBuilder) {
+
+  SmallVector<Attribute, 3> indexingMaps;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMaps.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMaps.empty()) {
+    getDefaultIndexingMap(result.getContext(), indexingMaps);
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMaps));
+
   // TODO: Enable when ods-gen supports captures.
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
 
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
   // TODO: consider merging results parsing into region parsing.
   // Need to wait for declarative assembly resolution to decide.
   SmallVector<Type, 1> outputTensorsTypes;
@@ -334,7 +408,7 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
       /*elidedAttrs=*/{"operandSegmentSizes",
                        // See generated code in
                        // LinalgNamedStructuredOps.yamlgen.cpp.inc
-                       "linalg.memoized_indexing_maps"});
+                       "linalg.memoized_indexing_maps", "indexing_maps"});
 
   // Printing is shared with generic ops, except for the region and
   // attributes.
@@ -3326,3 +3400,245 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+                                          utils::IteratorType::parallel,
+                                          utils::IteratorType::reduction};
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                             ArrayRef<NamedAttribute> attrs) {
+  assert(3 > 0 && block.getNumArguments() == 3 &&
+         "MatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
+  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(0));
+  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+/// Gives default indexing map for matmul operation.
+void MatmulOp::getDefaultIndexingMaps(
+    SmallVectorImpl<AffineMap> &indexingMaps) {
+  MLIRContext *context = this->getContext();
+  AffineExpr d0, d1, d2;
+  bindDims(context, d0, d1, d2);
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+}
+
+/// Returns true, if the given operand needs broadcast semantic.
+bool MatmulOp::hasBroadcastSemantic(unsigned opIndex) {
+  if (opIndex == 2)
+    return false;
+  SmallVector<AffineMap, 3> defaultMaps;
+  SmallVector<AffineMap, 3> explicitMaps;
+  getDefaultIndexingMaps(defaultMaps);
+  explicitMaps = getIndexingMapsArray();
+  if (explicitMaps[opIndex].getNumResults() <
+      defaultMaps[opIndex].getNumResults())
+    return true;
+  return false;
+}
+
+/// Infer sizes of all the dim expression 'd0, d1, d2' using given input types
+/// and indexing maps.
+void MatmulOp::inferDimensionSizes(
+    llvm::DenseMap<unsigned, ShapedType> &allTypes,
+    ArrayRef<AffineMap> indexingMaps,
+    llvm::DenseMap<unsigned, unsigned> &dimSizeMap) {
+  assert(!indexingMaps.empty() && "Expected non empty map");
+  assert(!allTypes.empty() && "Expected non empty types");
+  assert(allTypes[0].getRank() > 0 && allTypes[1].getRank() > 0 &&
+         "Input rank must be positive");
+  assert(allTypes[2].getRank() == 2 && "Output rank must be 2");
+
+  dimSizeMap[0] = allTypes[2].getDimSize(0);
+  dimSizeMap[1] = allTypes[2].getDimSize(1);
+
+  // Try to infer size of dim expression 'd2' from LHS type.
+  bool sizeFound = false;
+  if (allTypes[0].getRank() < allTypes[2].getRank()) {
+    AffineMap rhsMap = indexingMaps[1];
+    unsigned resNum = rhsMap.getNumResults();
+    for (unsigned i = 0; i < resNum; i++) {
+      // For each result or mapped dim expression
+      AffineExpr exp = rhsMap.getResult(i);
+      // Check if the dim exp is the 3rd expression 'd2'.
+      if (exp.isFunctionOfDim(2)) {
+        dimSizeMap[2] = allTypes[1].getDimSize(i);
+        sizeFound = true;
+        break;
+      }
+    }
+  }
+
+  // Try to infer size of dim expression 'd2' from RHS type.
+  if (!sizeFound && allTypes[1].getRank() < allTypes[2].getRank()) {
+    AffineMap lhsMap = indexingMaps[0];
+    unsigned resNum = lhsMap.getNumResults();
+    // For each result or mapped dim expression, check if the dim exp is the 3rd
+    // dim expression 'd2'.
+    for (unsigned i = 0; i < resNum; i++) {
+      AffineExpr exp = lhsMap.getResult(i);
+      if (exp.isFunctionOfDim(2)) {
+        dimSizeMap[2] = allTypes[0].getDimSize(i);
+        sizeFound = true;
+        break;
+      }
+    }
+  }
+
+  if (!sizeFound)
+    dimSizeMap[2] = ShapedType::kDynamic;
+}
+
+void MatmulOp::inferBroadcastDimensions(
+    AffineMap explicitMap, AffineMap defaultMap,
+    DenseMap<unsigned, std::pair<bool, unsigned>> &broadcastDims) {
+  assert(!explicitMap.isEmpty() && "Expected non empty map");
+  assert(!defaultMap.isEmpty() && "Expected non empty map");
+
+  llvm::SetVector<unsigned> typicalDims, providedDims, broadcastedDims;
+  // Build set of dimensions using default matmul indexing map
+  for (auto expr : defaultMap.getResults()) {
+    if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+      typicalDims.insert(dimExpr.getPosition());
+    }
+  }
+
+  // Build set of dimensions from explicitly provided indexing map
+  for (auto expr : explicitMap.getResults()) {
+    if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+      providedDims.insert(dimExpr.getPosition());
+    }
+  }
+
+  // Compute set difference to get broadcast dimensions
+  broadcastedDims = llvm::set_difference(typicalDims, providedDims);
+
+  // Update broadcastDims map
+  for (unsigned i = 0; i < typicalDims.size(); i++) {
+    broadcastDims[i] = {false, 0};
+    if (!providedDims.count(typicalDims[i])) {
+      broadcastDims[i] = {true, typicalDims[i]};
+    }
+  }
+}
+
+ShapedType MatmulOp::constructBroadcastedType(unsigned opIndex) {
+  assert(opIndex < 2 && "Operand index out of range");
+  DenseMap<unsigned, ShapedType> allTypes;
+  allTypes[0] = cast<ShapedType>(this->getOperand(0).getType());
+  allTypes[1] = cast<ShapedType>(this->getOperand(1).getType());
+  allTypes[2] = cast<ShapedType>(this->getOperand(2).getType());
+
+  ShapedType outputType = allTypes[2];
+  ShapedType inputType = allTypes[opIndex];
+  SmallVector<AffineMap, 3> defaultIndexingMaps;
+  DenseMap<unsigned, std::pair<bool, unsigned>> broadcastDims;
+  SmallVector<AffineMap, 3> indexingMaps = this->getIndexingMapsArray();
+
+  getDefaultIndexingMaps(defaultIndexingMaps);
+  inferBroadcastDimensions(indexingMaps[opIndex], defaultIndexingMaps[opIndex],
+                           broadcastDims);
+
+  AffineMap inputMap = indexingMaps[opIndex];
+  AffineMap defaultInputMap = defaultIndexingMaps[opIndex];
+  // Initialize new shape for input operand requiring broadcast.
+  unsigned numDims = outputType.getRank();
+  SmallVector<int64_t, 4> newShape(numDims, ShapedType::kDynamic);
+
+  DenseMap<unsigned, unsigned> dimSizeMap;
+  inferDimensionSizes(allTypes, indexingMaps, dimSizeMap);
+
+  // Fill in the known dimensions using defaultIndexingMap
+  unsigned inputDimIdx = 0;
+  for (unsigned i = 0; i < defaultInputMap.getNumResults(); ++i) {
+    if (!broadcastDims[i].first) {
+      if (auto dimExpr = dyn_cast<AffineDimExpr>(defaultInputMap.getResult(i)))
+        newShape[i] = dimSizeMap[dimExpr.getPosition()];
+    }
+  }
+
+  // Fill in the broadcast dimension.
+  for (unsigned i = 0; i < broadcastDims.size(); i++) {
+    if (broadcastDims[i].first) {
+      newShape[i] = dimSizeMap[broadcastDims[i].second];
+    }
+  }
+
+  // Create the new ShapedType
+  if (auto tensorType = dyn_cast<RankedTensorType>(inputType)) {
+    return RankedTensorType::get(newShape, tensorType.getElementType());
+  } else if (auto memrefType = dyn_cast<MemRefType>(inputType)) {
+    return MemRefType::get(newShape, memrefType.getElementType(),
+                           MemRefLayoutAttrInterface(),
+                           memrefType.getMemorySpace());
+  } else {
+    llvm::errs() << "Error: Unsupported ShapedType\n";
+    return ShapedType();
+  }
+}
+
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
+                                MatmulOp::getRegionBuilder());
+}
+void MatmulOp::print(OpAsmPrinter &p) {
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
+
+  SmallVector<Attribute, 3> indexingMaps;
+  getDefaultIndexingMap(getContext(), indexingMaps);
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+}
+
+LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void MatmulOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+} // namespace linalg
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index e4a6ec7487bb2f..d5e79b4d3cb6dd 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -383,23 +383,6 @@ def select(
     O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
 
 
- at linalg_structured_op
-def matmul(
-    A=TensorDef(T1, S.M, S.K),
-    B=TensorDef(T2, S.K, S.N),
-    C=TensorDef(U, S.M, S.N, output=True),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
-    """Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
-
-
 @linalg_structured_op
 def quantized_matmul(
     A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 1e8f1435ca0fa5..365c65d9018d5d 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -29,6 +29,92 @@ func.func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>,
 
 // -----
 
+func.func @matmul_bcast_a(%arg0: memref<3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @matmul_bcast_a(
+// CHECK-SAME:                              %[[VAL_0:.*]]: memref<3xf32>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                              %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<3xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) {
+// CHECK:           ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK:             %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK:             linalg.yield %[[VAL_7]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_a_b(%arg0: memref<3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_a_b(
+// CHECK-SAME:                                %[[VAL_0:.*]]: memref<3xf32>,
+// CHECK-SAME:                                %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                                %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) {
+// CHECK:           ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK:             %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK:             linalg.yield %[[VAL_7]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_b(
+// CHECK-SAME:                              %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                              %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) {
+// CHECK:           ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK:             %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK:             linalg.yield %[[VAL_7]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
 func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
   %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
                     outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
@@ -891,3 +977,86 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
 
   return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
 }
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_a_explicit(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                  %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                                  %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @matmul_transpose_b_explicit(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                           %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index c481a723c5623c..e27772c2ec2622 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -361,6 +361,100 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
 
 // -----
 
+func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+  // expected-error @+1 {{expected attribute value}}
+  linalg.matmul indexing_maps = [
+                       ,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)
+                      outs(%arg2 :memref<2x4xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_transpose_a_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
+  // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 1}}
+  %0 = linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return %0: tensor<4x64xf32>
+}
+
+// -----
+
+func.func @invalid_transpose_b_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
+  // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #1 to be 1, but found 64}}
+  %0 = linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return %0: tensor<4x64xf32>
+}
+
+// -----
+
+func.func @invalid_bcast_a(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{input rank plus added dimensions does not match init rank. input rank: 2, dimensions size: 1, init rank: 2}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 5, but found 7}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_matmul_bcast_b_transpose_a_wrong_dim(%arg0: memref<5x3xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 5, but found 7}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+2 {{custom op 'indexing_maps' is unknown (tried 'func.indexing_maps' as well)}}
+  linalg.matmul ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) outs(%init : tensor<4x64xf32>)
+                        indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+  return
+}
+
+// -----
+
 func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
   // expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
   linalg.conv_2d_nhwc_hwcf
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 02ecbed232c8b5..c527aed2140979 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1201,6 +1201,227 @@ func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %a
 
 // -----
 
+// CHECK-LABEL: func @matmul_transpose_a_explicit
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_b_explicit(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                           %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_a(%arg0: memref<3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_a
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_a_dim1
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_b
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_b(%arg0: memref<3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_a_b(
+// CHECK-SAME:                                %[[VAL_0:.*]]: memref<3xf32>,
+// CHECK-SAME:                                %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                                %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_b_dim1
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_transpose_b(%arg0: memref<3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_a_transpose_b(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: memref<3xf32>,
+// CHECK-SAME:                                  %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                  %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_b_transpose_a(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<7xf32>,
+// CHECK-SAME:                                          %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
 // CHECK-LABEL: func @matmul_transpose_b
 //       CHECK:   linalg.matmul_transpose_b
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 3bfbcf7d7f7c81..72045a07b2da80 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,81 +84,6 @@ def named_form(lhs, rhs):
 
     print(module)
 
-
-# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
- at run
-def testNamedStructuredOpGenericForm():
-    with Context() as ctx, Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def named_form(lhs, rhs):
-                init_result = tensor.empty([4, 8], f32)
-                #      CHECK: "linalg.matmul"(%{{.*}})
-                # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
-                # CHECK-SAME:    operandSegmentSizes = array<i32: 2, 1>
-                # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
-                # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
-                # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
-                # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
-                # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
-                return linalg.matmul(lhs, rhs, outs=[init_result])
-
-    module.operation.print(print_generic_op_form=True)
-
-
-# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
- at run
-def testNamedStructuredAsGenericOp():
-    with Context() as ctx, Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def generic_form(lhs, rhs):
-                init_result = tensor.EmptyOp([4, 8], f32)
-                # CHECK: linalg.generic
-                return linalg.matmul(
-                    lhs, rhs, outs=[init_result.result], emit_generic=True
-                )
-
-    print(module)
-
-
-# CHECK-LABEL: TEST: testOpResultFromOtherOp
- at run
-def testOpResultFromOtherOp():
-    with Context(), Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def pass_an_op_directly(arg0, arg1):
-                one = arith.ConstantOp(F32Type.get(), 1.0)
-                # CHECK: %[[LHS:.*]] = linalg.fill
-                lhs = linalg.fill(one, outs=[arg0])
-                # CHECK: %[[RHS:.*]] = linalg.fill
-                rhs = linalg.fill(one, outs=[arg1])
-                # CHECK: %[[INIT:.*]] = tensor.empty
-                init = tensor.EmptyOp([4, 8], f32)
-                # CHECK: linalg.matmul
-                # CHECK: ins(%[[LHS]], %[[RHS]]
-                # CHECK: outs(%[[INIT]]
-                return linalg.matmul(lhs, rhs, outs=init)
-
-    print(module)
-
-
 # CHECK-LABEL: TEST: testIdentityRegionOps
 @run
 def testIdentityRegionOps():



More information about the Mlir-commits mailing list