[Mlir-commits] [mlir] 558d7ad - [mlir][linalg] fix linalg.batch_reduce_matmul auto cast (#102585)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Aug 11 23:38:01 PDT 2024


Author: zhicong zhong
Date: 2024-08-12T14:37:57+08:00
New Revision: 558d7adaae4871134a87457bd07e21fdbe001c08

URL: https://github.com/llvm/llvm-project/commit/558d7adaae4871134a87457bd07e21fdbe001c08
DIFF: https://github.com/llvm/llvm-project/commit/558d7adaae4871134a87457bd07e21fdbe001c08.diff

LOG: [mlir][linalg] fix linalg.batch_reduce_matmul auto cast (#102585)

Fix the auto-cast of `linalg.batch_reduce_matmul` from `cast_to_T(A *
cast_to_T(B)) + C` to `cast_to_T(A) * cast_to_T(B) + C`

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 46b3ec0f60ebfa..8cb698096ef5b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1908,25 +1908,25 @@ structured_op: !LinalgStructuredOpConfig
           scalar_arg: C
         - !ScalarExpression
           scalar_fn:
-            kind: type
-            fn_name: cast_signed
-            type_var: U
+            kind: binary
+            fn_name: mul
             operands:
             - !ScalarExpression
               scalar_fn:
-                kind: binary
-                fn_name: mul
+                kind: type
+                fn_name: cast_signed
+                type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
                 - !ScalarExpression
-                  scalar_fn:
-                    kind: type
-                    fn_name: cast_signed
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: B
+                  scalar_arg: B
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matvec

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 67bde8f736ef46..e4a6ec7487bb2f 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -592,8 +592,8 @@ def batch_reduce_matmul(
     """
     domain(D.b, D.m, D.n, D.k)
     implements(ContractionOpInterface)
-    C[D.m, D.n] += TypeFn.cast_signed(
-        U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
+    C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+        U, B[D.b, D.k, D.n]
     )
 
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 31fac9b4b41659..1e8f1435ca0fa5 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -329,6 +329,33 @@ func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %
 // CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
 // CHECK:         linalg.yield %[[ADD]] : f32
 
+// -----
+
+func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: memref<7x9x8xbf16>, %out: memref<8x8xf32>) {
+  linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xbf16>, memref<7x9x8xbf16>)
+                             outs(%out: memref<8x8xf32>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK: @generalize_batch_reduce_gemm_bf16
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xbf16>, memref<7x9x8xbf16>)
+// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32>
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: bf16, %[[BBARG1:.+]]: bf16, %[[BBARG2:.+]]: f32)
+// CHECK:         %[[EXTBF16_0:.+]] = arith.extf %[[BBARG0]] : bf16 to f32
+// CHECK:         %[[EXTBF16_1:.+]] = arith.extf %[[BBARG1]] : bf16 to f32
+// CHECK:         %[[MUL:.+]] = arith.mulf %[[EXTBF16_0]], %[[EXTBF16_1]] : f32
+// CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK:         linalg.yield %[[ADD]] : f32
+
+
 // -----
 
 // CHECK-LABEL: generalize_linalg_map


        


More information about the Mlir-commits mailing list