Skip to content

Commit 8b5ebcc

Browse files
authored
[python-package] enforce keyword-only args in more internal functions (#7111)
* [python-package] enforce keyword-only args in more internal functions * keyword-only args for more early-stopping arguments
1 parent 648bf89 commit 8b5ebcc

6 files changed

Lines changed: 77 additions & 47 deletions

File tree

‎python-package/lightgbm/basic.py‎

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def _is_1d_collection(data: Any) -> bool:
367367

368368

369369
def _list_to_1d_numpy(
370+
*,
370371
data: Any,
371372
dtype: "np.typing.DTypeLike",
372373
name: str,
@@ -1840,7 +1841,7 @@ def __del__(self) -> None:
18401841
except AttributeError:
18411842
pass
18421843

1843-
def _create_sample_indices(self, total_nrow: int) -> np.ndarray:
1844+
def _create_sample_indices(self, *, total_nrow: int) -> np.ndarray:
18441845
"""Get an array of randomly chosen indices from this ``Dataset``.
18451846
18461847
Indices are sampled without replacement.
@@ -2167,26 +2168,26 @@ def _lazy_init(
21672168
)
21682169
)
21692170
elif isinstance(data, scipy.sparse.csr_matrix):
2170-
self.__init_from_csr(data, params_str, ref_dataset)
2171+
self.__init_from_csr(csr=data, params_str=params_str, ref_dataset=ref_dataset)
21712172
elif isinstance(data, scipy.sparse.csc_matrix):
2172-
self.__init_from_csc(data, params_str, ref_dataset)
2173+
self.__init_from_csc(csc=data, params_str=params_str, ref_dataset=ref_dataset)
21732174
elif isinstance(data, np.ndarray):
2174-
self.__init_from_np2d(data, params_str, ref_dataset)
2175+
self.__init_from_np2d(mat=data, params_str=params_str, ref_dataset=ref_dataset)
21752176
elif _is_pyarrow_table(data):
2176-
self.__init_from_pyarrow_table(data, params_str, ref_dataset)
2177+
self.__init_from_pyarrow_table(table=data, params_str=params_str, ref_dataset=ref_dataset)
21772178
elif isinstance(data, list) and len(data) > 0:
21782179
if _is_list_of_numpy_arrays(data):
2179-
self.__init_from_list_np2d(data, params_str, ref_dataset)
2180+
self.__init_from_list_np2d(mats=data, params_str=params_str, ref_dataset=ref_dataset)
21802181
elif _is_list_of_sequences(data):
2181-
self.__init_from_seqs(data, ref_dataset)
2182+
self.__init_from_seqs(seqs=data, ref_dataset=ref_dataset)
21822183
else:
21832184
raise TypeError("Data list can only be of ndarray or Sequence")
21842185
elif isinstance(data, Sequence):
2185-
self.__init_from_seqs([data], ref_dataset)
2186+
self.__init_from_seqs(seqs=[data], ref_dataset=ref_dataset)
21862187
else:
21872188
try:
21882189
csr = scipy.sparse.csr_matrix(data)
2189-
self.__init_from_csr(csr, params_str, ref_dataset)
2190+
self.__init_from_csr(csr=csr, params_str=params_str, ref_dataset=ref_dataset)
21902191
except BaseException as err:
21912192
raise TypeError(f"Cannot initialize Dataset from {type(data).__name__}") from err
21922193
if label is not None:
@@ -2225,7 +2226,7 @@ def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]) -> Ite
22252226
row = seq[id_in_seq]
22262227
yield row if row.flags["OWNDATA"] else row.copy()
22272228

2228-
def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:
2229+
def __sample(self, *, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:
22292230
"""Sample data from seqs.
22302231
22312232
Mimics behavior in c_api.cpp:LGBM_DatasetCreateFromMats()
@@ -2234,7 +2235,7 @@ def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarr
22342235
-------
22352236
sampled_rows, sampled_row_indices
22362237
"""
2237-
indices = self._create_sample_indices(total_nrow)
2238+
indices = self._create_sample_indices(total_nrow=total_nrow)
22382239

22392240
# Select sampled rows, transpose to column order.
22402241
sampled = np.array(list(self._yield_row_from_seqlist(seqs, indices)))
@@ -2255,6 +2256,7 @@ def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarr
22552256

22562257
def __init_from_seqs(
22572258
self,
2259+
*,
22582260
seqs: List[Sequence],
22592261
ref_dataset: Optional[_DatasetHandle],
22602262
) -> "Dataset":
@@ -2275,7 +2277,7 @@ def __init_from_seqs(
22752277
param_str = _param_dict_to_str(self.get_params())
22762278
sample_cnt = _get_sample_count(total_nrow, param_str)
22772279

2278-
sample_data, col_indices = self.__sample(seqs, total_nrow)
2280+
sample_data, col_indices = self.__sample(seqs=seqs, total_nrow=total_nrow)
22792281
self._init_from_sample(sample_data, col_indices, sample_cnt, total_nrow)
22802282

22812283
for seq in seqs:
@@ -2288,6 +2290,7 @@ def __init_from_seqs(
22882290

22892291
def __init_from_np2d(
22902292
self,
2293+
*,
22912294
mat: np.ndarray,
22922295
params_str: str,
22932296
ref_dataset: Optional[_DatasetHandle],
@@ -2315,6 +2318,7 @@ def __init_from_np2d(
23152318

23162319
def __init_from_list_np2d(
23172320
self,
2321+
*,
23182322
mats: List[np.ndarray],
23192323
params_str: str,
23202324
ref_dataset: Optional[_DatasetHandle],
@@ -2369,6 +2373,7 @@ def __init_from_list_np2d(
23692373

23702374
def __init_from_csr(
23712375
self,
2376+
*,
23722377
csr: scipy.sparse.csr_matrix,
23732378
params_str: str,
23742379
ref_dataset: Optional[_DatasetHandle],
@@ -2403,6 +2408,7 @@ def __init_from_csr(
24032408

24042409
def __init_from_csc(
24052410
self,
2411+
*,
24062412
csc: scipy.sparse.csc_matrix,
24072413
params_str: str,
24082414
ref_dataset: Optional[_DatasetHandle],
@@ -2437,6 +2443,7 @@ def __init_from_csc(
24372443

24382444
def __init_from_pyarrow_table(
24392445
self,
2446+
*,
24402447
table: pa_Table,
24412448
params_str: str,
24422449
ref_dataset: Optional[_DatasetHandle],
@@ -2466,6 +2473,7 @@ def __init_from_pyarrow_table(
24662473

24672474
@staticmethod
24682475
def _compare_params_for_warning(
2476+
*,
24692477
params: Dict[str, Any],
24702478
other_params: Dict[str, Any],
24712479
ignore_keys: Set[str],
@@ -2535,7 +2543,11 @@ def construct(self) -> "Dataset":
25352543
)
25362544
else:
25372545
# construct subset
2538-
used_indices = _list_to_1d_numpy(self.used_indices, dtype=np.int32, name="used_indices")
2546+
used_indices = _list_to_1d_numpy(
2547+
data=self.used_indices,
2548+
dtype=np.int32,
2549+
name="used_indices",
2550+
)
25392551
assert used_indices.flags.c_contiguous
25402552
if self.reference.group is not None:
25412553
group_info = np.array(self.reference.group).astype(np.int32, copy=False)
@@ -2803,9 +2815,9 @@ def set_field(
28032815
if field_name == "init_score":
28042816
dtype = np.float64
28052817
if _is_1d_collection(data):
2806-
data = _list_to_1d_numpy(data, dtype=dtype, name=field_name)
2818+
data = _list_to_1d_numpy(data=data, dtype=dtype, name=field_name)
28072819
elif _is_2d_collection(data):
2808-
data = _data_to_2d_numpy(data, dtype=dtype, name=field_name)
2820+
data = _data_to_2d_numpy(data=data, dtype=dtype, name=field_name)
28092821
data = data.ravel(order="F")
28102822
else:
28112823
raise TypeError(
@@ -2817,7 +2829,7 @@ def set_field(
28172829
dtype = np.int32
28182830
else:
28192831
dtype = np.float32
2820-
data = _list_to_1d_numpy(data, dtype=dtype, name=field_name)
2832+
data = _list_to_1d_numpy(data=data, dtype=dtype, name=field_name)
28212833

28222834
ptr_data: Union[_ctypes_float_ptr, _ctypes_int_ptr]
28232835
if data.dtype == np.float32 or data.dtype == np.float64:
@@ -3058,7 +3070,7 @@ def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
30583070
elif _is_pyarrow_array(label):
30593071
label_array = label
30603072
else:
3061-
label_array = _list_to_1d_numpy(label, dtype=np.float32, name="label")
3073+
label_array = _list_to_1d_numpy(data=label, dtype=np.float32, name="label")
30623074
self.set_field("label", label_array)
30633075
self.label = self.get_field("label") # original values can be modified at cpp side
30643076
return self
@@ -3091,7 +3103,7 @@ def set_weight(
30913103
# Set field
30923104
if self._handle is not None and weight is not None:
30933105
if not _is_pyarrow_array(weight):
3094-
weight = _list_to_1d_numpy(weight, dtype=np.float32, name="weight")
3106+
weight = _list_to_1d_numpy(data=weight, dtype=np.float32, name="weight")
30953107
self.set_field("weight", weight)
30963108
self.weight = self.get_field("weight") # original values can be modified at cpp side
30973109
return self
@@ -3141,7 +3153,7 @@ def set_group(
31413153
self.group = group
31423154
if self._handle is not None and group is not None:
31433155
if not _is_pyarrow_array(group):
3144-
group = _list_to_1d_numpy(group, dtype=np.int32, name="group")
3156+
group = _list_to_1d_numpy(data=group, dtype=np.int32, name="group")
31453157
self.set_field("group", group)
31463158
# original values can be modified at cpp side
31473159
constructed_group = self.get_field("group")
@@ -3167,7 +3179,7 @@ def set_position(
31673179
"""
31683180
self.position = position
31693181
if self._handle is not None and position is not None:
3170-
position = _list_to_1d_numpy(position, dtype=np.int32, name="position")
3182+
position = _list_to_1d_numpy(data=position, dtype=np.int32, name="position")
31713183
self.set_field("position", position)
31723184
return self
31733185

@@ -3884,6 +3896,7 @@ def _get_node_index(
38843896
return f"{tree_num}{node_type}{node_num}"
38853897

38863898
def _get_split_feature(
3899+
*,
38873900
tree: Dict[str, Any],
38883901
feature_names: Optional[List[str]],
38893902
) -> Optional[str]:
@@ -3907,7 +3920,7 @@ def _is_single_node_tree(tree: Dict[str, Any]) -> bool:
39073920
node["left_child"] = None
39083921
node["right_child"] = None
39093922
node["parent_index"] = parent_node
3910-
node["split_feature"] = _get_split_feature(tree, feature_names)
3923+
node["split_feature"] = _get_split_feature(tree=tree, feature_names=feature_names)
39113924
node["split_gain"] = None
39123925
node["threshold"] = None
39133926
node["decision_type"] = None
@@ -4132,11 +4145,12 @@ def update(
41324145
else:
41334146
if not self.__set_objective_to_none:
41344147
self.reset_parameter({"objective": "none"}).__set_objective_to_none = True
4135-
grad, hess = fobj(self.__inner_predict(0), self.train_set)
4136-
return self.__boost(grad, hess)
4148+
grad, hess = fobj(self.__inner_predict(data_idx=0), self.train_set)
4149+
return self.__boost(grad=grad, hess=hess)
41374150

41384151
def __boost(
41394152
self,
4153+
*,
41404154
grad: np.ndarray,
41414155
hess: np.ndarray,
41424156
) -> bool:
@@ -4171,8 +4185,8 @@ def __boost(
41714185
if self.__num_class > 1:
41724186
grad = grad.ravel(order="F")
41734187
hess = hess.ravel(order="F")
4174-
grad = _list_to_1d_numpy(grad, dtype=np.float32, name="gradient")
4175-
hess = _list_to_1d_numpy(hess, dtype=np.float32, name="hessian")
4188+
grad = _list_to_1d_numpy(data=grad, dtype=np.float32, name="gradient")
4189+
hess = _list_to_1d_numpy(data=hess, dtype=np.float32, name="hessian")
41764190
assert grad.flags.c_contiguous
41774191
assert hess.flags.c_contiguous
41784192
if len(grad) != len(hess):
@@ -5178,7 +5192,7 @@ def __inner_eval(
51785192
for eval_function in feval:
51795193
if eval_function is None:
51805194
continue
5181-
feval_ret = eval_function(self.__inner_predict(data_idx), cur_data)
5195+
feval_ret = eval_function(self.__inner_predict(data_idx=data_idx), cur_data)
51825196
if isinstance(feval_ret, list):
51835197
for eval_name, val, is_higher_better in feval_ret:
51845198
ret.append((data_name, eval_name, val, is_higher_better))
@@ -5187,7 +5201,7 @@ def __inner_eval(
51875201
ret.append((data_name, eval_name, val, is_higher_better))
51885202
return ret
51895203

5190-
def __inner_predict(self, data_idx: int) -> np.ndarray:
5204+
def __inner_predict(self, *, data_idx: int) -> np.ndarray:
51915205
"""Predict for training and validation dataset."""
51925206
if data_idx >= self.__num_dataset:
51935207
raise ValueError("Data_idx should be smaller than number of dataset")

‎python-package/lightgbm/callback.py‎

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,16 +301,16 @@ def _reset_storages(self) -> None:
301301
self.best_score: List[float] = []
302302
self.best_iter: List[int] = []
303303
self.best_score_list: List[_ListOfEvalResultTuples] = []
304-
self.cmp_op: List[Callable[[float, float], bool]] = []
304+
self.cmp_op: List[Callable[[float, float, float], bool]] = []
305305
self.first_metric = ""
306306

307-
def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
307+
def _gt_delta(self, *, curr_score: float, best_score: float, delta: float) -> bool:
308308
return curr_score > best_score + delta
309309

310-
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
310+
def _lt_delta(self, *, curr_score: float, best_score: float, delta: float) -> bool:
311311
return curr_score < best_score - delta
312312

313-
def _is_train_set(self, dataset_name: str, env: CallbackEnv) -> bool:
313+
def _is_train_set(self, *, dataset_name: str, env: CallbackEnv) -> bool:
314314
"""Check, by name, if a given Dataset is the training data."""
315315
# for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
316316
# and those metrics are considered for early stopping
@@ -413,7 +413,9 @@ def __call__(self, env: CallbackEnv) -> None:
413413
first_time_updating_best_score_list = self.best_score_list == []
414414
for i in range(len(env.evaluation_result_list)):
415415
dataset_name, metric_name, metric_value, *_ = env.evaluation_result_list[i]
416-
if first_time_updating_best_score_list or self.cmp_op[i](metric_value, self.best_score[i]):
416+
if first_time_updating_best_score_list or self.cmp_op[i]( # type: ignore[call-arg]
417+
curr_score=metric_value, best_score=self.best_score[i]
418+
):
417419
self.best_score[i] = metric_value
418420
self.best_iter[i] = env.iteration
419421
if first_time_updating_best_score_list:

‎python-package/lightgbm/dask.py‎

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def _get_dask_client(client: Optional[Client]) -> Client:
113113

114114

115115
def _assign_open_ports_to_workers(
116+
*,
116117
client: Client,
117118
workers: List[str],
118119
) -> Tuple[Dict[str, Future], Dict[str, int]]:
@@ -165,7 +166,11 @@ def _remove_list_padding(*args: Any) -> List[List[Any]]:
165166
return [[z for z in arg if z is not None] for arg in args]
166167

167168

168-
def _pad_eval_names(lgbm_model: LGBMModel, required_names: List[str]) -> LGBMModel:
169+
def _pad_eval_names(
170+
*,
171+
lgbm_model: LGBMModel,
172+
required_names: List[str],
173+
) -> LGBMModel:
169174
"""Append missing (key, value) pairs to a LightGBM model's evals_result_ and best_score_ OrderedDict attrs based on a set of required eval_set names.
170175
171176
Allows users to rely on expected eval_set names being present when fitting DaskLGBM estimators with ``eval_set``.
@@ -356,12 +361,12 @@ def _train_part(
356361

357362
if n_evals:
358363
# ensure that expected keys for evals_result_ and best_score_ exist regardless of padding.
359-
model = _pad_eval_names(model, required_names=evals_result_names)
364+
model = _pad_eval_names(lgbm_model=model, required_names=evals_result_names)
360365

361366
return model if return_model else None
362367

363368

364-
def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]:
369+
def _split_to_parts(*, data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]:
365370
parts = data.to_delayed()
366371
if isinstance(parts, np.ndarray):
367372
if is_matrix:
@@ -372,7 +377,11 @@ def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]:
372377
return parts
373378

374379

375-
def _machines_to_worker_map(machines: str, worker_addresses: Iterable[str]) -> Dict[str, int]:
380+
def _machines_to_worker_map(
381+
*,
382+
machines: str,
383+
worker_addresses: Iterable[str],
384+
) -> Dict[str, int]:
376385
"""Create a worker_map from machines list.
377386
378387
Given ``machines`` and a list of Dask worker addresses, return a mapping where the keys are
@@ -773,7 +782,8 @@ def _train(
773782
else:
774783
_log_info("Finding random open ports for workers")
775784
worker_to_socket_future, worker_address_to_port = _assign_open_ports_to_workers(
776-
client, list(worker_map.keys())
785+
client=client,
786+
workers=list(worker_map.keys()),
777787
)
778788

779789
machines = ",".join(
@@ -1091,20 +1101,21 @@ def _lgb_dask_fit(
10911101
)
10921102

10931103
self.set_params(**model.get_params()) # type: ignore[attr-defined]
1094-
self._lgb_dask_copy_extra_params(model, self) # type: ignore[attr-defined]
1104+
self._lgb_dask_copy_extra_params(source=model, dest=self) # type: ignore[attr-defined]
10951105

10961106
return self
10971107

10981108
def _lgb_dask_to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
10991109
params = self.get_params() # type: ignore[attr-defined]
11001110
params.pop("client", None)
11011111
model = model_factory(**params)
1102-
self._lgb_dask_copy_extra_params(self, model)
1112+
self._lgb_dask_copy_extra_params(source=self, dest=model)
11031113
model._other_params.pop("client", None)
11041114
return model
11051115

11061116
@staticmethod
11071117
def _lgb_dask_copy_extra_params(
1118+
*,
11081119
source: Union["_DaskLGBMModel", LGBMModel],
11091120
dest: Union["_DaskLGBMModel", LGBMModel],
11101121
) -> None:

0 commit comments

Comments
 (0)