[llvm] 3edd897 - fix mlgo regalloc test model generation for tflite

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 9 12:37:00 PDT 2022


Author: yundiqian
Date: 2022-08-09T12:36:28-07:00
New Revision: 3edd8978c3129d15e364abb3632a0db478891415

URL: https://github.com/llvm/llvm-project/commit/3edd8978c3129d15e364abb3632a0db478891415
DIFF: https://github.com/llvm/llvm-project/commit/3edd8978c3129d15e364abb3632a0db478891415.diff

LOG: fix mlgo regalloc test model generation for tflite

To move from TF C API to TFLite, we found that the argmax op in TFLite does not work for int64 inputs, so cast the int64 inputs to int32 inputs to make TFLite argmax op work

Differential Revision: https://reviews.llvm.org/D131462

Added: 
    

Modified: 
    llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py b/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py
index 476163d6b5b3b..11bc3f259ddee 100644
--- a/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py
+++ b/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py
@@ -46,7 +46,7 @@ def build_mock_model(path):
   module.var = tf.Variable(0, dtype=tf.int64)
 
   def action(*inputs):
-    result = tf.math.argmax(inputs[0]['mask'], axis=-1) + module.var
+    result = tf.math.argmax(tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var
     return {POLICY_DECISION_LABEL: result}
   module.action = tf.function()(action)
   action = {


        


More information about the llvm-commits mailing list