[Mlir-commits] [mlir] [mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (PR #104783)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Mon Aug 19 07:17:21 PDT 2024


https://github.com/shahidact created https://github.com/llvm/llvm-project/pull/104783

The main goal of this patch is to extend the semantic of 'linalg.matmul' named op to include per operand transpose semantic while also laying out a way to move ops definition from OpDSL to tablegen. Hence, it is implemented in tablegen. Transpose semantic is as follows.

By default 'linalg.matmul' behavior will remain as is. Transpose semantics can be appiled on per input operand by specifying the optional permutation attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd input) for each operand explicitly as needed. By default, no transpose is mandated for any of the input operand.

    Example:
    ```
    %val = linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
              outs(%arg2: memref<3x7xf32>)
              permutationA = [1, 0]
              permutationB = [0, 1]
    ```

>From c08e4110903ea4cb2ba5106e5dd373f92fe2030c Mon Sep 17 00:00:00 2001
From: mshahid <md.asghar.ahmad.shahid at intel.com>
Date: Thu, 8 Aug 2024 07:52:33 -0700
Subject: [PATCH] [mlir][linalg] Introduce transpose semantic to
 'linalg.matmul'.

The main goal of this patch is to extend the semantic of 'linalg.matmul'
named op to include per operand transpose semantic while also laying out
a way to move ops definition from OpDSL to tablegen. Hence, it is
implemented in tablegen. Transpose semantic is as follows.

By default 'linalg.matmul' behavior will remain as is. Transpose semantics
can be appiled on per input operand by specifying the optional permutation
attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd
input) for each operand explicitly as needed. By default, no transpose is
mandated for any of the input operand.

    Example:
    ```
    %val = linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
              outs(%arg2: memref<3x7xf32>)
              permutationA = [1, 0]
              permutationB = [0, 1]
    ```
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   |  72 ---------
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |   1 +
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 100 ++++++++++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 152 +++++++++++++++++-
 .../linalg/opdsl/ops/core_named_ops.py        |  17 --
 .../Dialect/Linalg/generalize-named-ops.mlir  |  62 +++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       |  33 ++++
 mlir/test/python/dialects/linalg/ops.py       |  75 ---------
 8 files changed, 347 insertions(+), 165 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 46b3ec0f60ebfa..8e2e827a12cc4e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul
-  cpp_class_name: MatmulOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_matmul
   cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a9007c8db3078e..4ca7c5f0f1f676 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -271,4 +271,5 @@ def Linalg_WinogradOutputTransformOp :
   let hasVerifier = 1;
 }
 
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ac61117c3d6e36..5e6940b42db976 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -534,6 +534,106 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", !listconcat([AttrSizedOperandSegments],
+  /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+    
+  let summary = [{Performs a matrix multiplication of two 2D inputs without transpose.}];
+  let description = [{Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
+
+    Per input operand transpose can be performed by specifying the required permutation
+    attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd input) for
+    each operand explicitly. By default, no transpose is mandated for each input operand.
+
+    Example:
+    ```
+    %val = linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+              outs(%arg2: memref<3x7xf32>)
+              permutationA = [1, 0]
+              permutationB = [0, 1]
+     ```
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{0, 1}">, [DenseArrayCount<2>]>:$permutationA,
+      ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{0, 1}">, [DenseArrayCount<2>]>:$permutationB,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs, "DenseI64ArrayAttr":$permutationA, "DenseI64ArrayAttr":$permutationB, "Attribute":$cast,
+       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("permutationA", permutationA);
+        $_state.addAttribute("permutationB", permutationB);
+        $_state.addAttribute("cast", cast);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>
+
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      // Auto-generated.
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+      ArrayAttr getIndexingMaps();
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 99b625d99fec2e..bef19e737ca6c7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -303,6 +303,26 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
 
+  if (parser.parseOptionalKeyword("permutationA").succeeded()) {
+    if (parser.parseEqual())
+      return failure();
+
+    result.attributes.set("permutationA",
+                          DenseI64ArrayAttr::parse(parser, Type{}));
+  }
+
+  if (parser.parseOptionalKeyword("permutationB").succeeded()) {
+    if (parser.parseEqual())
+      return failure();
+
+    result.attributes.set("permutationB",
+                          DenseI64ArrayAttr::parse(parser, Type{}));
+  }
+
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
   // TODO: consider merging results parsing into region parsing.
   // Need to wait for declarative assembly resolution to decide.
   SmallVector<Type, 1> outputTensorsTypes;
@@ -334,7 +354,8 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
       /*elidedAttrs=*/{"operandSegmentSizes",
                        // See generated code in
                        // LinalgNamedStructuredOps.yamlgen.cpp.inc
-                       "linalg.memoized_indexing_maps"});
+                       "linalg.memoized_indexing_maps", "permutationA",
+                       "permutationB"});
 
   // Printing is shared with generic ops, except for the region and
   // attributes.
@@ -2980,3 +3001,132 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+                                          utils::IteratorType::parallel,
+                                          utils::IteratorType::reduction};
+}
+
+ArrayAttr MatmulOp::getIndexingMaps() {
+  static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
+  ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
+  if (cached)
+    return cached;
+
+  MLIRContext *context = getContext();
+  SmallVector<AffineMap> maps;
+
+  unsigned numResults;
+  SmallVector<AffineExpr, 3> dimReplacements;
+  AffineMap originalMap =
+      llvm::cast<AffineMapAttr>(
+          mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d2)>", context))
+          .getValue();
+  numResults = originalMap.getNumResults();
+  for (unsigned i = 0; i < numResults; i++) {
+    AffineExpr expr = originalMap.getResult(getPermutationA()[i]);
+    dimReplacements.push_back(expr);
+  }
+
+  AffineMap newMap =
+      AffineMap::get(originalMap.getNumDims(), originalMap.getNumSymbols(),
+                     dimReplacements, context);
+  maps.push_back(newMap);
+  maps.back() =
+      simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
+
+  originalMap =
+      llvm::cast<AffineMapAttr>(
+          mlir::parseAttribute("affine_map<(d0, d1, d2)->(d2, d1)>", context))
+          .getValue();
+  numResults = originalMap.getNumResults();
+  dimReplacements.clear();
+  for (unsigned i = 0; i < numResults; i++) {
+    AffineExpr expr = originalMap.getResult(getPermutationB()[i]);
+    dimReplacements.push_back(expr);
+  }
+
+  newMap = AffineMap::get(originalMap.getNumDims(), originalMap.getNumSymbols(),
+                          dimReplacements, context);
+  maps.push_back(newMap);
+  maps.back() =
+      simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
+
+  maps.push_back(
+      llvm::cast<AffineMapAttr>(
+          mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d1)>", context))
+          .getValue());
+  maps.back() =
+      simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
+  cached = Builder(context).getAffineMapArrayAttr(maps);
+  getOperation()->setAttr(memoizeAttr, cached);
+  return cached;
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+  return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                             ArrayRef<NamedAttribute> attrs) {
+  assert(3 > 0 && block.getNumArguments() == 3 &&
+         "MatmulOp regionBuilder expects 3 (>=0) args");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
+  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(0));
+  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(1));
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+  Value value4 =
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  yields.push_back(value4);
+  helper.yieldOutputs(yields);
+}
+
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
+                                MatmulOp::getRegionBuilder());
+}
+void MatmulOp::print(OpAsmPrinter &p) {
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
+  if (!getPermutationA().empty())
+    printDenseI64ArrayAttr(p, getPermutationAAttrName(), getPermutationA());
+
+  if (!getPermutationB().empty())
+    printDenseI64ArrayAttr(p, getPermutationBAttrName(), getPermutationB());
+}
+
+LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+void MatmulOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+} // namespace linalg
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 67bde8f736ef46..7ef5de12de5ad3 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -383,23 +383,6 @@ def select(
     O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
 
 
- at linalg_structured_op
-def matmul(
-    A=TensorDef(T1, S.M, S.K),
-    B=TensorDef(T2, S.K, S.N),
-    C=TensorDef(U, S.M, S.N, output=True),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
-    """Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-    """
-    domain(D.m, D.n, D.k)
-    implements(ContractionOpInterface)
-    C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
-
-
 @linalg_structured_op
 def quantized_matmul(
     A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 31fac9b4b41659..7c95d9592481e6 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -864,3 +864,65 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
 
   return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
 }
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_a_explicit(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                  %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                                  %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0]
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:   func.func @matmul_transpose_b_explicit(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                           %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) permutationB = [1, 0]
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0] permutationB = [1, 0]
+  return
+}
+
+// -----
+
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 02ecbed232c8b5..e702125667acc7 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1201,6 +1201,39 @@ func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %a
 
 // -----
 
+// CHECK-LABEL: func @matmul_transpose_a_explicit
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0]
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @matmul_transpose_b_explicit
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) permutationB = [1, 0]
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @matmul_transpose_a_b_explicit
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<7x5xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0] permutationB = [1, 0]
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @matmul_transpose_b
 //       CHECK:   linalg.matmul_transpose_b
 //  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index b147551c2e73db..72045a07b2da80 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,81 +84,6 @@ def named_form(lhs, rhs):
 
     print(module)
 
-
-# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
- at run
-def testNamedStructuredOpGenericForm():
-    with Context() as ctx, Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def named_form(lhs, rhs):
-                init_result = tensor.EmptyOp([4, 8], f32)
-                #      CHECK: "linalg.matmul"(%{{.*}})
-                # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
-                # CHECK-SAME:    operandSegmentSizes = array<i32: 2, 1>
-                # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
-                # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
-                # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
-                # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
-                # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
-                return linalg.matmul(lhs, rhs, outs=[init_result.result])
-
-    module.operation.print(print_generic_op_form=True)
-
-
-# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
- at run
-def testNamedStructuredAsGenericOp():
-    with Context() as ctx, Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def generic_form(lhs, rhs):
-                init_result = tensor.EmptyOp([4, 8], f32)
-                # CHECK: linalg.generic
-                return linalg.matmul(
-                    lhs, rhs, outs=[init_result.result], emit_generic=True
-                )
-
-    print(module)
-
-
-# CHECK-LABEL: TEST: testOpResultFromOtherOp
- at run
-def testOpResultFromOtherOp():
-    with Context(), Location.unknown():
-        module = Module.create()
-        f32 = F32Type.get()
-        with InsertionPoint(module.body):
-
-            @func.FuncOp.from_py_func(
-                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
-            )
-            def pass_an_op_directly(arg0, arg1):
-                one = arith.ConstantOp(F32Type.get(), 1.0)
-                # CHECK: %[[LHS:.*]] = linalg.fill
-                lhs = linalg.fill(one, outs=[arg0])
-                # CHECK: %[[RHS:.*]] = linalg.fill
-                rhs = linalg.fill(one, outs=[arg1])
-                # CHECK: %[[INIT:.*]] = tensor.empty
-                init = tensor.EmptyOp([4, 8], f32)
-                # CHECK: linalg.matmul
-                # CHECK: ins(%[[LHS]], %[[RHS]]
-                # CHECK: outs(%[[INIT]]
-                return linalg.matmul(lhs, rhs, outs=init)
-
-    print(module)
-
-
 # CHECK-LABEL: TEST: testIdentityRegionOps
 @run
 def testIdentityRegionOps():



More information about the Mlir-commits mailing list