[Mlir-commits] [mlir] [mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (PR #104783)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Tue Oct 8 07:33:35 PDT 2024


https://github.com/shahidact updated https://github.com/llvm/llvm-project/pull/104783

>From 28f63eb6093400a7aede88f0058a7344b055e355 Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Wed, 21 Aug 2024 05:34:27 -0700
Subject: [PATCH] [mlir][linalg] Extends 'linalg.matmul' named op to define
 broadcast and transpose semantic.

Goals:
1. To add syntax to matmul without changing any of the existing syntax
   expectations for current usage. matmul is still just matmul.

2. To expose broadcast and transpose semantics on the three matmul
   variations: matmul, batch_matmul and batch_reduce_matmul.

Scope of this patch:
To expose broadcast and transpose semantics on the 'matmul'.

The broadcast and transpose semantic is as follows:

By default 'linalg.matmul' behavior will remain as is.Broadcast and
Transpose semantics can be appiled by specifying the explicit attribute
'indexing_maps' as shown below.This is a list attribute, so the list
must include all the maps if specified.

Example Transpose:
linalg.matmul indexing_maps = [
               affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
               affine_map<(d0, d1, d2) -> (d2, d1)>,
               affine_map<(d0, d1, d2) -> (d0, d1)>
               ]
               ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
               outs(%arg2: memref<3x7xf32>)

Example Broadcast:
linalg.matmul indexing_maps = [
               affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
               affine_map<(d0, d1, d2) -> (d2, d1)>,
               affine_map<(d0, d1, d2) -> (d0, d1)>
              ]
              ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
              outs(%arg2: memref<3x7xf32>)
---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |  10 +
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  72 ------
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 134 ++++++++++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  88 ++++++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 194 +++++++++++++-
 .../Linalg/Transforms/TransposeMatmul.cpp     |   6 +
 .../Linalg/Transforms/Vectorization.cpp       |   5 +
 .../NVGPU/TransformOps/NVGPUTransformOps.cpp  |   5 +
 .../linalg/opdsl/ops/core_named_ops.py        |  17 --
 .../Dialect/Linalg/generalize-named-ops.mlir  | 111 ++++++++
 mlir/test/Dialect/Linalg/invalid.mlir         | 159 ++++++++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 243 ++++++++++++++++++
 mlir/test/python/dialects/linalg/ops.py       |  75 ------
 13 files changed, 942 insertions(+), 177 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9b..e80dbb2afb9ef7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -684,6 +684,16 @@ def LinalgStructuredInterface
         return;
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return true if the user has supplied an explicit indexing maps for this op.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasUserDefinedMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return false; }]
+    >,
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 8cb698096ef5b7..97b90333e2b200 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul
-  cpp_class_name: MatmulOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_matmul
   cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 31f29139247267..972143cde996bf 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -535,6 +535,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+    
+  let summary = [{
+    Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
+    }];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                   ]
+                   ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+     ```
+
+    Example Broadcast:
+     ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+     ```
+
+     Example Broadcast and transpose:
+     ```
+     linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                       affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>
+
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+      SmallVector<AffineMap> getDefaultIndexingMaps();
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps();
+      /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+      /// user defined indexing maps are not equal to default map.
+      bool hasUserDefinedMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 40795879c3026d..53f25df5748705 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -15,13 +15,20 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
 #include <algorithm>
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -1140,9 +1147,66 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
          operandNumber - start;
 }
 
+/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
+/// defaultMap.
+static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
+  auto explicitRange = explictMap.getResults();
+  auto defaultRange = defaultMap.getResults();
+  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
+  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
+  llvm::set_union(explicitSet, defaultSet);
+  return explicitSet == defaultSet;
+}
+
+/// Returns true if the \p explictMap is broadcasted with respect to the
+/// \p defaultMap.
+static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
+  return explictMap.getNumResults() < defaultMap.getNumResults();
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the MatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult verifyExtendedMatmulSemantic(Operation *op,
+                                                  unsigned opIndex) {
+  MatmulOp matmulOp = dyn_cast<MatmulOp>(op);
+  assert(matmulOp && "Expected linalg.matmul");
+
+  // No verification for output required.
+  if (opIndex > 1)
+    return success();
+
+  SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
+  SmallVector<AffineMap, 3> defaultIndexingMaps =
+      matmulOp.getDefaultIndexingMaps();
+
+  auto opIndexingMap = opIndexingMaps[opIndex];
+  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
+  // Check general validity of indexing map results.
+  if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
+    return op->emitOpError() << "Unexpected dim expression in map result.";
+
+  OpOperand &opOperand = matmulOp->getOpOperand(opIndex);
+  auto inputType = cast<ShapedType>(opOperand.get().getType());
+  int64_t inputRank = inputType.getRank();
+  auto explicitNumRes = opIndexingMap.getNumResults();
+
+  // Check if the requested broadcast is valid.
+  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
+    if (inputRank != explicitNumRes)
+      return op->emitOpError() << "Mismatch between input rank and number of "
+                                  "results of given map.";
+    if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
+      return op->emitOpError()
+             << "Invalid broadcast requested, should be (d2).";
+    }
+    return success();
+  }
+  return success();
+}
+
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
-
   // Mixed tensor/buffer operands are not allowed.
   if (!linalgOp.hasPureTensorSemantics() &&
       !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1162,8 +1226,12 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
            << ") to be equal to the number of input/output operands ("
            << linalgOp->getNumOperands() << ")";
 
+  // Set this flag if this op has user defined maps. This is required to guard
+  // the below error condition which assume default indexing maps.
+  bool isUserDefinedSemantic = linalgOp.hasUserDefinedMaps();
   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+    unsigned opIndex = opOperand.getOperandNumber();
 
     // Symbols disallowed.
     if (indexingMap.getNumSymbols() != 0)
@@ -1172,19 +1240,26 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
 
     // Domain must be consistent.
     unsigned numLoops = linalgOp.getNumLoops();
-    if (indexingMap.getNumDims() != numLoops)
+    if (!isUserDefinedSemantic && indexingMap.getNumDims() != numLoops)
       return op->emitOpError("expected indexing_map #")
              << opOperand.getOperandNumber() << " to have " << numLoops
              << " dim(s) to match the number of loops";
 
     int64_t rank = linalgOp.getRank(&opOperand);
-    if (indexingMap.getNumResults() != rank)
+
+    if (!isUserDefinedSemantic && indexingMap.getNumResults() != rank)
       return op->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
-  }
 
+    // Try to verify broadcast/transpose semantic requested through explicit
+    // user defined indexing map.
+    if (isUserDefinedSemantic) {
+      if (failed(verifyExtendedMatmulSemantic(op, opIndex)))
+        return failure();
+    }
+  }
   SmallVector<unsigned> redDims;
   linalgOp.getReductionDims(redDims);
 
@@ -1194,9 +1269,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   // Check if given shapes match to inferred shapes.
   SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
   SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
-
-  // Verify only static cases since we can't get exact dimension sizes and loop
-  // ranges for dynamic cases in this stage.
+  // Verify only static cases since we can't get exact dimension sizes and
+  // loop ranges for dynamic cases in this stage.
   if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
     for (int64_t &range : endLoopRangeValues)
       range -= 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..6037612e4de949 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
@@ -37,12 +38,16 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cassert>
 #include <optional>
 
 using namespace mlir;
@@ -149,15 +154,36 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   // iterator_types is an auto-generated method.
 }
 
+/// Helper to create a typical indexing map for MatmulOp. Returns a list of
+/// AffineMap.
+static SmallVector<AffineMap>
+getDefaultIndexingMapsForMatmul(MLIRContext *context) {
+  AffineExpr d0, d1, d2;
+  SmallVector<AffineMap, 3> indexingMaps;
+  bindDims(context, d0, d1, d2);
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+  return indexingMaps;
+}
+
+/// Wrapper to return the typical indexing map array attribute for MatmulOp.
+static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
+  return llvm::map_to_vector(
+      getDefaultIndexingMapsForMatmul(context),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+}
+
 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
 /// The result types are derived automatically if `resultTensorTypes` is none.
 /// The body of the operation is filled using `regionBuilder`. All ods-gen
 /// created structured operations use the method to implement their builders.
-static void buildStructuredOp(OpBuilder &b, OperationState &state,
-                              std::optional<TypeRange> resultTensorTypes,
-                              ValueRange inputs, ValueRange outputs,
-                              ArrayRef<NamedAttribute> attributes,
-                              RegionBuilderFn regionBuilder) {
+static void buildStructuredOp(
+    OpBuilder &b, OperationState &state,
+    std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+    ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+    RegionBuilderFn regionBuilder,
+    std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
   // Derive the result types if needed.
   SmallVector<Type> derivedResultTypes =
       resultTensorTypes.value_or(TypeRange());
@@ -168,6 +194,20 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
   state.addOperands(inputs);
   state.addOperands(outputs);
   state.addTypes(derivedResultTypes);
+
+  // Initialize indexingMaps, for MatmulOp.
+  SmallVector<Attribute, 3> indexingMapsAttrVal;
+  if (indexingMaps.has_value()) {
+    for (mlir::AffineMap map : *indexingMaps) {
+      // Convert each AffineMap to an AffineMapAttr
+      indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
+    }
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  } else {
+    indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  }
+
   state.addAttributes(attributes);
   state.addAttribute(
       "operandSegmentSizes",
@@ -299,11 +339,48 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result,
                                           unsigned numRegionArgs,
                                           RegionBuilderFn regionBuilder) {
+
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
   // TODO: Enable when ods-gen supports captures.
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
 
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
   // TODO: consider merging results parsing into region parsing.
   // Need to wait for declarative assembly resolution to decide.
   SmallVector<Type, 1> outputTensorsTypes;
@@ -335,7 +412,7 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
       /*elidedAttrs=*/{"operandSegmentSizes",
                        // See generated code in
                        // LinalgNamedStructuredOps.yamlgen.cpp.inc
-                       "linalg.memoized_indexing_maps"});
+                       "linalg.memoized_indexing_maps", "indexing_maps"});
 
   // Printing is shared with generic ops, except for the region and
   // attributes.
@@ -3382,3 +3459,108 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+                                          utils::IteratorType::parallel,
+                                          utils::IteratorType::reduction};
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+/// Check if the op has broadcast and/or transpose semantic. Returns true if the
+/// user defined indexing maps are not equal to default map.
+bool MatmulOp::hasUserDefinedMaps() {
+  SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
+  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+  return defaultMaps != explicitMaps;
+}
+
+/// Implements the block region builder for the MatmulOp. This is called by
+/// 'fillStructuredOpRegion'.
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                             ArrayRef<NamedAttribute> attrs) {
+  assert(3 > 0 && block.getNumArguments() == 3 &&
+         "MatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
+  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(0));
+  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+/// Returns a list of AffineMap with the typical matmul indexing charactristic.
+SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
+  MLIRContext *context = this->getContext();
+  return getDefaultIndexingMapsForMatmul(context);
+}
+
+/// Returns true if the given broadcast map \p bcastMap is valid for this op.
+bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
+  assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
+  AffineExpr exp = bcastMap.getResult(0);
+  // Invalid map if the common dimension of matmul not found.
+  return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
+}
+
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
+                                MatmulOp::getRegionBuilder());
+}
+void MatmulOp::print(OpAsmPrinter &p) {
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
+
+  SmallVector<Attribute, 3> indexingMaps =
+      getDefaultIndexingMapAttr(getContext());
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+}
+
+LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void MatmulOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability MatmulOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
+} // namespace linalg
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index aa0052ce47fa7b..707028c9e27778 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -31,6 +31,12 @@ using namespace mlir::linalg;
 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
                                                      linalg::MatmulOp matmulOp,
                                                      bool transposeLHS) {
+  // Check to not let go the matmul with extended semantic, through this
+  // transform.
+  if (matmulOp.hasUserDefinedMaps())
+    return rewriter.notifyMatchFailure(
+        matmulOp, "only matmul ops with non-extended semantics are supported");
+
   if (!bufferization::hasTensorSemantics(matmulOp))
     return rewriter.notifyMatchFailure(
         matmulOp, "only matmul ops with tensors are supported");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 09c6b2683b4388..e3f010d9cfb20b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2071,6 +2071,11 @@ vectorizeScalableVectorPrecondition(Operation *op,
       return failure();
   }
 
+  // Check to not let go the matmul with extended semantic, through this
+  // transform.
+  if (linalgOp.hasUserDefinedMaps())
+    return failure();
+
   // Cond 4: Only the following ops are supported in the
   // presence of scalable vectors
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 0c2275bbc4b224..a08c383417a0b4 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -821,6 +821,11 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
   bool fail = true;
   // TODO: more robust detection of matmulOp, with transposes etc.
   if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
+    // Check to not let go the matmul with extended semantic, through this
+    // transform.
+    if (linalgOp.hasUserDefinedMaps())
+      return emitSilenceableError()
+             << "only matmul ops with non-extended semantics are supported";
     Location loc = linalgOp.getLoc();
     // TODO: more robust computation of laneId, for now assume a single warp.
     Value laneId = rewriter.create<gpu::ThreadIdOp>(
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index e4a6ec7487bb2f..d5e79b4d3cb6dd 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -383,23 +383,6 @@ def select(
     O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
 
 
- at linalg_structured_op
-def matmul(
-    A=TensorDef(T1, S.M, S.K),
-    B=TensorDef(T2, S.K, S.N),
-    C=TensorDef(U, S.M, S.N, output=True),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
-    """Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
-
-
 @linalg_structured_op
 def quantized_matmul(
     A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 1e8f1435ca0fa5..aba26c35931fd3 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -29,6 +29,34 @@ func.func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>,
 
 // -----
 
+func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @matmul_bcast_a(
+// CHECK-SAME:                              %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                              %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) {
+// CHECK:           ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK:             %[[VAL_6:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK:             linalg.yield %[[VAL_7]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
 func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
   %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
                     outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
@@ -891,3 +919,86 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
 
   return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
 }
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_a_explicit(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                  %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                                  %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @matmul_transpose_b_explicit(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                           %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index c481a723c5623c..6ee8712a302426 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -361,6 +361,165 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
 
 // -----
 
+func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+  // expected-error @+1 {{expected attribute value}}
+  linalg.matmul indexing_maps = [
+                       ,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)
+                      outs(%arg2 :memref<2x4xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_matmul_dim_a(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) {
+  // expected-error @+1 {{Unexpected dim expression in map result}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_matmul_dim_b(%arg0: memref<5x5xf32>, %arg1: memref<5x5xf32>, %arg2: memref<5x5xf32>) {
+  // expected-error @+1 {{Unexpected dim expression in map result}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2: memref<5x5xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_transpose_a_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
+  // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 1}}
+  %0 = linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return %0: tensor<4x64xf32>
+}
+
+// -----
+
+func.func @invalid_transpose_b_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) -> tensor<4x64xf32> {
+  // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #1 to be 1, but found 64}}
+  %0 = linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>)
+                      outs(%init : tensor<4x64xf32>) -> tensor<4x64xf32>
+  return %0: tensor<4x64xf32>
+}
+
+// -----
+
+func.func @invalid_bcast_a(%arg0: memref<3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{'linalg.matmul' op Invalid broadcast requested, should be (d2)}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_a_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{'linalg.matmul' op Mismatch between input rank and number of results of given map}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_b_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{'linalg.matmul' op Mismatch between input rank and number of results of given map}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<7xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 5, but found 7}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_matmul_bcast_b_transpose_a_wrong_dim(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  // expected-error @+1 {{'linalg.matmul' op Unexpected dim expression in map result.}}
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: tensor<1x64xf32>, %init: tensor<4x64xf32>) {
+  // expected-error @+2 {{custom op 'indexing_maps' is unknown (tried 'func.indexing_maps' as well)}}
+  linalg.matmul ins(%lhs, %rhs : tensor<4x1xf32>, tensor<1x64xf32>) outs(%init : tensor<4x64xf32>)
+                        indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+  return
+}
+
+// -----
+
 func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
   // expected-error @+1 {{inferred input/output operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
   linalg.conv_2d_nhwc_hwcf
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 02ecbed232c8b5..65c18de8424771 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1201,6 +1201,249 @@ func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %a
 
 // -----
 
+// CHECK-LABEL: func @matmul_transpose_a_explicit
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// -----
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+                      
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_b_explicit(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                           %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                      ]
+                      ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+                      outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_a(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_a
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_dim1(%arg0: memref<5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_a_dim1
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_b
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @matmul_bcast_a_b(%arg0: memref<5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_a_b(
+// CHECK-SAME:                                %[[VAL_0:.*]]: memref<5xf32>, %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                                %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_b_dim1(%arg0: memref<3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d0, d2)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @matmul_bcast_b_dim1
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+
+// -----
+
+func.func @dynamic_matmul_bcast_a(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d2, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @dynamic_matmul_bcast_a(
+// CHECK-SAME:                                      %[[VAL_0:.*]]: memref<?xf32>,
+// CHECK-SAME:                                      %[[VAL_1:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                      %[[VAL_2:.*]]: memref<?x?xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<?xf32>, memref<?x?xf32>) outs(%[[VAL_2]] : memref<?x?xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_a_transpose_b(%arg0: memref<5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d1, d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_a_transpose_b(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME:                                  %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                  %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @matmul_bcast_b_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>,
+                       affine_map<(d0, d1, d2) -> (d2)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<5xf32>) outs(%arg2: memref<3x7xf32>)
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_bcast_b_transpose_a(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                                          %[[VAL_2:.*]]: memref<3x7xf32>) {
+// CHECK:           linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
 // CHECK-LABEL: func @matmul_transpose_b
 //       CHECK:   linalg.matmul_transpose_b
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 3bfbcf7d7f7c81..72045a07b2da80 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,81 +84,6 @@ def named_form(lhs, rhs):
 
     print(module)
 
-
-# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
- at run
-def testNamedStructuredOpGenericForm():
-    with Context() as ctx, Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def named_form(lhs, rhs):
-                init_result = tensor.empty([4, 8], f32)
-                #      CHECK: "linalg.matmul"(%{{.*}})
-                # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
-                # CHECK-SAME:    operandSegmentSizes = array<i32: 2, 1>
-                # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
-                # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
-                # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
-                # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
-                # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
-                return linalg.matmul(lhs, rhs, outs=[init_result])
-
-    module.operation.print(print_generic_op_form=True)
-
-
-# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
- at run
-def testNamedStructuredAsGenericOp():
-    with Context() as ctx, Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def generic_form(lhs, rhs):
-                init_result = tensor.EmptyOp([4, 8], f32)
-                # CHECK: linalg.generic
-                return linalg.matmul(
-                    lhs, rhs, outs=[init_result.result], emit_generic=True
-                )
-
-    print(module)
-
-
-# CHECK-LABEL: TEST: testOpResultFromOtherOp
- at run
-def testOpResultFromOtherOp():
-    with Context(), Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def pass_an_op_directly(arg0, arg1):
-                one = arith.ConstantOp(F32Type.get(), 1.0)
-                # CHECK: %[[LHS:.*]] = linalg.fill
-                lhs = linalg.fill(one, outs=[arg0])
-                # CHECK: %[[RHS:.*]] = linalg.fill
-                rhs = linalg.fill(one, outs=[arg1])
-                # CHECK: %[[INIT:.*]] = tensor.empty
-                init = tensor.EmptyOp([4, 8], f32)
-                # CHECK: linalg.matmul
-                # CHECK: ins(%[[LHS]], %[[RHS]]
-                # CHECK: outs(%[[INIT]]
-                return linalg.matmul(lhs, rhs, outs=init)
-
-    print(module)
-
-
 # CHECK-LABEL: TEST: testIdentityRegionOps
 @run
 def testIdentityRegionOps():



More information about the Mlir-commits mailing list