[llvm] 3edd897 - fix mlgo regalloc test model generation for tflite
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 9 12:37:00 PDT 2022
Author: yundiqian
Date: 2022-08-09T12:36:28-07:00
New Revision: 3edd8978c3129d15e364abb3632a0db478891415
URL: https://github.com/llvm/llvm-project/commit/3edd8978c3129d15e364abb3632a0db478891415
DIFF: https://github.com/llvm/llvm-project/commit/3edd8978c3129d15e364abb3632a0db478891415.diff
LOG: fix mlgo regalloc test model generation for tflite
To move from TF C API to TFLite, we found that the argmax op in TFLite does not work for int64 inputs, so cast the int64 inputs to int32 inputs to make TFLite argmax op work
Differential Revision: https://reviews.llvm.org/D131462
Added:
Modified:
llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py b/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py
index 476163d6b5b3b..11bc3f259ddee 100644
--- a/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py
+++ b/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py
@@ -46,7 +46,7 @@ def build_mock_model(path):
module.var = tf.Variable(0, dtype=tf.int64)
def action(*inputs):
- result = tf.math.argmax(inputs[0]['mask'], axis=-1) + module.var
+ result = tf.math.argmax(tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var
return {POLICY_DECISION_LABEL: result}
module.action = tf.function()(action)
action = {
More information about the llvm-commits
mailing list