[Mlir-commits] [mlir] [mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (PR #104783)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 19 07:18:08 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
Author: Md Asghar Ahmad Shahid (shahidact)
<details>
<summary>Changes</summary>
The main goal of this patch is to extend the semantic of 'linalg.matmul' named op to include per operand transpose semantic while also laying out a way to move ops definition from OpDSL to tablegen. Hence, it is implemented in tablegen. Transpose semantic is as follows.
By default 'linalg.matmul' behavior will remain as is. Transpose semantics can be appiled on per input operand by specifying the optional permutation attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd input) for each operand explicitly as needed. By default, no transpose is mandated for any of the input operand.
Example:
```
%val = linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
permutationA = [1, 0]
permutationB = [0, 1]
```
---
Patch is 23.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104783.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-72)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+1)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+100)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+151-1)
- (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-17)
- (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+62)
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+33)
- (modified) mlir/test/python/dialects/linalg/ops.py (-75)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 46b3ec0f60ebfa..8e2e827a12cc4e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul
- cpp_class_name: MatmulOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_matmul
cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a9007c8db3078e..4ca7c5f0f1f676 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -271,4 +271,5 @@ def Linalg_WinogradOutputTransformOp :
let hasVerifier = 1;
}
+
#endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ac61117c3d6e36..5e6940b42db976 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -534,6 +534,106 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", !listconcat([AttrSizedOperandSegments],
+ /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+
+ let summary = [{Performs a matrix multiplication of two 2D inputs without transpose.}];
+ let description = [{Numeric casting is performed on the operands to the inner multiply,
+ promoting them to the same data type as the accumulator/output.
+
+ Per input operand transpose can be performed by specifying the required permutation
+ attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd input) for
+ each operand explicitly. By default, no transpose is mandated for each input operand.
+
+ Example:
+ ```
+ %val = linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ permutationA = [1, 0]
+ permutationB = [0, 1]
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{0, 1}">, [DenseArrayCount<2>]>:$permutationA,
+ ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{0, 1}">, [DenseArrayCount<2>]>:$permutationB,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, MatmulOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addOperands(operands);
+ $_state.addAttributes(attributes);
+ $_state.addTypes(resultTensorTypes);
+ (void)$_state.addRegion();
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs, "DenseI64ArrayAttr":$permutationA, "DenseI64ArrayAttr":$permutationB, "Attribute":$cast,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("permutationA", permutationA);
+ $_state.addAttribute("permutationB", permutationB);
+ $_state.addAttribute("cast", cast);
+ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+ attributes, MatmulOp::getRegionBuilder());
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ // Auto-generated.
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+ ArrayAttr getIndexingMaps();
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ std::string getLibraryCallName();
+ bool hasDynamicIndexingMaps();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 99b625d99fec2e..bef19e737ca6c7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -303,6 +303,26 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
+ if (parser.parseOptionalKeyword("permutationA").succeeded()) {
+ if (parser.parseEqual())
+ return failure();
+
+ result.attributes.set("permutationA",
+ DenseI64ArrayAttr::parse(parser, Type{}));
+ }
+
+ if (parser.parseOptionalKeyword("permutationB").succeeded()) {
+ if (parser.parseEqual())
+ return failure();
+
+ result.attributes.set("permutationB",
+ DenseI64ArrayAttr::parse(parser, Type{}));
+ }
+
+ // Parse optional attributes.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
// TODO: consider merging results parsing into region parsing.
// Need to wait for declarative assembly resolution to decide.
SmallVector<Type, 1> outputTensorsTypes;
@@ -334,7 +354,8 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
/*elidedAttrs=*/{"operandSegmentSizes",
// See generated code in
// LinalgNamedStructuredOps.yamlgen.cpp.inc
- "linalg.memoized_indexing_maps"});
+ "linalg.memoized_indexing_maps", "permutationA",
+ "permutationB"});
// Printing is shared with generic ops, except for the region and
// attributes.
@@ -2980,3 +3001,132 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
Location loc) {
return arith::ConstantOp::materialize(builder, value, type, loc);
}
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+ utils::IteratorType::parallel,
+ utils::IteratorType::reduction};
+}
+
+ArrayAttr MatmulOp::getIndexingMaps() {
+ static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
+ ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
+ if (cached)
+ return cached;
+
+ MLIRContext *context = getContext();
+ SmallVector<AffineMap> maps;
+
+ unsigned numResults;
+ SmallVector<AffineExpr, 3> dimReplacements;
+ AffineMap originalMap =
+ llvm::cast<AffineMapAttr>(
+ mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d2)>", context))
+ .getValue();
+ numResults = originalMap.getNumResults();
+ for (unsigned i = 0; i < numResults; i++) {
+ AffineExpr expr = originalMap.getResult(getPermutationA()[i]);
+ dimReplacements.push_back(expr);
+ }
+
+ AffineMap newMap =
+ AffineMap::get(originalMap.getNumDims(), originalMap.getNumSymbols(),
+ dimReplacements, context);
+ maps.push_back(newMap);
+ maps.back() =
+ simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
+
+ originalMap =
+ llvm::cast<AffineMapAttr>(
+ mlir::parseAttribute("affine_map<(d0, d1, d2)->(d2, d1)>", context))
+ .getValue();
+ numResults = originalMap.getNumResults();
+ dimReplacements.clear();
+ for (unsigned i = 0; i < numResults; i++) {
+ AffineExpr expr = originalMap.getResult(getPermutationB()[i]);
+ dimReplacements.push_back(expr);
+ }
+
+ newMap = AffineMap::get(originalMap.getNumDims(), originalMap.getNumSymbols(),
+ dimReplacements, context);
+ maps.push_back(newMap);
+ maps.back() =
+ simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
+
+ maps.push_back(
+ llvm::cast<AffineMapAttr>(
+ mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d1)>", context))
+ .getValue());
+ maps.back() =
+ simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
+ cached = Builder(context).getAffineMapArrayAttr(maps);
+ getOperation()->setAttr(memoizeAttr, cached);
+ return cached;
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(3 > 0 && block.getNumArguments() == 3 &&
+ "MatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ TypeFn castVal = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castVal = attr.getValue();
+ }
+
+ Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
+ MatmulOp::getRegionBuilder());
+}
+void MatmulOp::print(OpAsmPrinter &p) {
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
+ if (!getPermutationA().empty())
+ printDenseI64ArrayAttr(p, getPermutationAAttrName(), getPermutationA());
+
+ if (!getPermutationB().empty())
+ printDenseI64ArrayAttr(p, getPermutationBAttrName(), getPermutationB());
+}
+
+LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void MatmulOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+} // namespace linalg
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 67bde8f736ef46..7ef5de12de5ad3 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -383,23 +383,6 @@ def select(
O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
- at linalg_structured_op
-def matmul(
- A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
- """Performs a matrix multiplication of two 2D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
-
-
@linalg_structured_op
def quantized_matmul(
A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 31fac9b4b41659..7c95d9592481e6 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -864,3 +864,65 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_a_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0]
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_transpose_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) permutationB = [1, 0]
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0] permutationB = [1, 0]
+ return
+}
+
+// -----
+
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 02ecbed232c8b5..e702125667acc7 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1201,6 +1201,39 @@ func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %a
// -----
+// CHECK-LABEL: func @matmul_transpose_a_explicit
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0]
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @matmul_transpose_b_explicit
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) permutationB = [1, 0]
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @matmul_transpose_a_b_explicit
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<7x5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/104783
More information about the Mlir-commits
mailing list