[Mlir-commits] [mlir] cae13ff - [mlir][test] Fix how the number of flops is calculated

Uday Bondhugula llvmlistbot at llvm.org
Mon Nov 21 00:11:45 PST 2022


Author: Andrzej Warzynski
Date: 2022-11-21T13:40:24+05:30
New Revision: cae13ff416ec02f58164e76230371edc5e41be7f

URL: https://github.com/llvm/llvm-project/commit/cae13ff416ec02f58164e76230371edc5e41be7f
DIFF: https://github.com/llvm/llvm-project/commit/cae13ff416ec02f58164e76230371edc5e41be7f.diff

LOG: [mlir][test] Fix how the number of flops is calculated

Make sure that the number of repetitions is correctly incorporated when
calculating the number of floating point operations.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D138382

Added: 
    

Modified: 
    mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir b/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir
index 5ebafbe668020..8a427ddf4b398 100644
--- a/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir
+++ b/mlir/test/mlir-cpu-runner/sgemm-naive-codegen.mlir
@@ -10,10 +10,10 @@ func.func @main() {
   linalg.fill ins(%cf1 : f32) outs(%A : memref<16x16xf32>)
   linalg.fill ins(%cf1 : f32) outs(%B : memref<16x16xf32>)
 
-  %reps = arith.constant 1 : index
+  %num_reps = arith.constant 5 : index
 
   %t_start = call @rtclock() : () -> f64
-  affine.for %arg0 = 0 to 5 {
+  affine.for %arg0 = 0 to %num_reps {
     linalg.fill ins(%cf1 : f32) outs(%C : memref<16x16xf32>)
     func.call @sgemm_naive(%A, %B, %C) : (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>) -> ()
   }
@@ -31,16 +31,19 @@ func.func @main() {
   %N = memref.dim %C, %c1 : memref<16x16xf32>
   %K = memref.dim %A, %c1 : memref<16x16xf32>
 
+  // num_flops_per_iter = 2*M*N*K
   %f1 = arith.muli %M, %N : index
   %f2 = arith.muli %f1, %K : index
+  %num_flops_per_iter = arith.muli %c2, %f2 : index
 
-  // 2*M*N*K.
-  %f3 = arith.muli %c2, %f2 : index
-  %num_flops = arith.muli %reps, %f3 : index
-  %num_flops_i = arith.index_cast %num_flops : index to i16
-  %num_flops_f = arith.sitofp %num_flops_i : i16 to f64
-  %flops = arith.divf %num_flops_f, %t : f64
-  call @printFlops(%flops) : (f64) -> ()
+  // num_flops_total = num_flops_per_iter * num_reps
+  %num_flops_total = arith.muli %num_flops_per_iter, %num_reps: index
+
+  // Print the number of flops per second
+  %num_flops_total_i = arith.index_cast %num_flops_total : index to i16
+  %num_flops_total_f = arith.uitofp %num_flops_total_i : i16 to f64
+  %flops_per_s = arith.divf %num_flops_total_f, %t : f64
+  call @printFlops(%flops_per_s) : (f64) -> ()
 
   memref.dealloc %A : memref<16x16xf32>
   memref.dealloc %B : memref<16x16xf32>


        


More information about the Mlir-commits mailing list