[Mlir-commits] [mlir] [mlir][linalg] Fix `linalg.matmul_transpose_a` def. (PR #97690)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 4 01:16:42 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jerry Shih (JerryShih)

<details>
<summary>Changes</summary>

The `matmul_transpose_a` input data format should be `KxM * KxN`.

---
Full diff: https://github.com/llvm/llvm-project/pull/97690.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+1-1) 
- (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index fad234a9dcae9..abb79278eddd4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1336,7 +1336,7 @@ structured_op: !LinalgStructuredOpConfig
     name: C
     kind: output_tensor
     type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
+    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
   - !LinalgOperandDefConfig
     name: cast
     kind: type_fn_attr
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 43410aaa6af1b..59b3ba914eaab 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -429,8 +429,8 @@ def quantized_matmul(
 
 @linalg_structured_op
 def matmul_transpose_a(
-    A=TensorDef(T1, S.K, S.N),
-    B=TensorDef(T2, S.K, S.M),
+    A=TensorDef(T1, S.K, S.M),
+    B=TensorDef(T2, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True),
     cast=TypeFnAttrDef(default=TypeFn.cast_signed),
 ):

``````````

</details>


https://github.com/llvm/llvm-project/pull/97690


More information about the Mlir-commits mailing list