@@ -132,7 +132,7 @@ def create_prediction_file(output_file, data_id, true_labels, predict_labels, pr
132132
133133 Args:
134134 output_file: The all classes predicted results provided by network.
135- data_id: The data record id info provided by class Data.
135+ data_id: The data record id info provided by dict < Data> .
136136 true_labels: The all true labels.
137137 predict_labels: The all predict labels by threshold.
138138 predict_scores: The all predict scores by threshold.
@@ -156,14 +156,14 @@ def create_prediction_file(output_file, data_id, true_labels, predict_labels, pr
156156
157157def get_onehot_label_threshold (scores , threshold = 0.5 ):
158158 """
159- Get the predicted onehot labels based on the threshold.
159+ Get the predicted one-hot labels based on the threshold.
160160 If there is no predict score greater than threshold, then choose the label which has the max predict score.
161161
162162 Args:
163163 scores: The all classes predicted scores provided by network.
164164 threshold: The threshold (default: 0.5).
165165 Returns:
166- predicted_onehot_labels: The predicted labels (onehot ).
166+ predicted_onehot_labels: The predicted labels (one-hot ).
167167 """
168168 predicted_onehot_labels = []
169169 scores = np .ndarray .tolist (scores )
@@ -183,13 +183,13 @@ def get_onehot_label_threshold(scores, threshold=0.5):
183183
184184def get_onehot_label_topk (scores , top_num = 1 ):
185185 """
186- Get the predicted onehot labels based on the topK number .
186+ Get the predicted one-hot labels based on the topK.
187187
188188 Args:
189189 scores: The all classes predicted scores provided by network.
190190 top_num: The max topK number (default: 5).
191191 Returns:
192- predicted_onehot_labels: The predicted labels (onehot ).
192+ predicted_onehot_labels: The predicted labels (one-hot ).
193193 """
194194 predicted_onehot_labels = []
195195 scores = np .ndarray .tolist (scores )
@@ -237,7 +237,7 @@ def get_label_threshold(scores, threshold=0.5):
237237
238238def get_label_topk (scores , top_num = 1 ):
239239 """
240- Get the predicted labels based on the topK number .
240+ Get the predicted labels based on the topK.
241241 Note: Only Used in `test_model.py`
242242
243243 Args:
0 commit comments