Skip to content

Commit 28c4c92

Browse files
committed
Update data_helpers.py
1 parent 1e2c832 commit 28c4c92

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

‎utils/data_helpers.py‎

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def create_prediction_file(output_file, data_id, true_labels, predict_labels, pr
132132
133133
Args:
134134
output_file: The all classes predicted results provided by network.
135-
data_id: The data record id info provided by class Data.
135+
data_id: The data record id info provided by dict <Data>.
136136
true_labels: The all true labels.
137137
predict_labels: The all predict labels by threshold.
138138
predict_scores: The all predict scores by threshold.
@@ -156,14 +156,14 @@ def create_prediction_file(output_file, data_id, true_labels, predict_labels, pr
156156

157157
def get_onehot_label_threshold(scores, threshold=0.5):
158158
"""
159-
Get the predicted onehot labels based on the threshold.
159+
Get the predicted one-hot labels based on the threshold.
160160
If there is no predict score greater than threshold, then choose the label which has the max predict score.
161161
162162
Args:
163163
scores: The all classes predicted scores provided by network.
164164
threshold: The threshold (default: 0.5).
165165
Returns:
166-
predicted_onehot_labels: The predicted labels (onehot).
166+
predicted_onehot_labels: The predicted labels (one-hot).
167167
"""
168168
predicted_onehot_labels = []
169169
scores = np.ndarray.tolist(scores)
@@ -183,13 +183,13 @@ def get_onehot_label_threshold(scores, threshold=0.5):
183183

184184
def get_onehot_label_topk(scores, top_num=1):
185185
"""
186-
Get the predicted onehot labels based on the topK number.
186+
Get the predicted one-hot labels based on the topK.
187187
188188
Args:
189189
scores: The all classes predicted scores provided by network.
190190
top_num: The max topK number (default: 5).
191191
Returns:
192-
predicted_onehot_labels: The predicted labels (onehot).
192+
predicted_onehot_labels: The predicted labels (one-hot).
193193
"""
194194
predicted_onehot_labels = []
195195
scores = np.ndarray.tolist(scores)
@@ -237,7 +237,7 @@ def get_label_threshold(scores, threshold=0.5):
237237

238238
def get_label_topk(scores, top_num=1):
239239
"""
240-
Get the predicted labels based on the topK number.
240+
Get the predicted labels based on the topK.
241241
Note: Only Used in `test_model.py`
242242
243243
Args:

0 commit comments

Comments
 (0)