[Mlir-commits] [mlir] [mlir][linalg] fix linalg.batch_reduce_matmul auto cast (PR #102585)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 9 01:44:55 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: zhicong zhong (zhczhong)
<details>
<summary>Changes</summary>
Fix the auto-cast of `linalg.batch_reduce_matmul` from `cast_to_T(A * cast_to_T(B)) + C` to `cast_to_T(A) * cast_to_T(B) + C`
---
Full diff: https://github.com/llvm/llvm-project/pull/102585.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+13-14)
- (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+1-2)
- (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+27)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 46b3ec0f60ebfa..249b0f56477cc8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,5 +1,3 @@
-### AUTOGENERATED from core_named_ops.py
-### To regenerate, run: bin/update_core_linalg_named_ops.sh
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: copy
@@ -1908,25 +1906,25 @@ structured_op: !LinalgStructuredOpConfig
scalar_arg: C
- !ScalarExpression
scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
+ kind: binary
+ fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
- kind: binary
- fn_name: mul
+ kind: type
+ fn_name: cast_signed
+ type_var: U
operands:
- !ScalarExpression
scalar_arg: A
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
- !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
+ scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
@@ -6509,3 +6507,4 @@ structured_op: !LinalgStructuredOpConfig
scalar_const: '2.3283063999999999E-10 : f64'
- !ScalarExpression
scalar_arg: min
+
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 67bde8f736ef46..afb68b471d347a 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -593,8 +593,7 @@ def batch_reduce_matmul(
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.m, D.n] += TypeFn.cast_signed(
- U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
- )
+ U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
@linalg_structured_op
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 31fac9b4b41659..1e8f1435ca0fa5 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -329,6 +329,33 @@ func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
// CHECK: linalg.yield %[[ADD]] : f32
+// -----
+
+func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: memref<7x9x8xbf16>, %out: memref<8x8xf32>) {
+ linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xbf16>, memref<7x9x8xbf16>)
+ outs(%out: memref<8x8xf32>)
+ return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK: @generalize_batch_reduce_gemm_bf16
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xbf16>, memref<7x9x8xbf16>)
+// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32>
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: bf16, %[[BBARG1:.+]]: bf16, %[[BBARG2:.+]]: f32)
+// CHECK: %[[EXTBF16_0:.+]] = arith.extf %[[BBARG0]] : bf16 to f32
+// CHECK: %[[EXTBF16_1:.+]] = arith.extf %[[BBARG1]] : bf16 to f32
+// CHECK: %[[MUL:.+]] = arith.mulf %[[EXTBF16_0]], %[[EXTBF16_1]] : f32
+// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+
+
// -----
// CHECK-LABEL: generalize_linalg_map
``````````
</details>
https://github.com/llvm/llvm-project/pull/102585
More information about the Mlir-commits
mailing list