From 60c707e02af7bd950401faeadcc6c0e1b39b2978 Mon Sep 17 00:00:00 2001 From: crlotwhite <76243689+crlotwhite@users.noreply.github.com> Date: Fri, 15 Jul 2022 01:42:47 +0800 Subject: [PATCH] fix(tf2): predict_classes deprecated (#271) --- tf2/tf2-06-1-softmax_classifier.py | 4 ++-- tf2/tf2-06-2-softmax_zoo_classifier.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tf2/tf2-06-1-softmax_classifier.py b/tf2/tf2-06-1-softmax_classifier.py index 67ebf896..b150690e 100644 --- a/tf2/tf2-06-1-softmax_classifier.py +++ b/tf2/tf2-06-1-softmax_classifier.py @@ -48,10 +48,10 @@ print('--------------') # or use argmax embedded method, predict_classes c = tf.model.predict(np.array([[1, 1, 0, 1]])) -c_onehot = tf.model.predict_classes(np.array([[1, 1, 0, 1]])) +c_onehot = np.argmax(c, axis=-1) print(c, c_onehot) print('--------------') all = tf.model.predict(np.array([[1, 11, 7, 9], [1, 3, 4, 3], [1, 1, 0, 1]])) -all_onehot = tf.model.predict_classes(np.array([[1, 11, 7, 9], [1, 3, 4, 3], [1, 1, 0, 1]])) +all_onehot = np.argmax(all, axis=-1) print(all, all_onehot) diff --git a/tf2/tf2-06-2-softmax_zoo_classifier.py b/tf2/tf2-06-2-softmax_zoo_classifier.py index 01fcb290..d5fa27a4 100644 --- a/tf2/tf2-06-2-softmax_zoo_classifier.py +++ b/tf2/tf2-06-2-softmax_zoo_classifier.py @@ -28,9 +28,9 @@ # Single data test test_data = np.array([[0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0]]) # expected prediction == 3 (feathers) -print(tf.model.predict(test_data), tf.model.predict_classes(test_data)) +print(tf.model.predict(test_data), np.argmax(tf.model.predict(test_data), axis=-1)) # Full x_data test -pred = tf.model.predict_classes(x_data) +pred = np.argmax(tf.model.predict(x_data), axis=-1) for p, y in zip(pred, y_data.flatten()): print("[{}] Prediction: {} True Y: {}".format(p == int(y), p, int(y)))