Skip to content

Commit 038c7f8

Browse files
Update Python (3.13) and PyTorch (2.9) versions (#1121)
Also: - drop Python 3.9 (EOL) - drop PyTorch 2.5 - update some tests - for torch < 2.9, downgrade to triton 3.4 - fix indentation error in GP classes docstrings (similar to #1080)
1 parent 90add47 commit 038c7f8

7 files changed

Lines changed: 53 additions & 100 deletions

File tree

‎.github/workflows/testing.yml‎

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ jobs:
2020
strategy:
2121
fail-fast: false # don't cancel all jobs when one fails
2222
matrix:
23-
python_version: ['3.9', '3.10', '3.11', '3.12']
24-
torch_version: ['2.5.1+cpu', '2.6.0+cpu', '2.7.1+cpu', '2.8.0+cpu']
23+
python_version: ['3.10', '3.11', '3.12', '3.13']
24+
torch_version: ['2.6.0+cpu', '2.7.1+cpu', '2.8.0+cpu', '2.9.0+cpu']
2525
os: [ubuntu-latest]
2626

2727
steps:
@@ -31,18 +31,17 @@ jobs:
3131
with:
3232
python-version: ${{ matrix.python_version }}
3333
- name: Install dependencies
34-
# TODO remove triton 3.1 install if torch < 2.6 no longer supported, see #1093
3534
# TODO remove triton 3.2 install if torch < 2.7 no longer supported
3635
run: |
3736
python -m pip install --upgrade pip
3837
python -m pip install '.[test,docs,dev,extended]'
3938
python -m pip install pytest-pretty
4039
python -m pip install torch==${{ matrix.torch_version }} -f https://download.pytorch.org/whl/torch
4140
TORCH_VERSION_MAJOR_MINOR=$(python -c "import torch; v=torch.__version__.split('+')[0]; print('.'.join(v.split('.')[:2]))")
42-
if [[ $(echo "$TORCH_VERSION_MAJOR_MINOR < 2.6" | bc -l) -eq 1 ]]; then
43-
python -m pip install "triton==3.1"
44-
elif [[ $(echo "$TORCH_VERSION_MAJOR_MINOR < 2.7" | bc -l) -eq 1 ]]; then
41+
if [[ $(echo "$TORCH_VERSION_MAJOR_MINOR < 2.7" | bc -l) -eq 1 ]]; then
4542
python -m pip install "triton==3.2"
43+
elif [[ $(echo "$TORCH_VERSION_MAJOR_MINOR < 2.9" | bc -l) -eq 1 ]]; then
44+
python -m pip install "triton==3.4"
4645
fi
4746
python -m pip list
4847
- name: Install skorch

‎README.rst‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,10 @@ instructions for PyTorch, visit the `PyTorch website
233233
<http://pytorch.org/>`__. skorch officially supports the last four
234234
minor PyTorch versions, which currently are:
235235

236-
- 2.5.1
237236
- 2.6.0
238237
- 2.7.1
239238
- 2.8.0
239+
- 2.9.0
240240

241241
However, that doesn't mean that older versions don't work, just that
242242
they aren't tested. Since skorch mostly relies on the stable part of

‎docs/user/installation.rst‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ instructions for PyTorch, visit the `PyTorch website
9393
<http://pytorch.org/>`__. skorch officially supports the last four
9494
minor PyTorch versions, which currently are:
9595

96-
- 2.5.1
9796
- 2.6.0
9897
- 2.7.1
9998
- 2.8.0
99+
- 2.9.0
100100

101101
However, that doesn't mean that older versions don't work, just that
102102
they aren't tested. Since skorch mostly relies on the stable part of

‎pyproject.toml‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dynamic = ["version", "readme"]
1515
description = "scikit-learn compatible neural network library for pytorch"
1616
license = "BSD-3-Clause"
1717
license-files = ["LICENSE"]
18-
requires-python = ">=3.9"
18+
requires-python = ">=3.10"
1919
dependencies =[
2020
"numpy>=1.13.3",
2121
"scikit-learn>=0.22.0",
@@ -31,7 +31,6 @@ authors = [
3131
{name = "skorch Developers"}
3232
]
3333
classifiers = [ #note:classifiers are recommended for packages to be visible on PyPI
34-
"Programming Language :: Python :: 3.9",
3534
"Programming Language :: Python :: 3.10",
3635
"Programming Language :: Python :: 3.11",
3736
"Programming Language :: Python :: 3.12",

‎skorch/classifier.py‎

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,13 @@
4242
"""
4343

4444
def get_neural_net_clf_doc(doc):
45-
indentation = " "
4645
# dedent/indent roundtrip required for consistent indention in both
4746
# Python <3.13 and Python >=3.13
48-
# Because <3.13 => not automatic dedent, but it is the case in >=3.13
49-
doc = neural_net_clf_doc_start + " " + textwrap.indent(textwrap.dedent(doc.split("\n", 5)[-1]), indentation)
47+
# Because <3.13 => no automatic dedent, but it is the case in >=3.13
48+
indentation = " "
49+
doc = textwrap.indent(textwrap.dedent(doc.split("\n", 5)[-1]), indentation)
50+
51+
doc = neural_net_clf_doc_start + " " + doc
5052
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+|.){1,99}')
5153
start, end = pattern.search(doc).span()
5254
doc = doc[:start] + neural_net_clf_additional_text + doc[end:]
@@ -253,11 +255,13 @@ def predict(self, X):
253255
is used by ``predict`` and ``predict_proba`` for classification."""
254256

255257
def get_neural_net_binary_clf_doc(doc):
256-
indentation = " "
257258
# dedent/indent roundtrip required for consistent indention in both
258259
# Python <3.13 and Python >=3.13
259-
# Because <3.13 => not automatic dedent, but it is the case in >=3.13
260-
doc = neural_net_binary_clf_doc_start + " " + textwrap.indent(textwrap.dedent(doc.split("\n", 5)[-1]), indentation)
260+
# Because <3.13 => no automatic dedent, but it is the case in >=3.13
261+
indentation = " "
262+
doc = textwrap.indent(textwrap.dedent(doc.split("\n", 5)[-1]), indentation)
263+
264+
doc = neural_net_binary_clf_doc_start + " " + doc
261265
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+|.){1,99}')
262266
start, end = pattern.search(doc).span()
263267
doc = doc[:start] + neural_net_binary_clf_criterion_text + doc[end:]

‎skorch/probabilistic.py‎

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pickle
1010
import re
11+
import textwrap
1112

1213
import gpytorch
1314
import numpy as np
@@ -471,7 +472,6 @@ def _predict(self, X):
471472
Module : gpytorch.models.ExactGP (class or instance)
472473
The module needs to return a
473474
:class:`~gpytorch.distributions.MultivariateNormal` distribution.
474-
475475
"""
476476

477477
exact_gp_regr_criterion_text = """
@@ -483,7 +483,6 @@ def _predict(self, X):
483483
criterion : gpytorch.mlls.ExactMarginalLogLikelihood
484484
The objective function to learn the posterior of of the GP regressor.
485485
Usually doesn't need to be changed.
486-
487486
"""
488487

489488
exact_gp_regr_batch_size_text = """
@@ -492,7 +491,6 @@ def _predict(self, X):
492491
Mini-batch size. For exact GPs, it must be set to -1, since the exact
493492
solution cannot deal with batching. To make use of batching, use
494493
:class:`.GPRegressor` in conjunction with a variational strategy.
495-
496494
"""
497495

498496
# this is the same text for exact and approximate GP regression
@@ -505,37 +503,40 @@ def _predict(self, X):
505503
None. There is no default train split for GP regressors because random
506504
splitting is typically not desired, e.g. because there is a temporal
507505
relationship between samples.
508-
509506
"""
510507

511508
# this is the same text for all GPs
512509
gp_likelihood_attribute_text = """
513-
514510
likelihood_: torch module (instance)
515511
The instantiated likelihood.
516-
517512
"""
518513

519514

520515
def get_exact_gp_regr_doc(doc):
521516
"""Customizes the net docs to avoid duplication."""
517+
# dedent/indent roundtrip required for consistent indention in both
518+
# Python <3.13 and Python >=3.13
519+
# Because <3.13 => no automatic dedent, but it is the case in >=3.13
520+
indentation = " "
521+
doc = textwrap.indent(textwrap.dedent(doc.split("\n", 5)[-1]), indentation)
522+
522523
params_start_idx = doc.find(' Parameters\n ----------')
523524
doc = doc[params_start_idx:]
524-
doc = exact_gp_regr_doc_start + " " + doc
525+
doc = exact_gp_regr_doc_start + doc
525526

526-
pattern = re.compile(r'(\n\s+)(module .*\n)(\s.+){1,99}')
527+
pattern = re.compile(r'(\n\s+)(module .*\n)(\s.+|.){1,99}')
527528
start, end = pattern.search(doc).span()
528529
doc = doc[:start] + exact_gp_regr_module_text + doc[end:]
529530

530-
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+){1,99}')
531+
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+|.){1,99}')
531532
start, end = pattern.search(doc).span()
532533
doc = doc[:start] + exact_gp_regr_criterion_text + doc[end:]
533534

534-
pattern = re.compile(r'(\n\s+)(batch_size .*\n)(\s.+){1,99}')
535+
pattern = re.compile(r'(\n\s+)(batch_size .*\n)(\s.+|.){1,99}')
535536
start, end = pattern.search(doc).span()
536537
doc = doc[:start] + exact_gp_regr_batch_size_text + doc[end:]
537538

538-
pattern = re.compile(r'(\n\s+)(train_split .*\n)(\s.+){1,99}')
539+
pattern = re.compile(r'(\n\s+)(train_split .*\n)(\s.+|.){1,99}')
539540
start, end = pattern.search(doc).span()
540541
doc = doc[:start] + gp_regr_train_split_text + doc[end:]
541542

@@ -672,7 +673,6 @@ def fit(self, X, y=None, **fit_params):
672673
Module : gpytorch.models.ApproximateGP (class or instance)
673674
The GPyTorch module; in contrast to exact GP, the return distribution does
674675
not need to be Gaussian.
675-
676676
"""
677677

678678
gp_regr_criterion_text = """
@@ -684,25 +684,30 @@ def fit(self, X, y=None, **fit_params):
684684
criterion : gpytorch.mlls.VariationalELBO
685685
The objective function to learn the approximate posterior of of the GP
686686
regressor.
687-
688687
"""
689688

690689

691690
def get_gp_regr_doc(doc):
692691
"""Customizes the net docs to avoid duplication."""
692+
# dedent/indent roundtrip required for consistent indention in both
693+
# Python <3.13 and Python >=3.13
694+
# Because <3.13 => no automatic dedent, but it is the case in >=3.13
695+
indentation = " "
696+
doc = textwrap.indent(textwrap.dedent(doc.split("\n", 5)[-1]), indentation)
697+
693698
params_start_idx = doc.find(' Parameters\n ----------')
694699
doc = doc[params_start_idx:]
695-
doc = gp_regr_doc_start + " " + doc
700+
doc = gp_regr_doc_start + doc
696701

697-
pattern = re.compile(r'(\n\s+)(module .*\n)(\s.+){1,99}')
702+
pattern = re.compile(r'(\n\s+)(module .*\n)(\s.+|.){1,99}')
698703
start, end = pattern.search(doc).span()
699704
doc = doc[:start] + gp_regr_module_text + doc[end:]
700705

701-
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+){1,99}')
706+
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+|.){1,99}')
702707
start, end = pattern.search(doc).span()
703708
doc = doc[:start] + gp_regr_criterion_text + doc[end:]
704709

705-
pattern = re.compile(r'(\n\s+)(train_split .*\n)(\s.+){1,99}')
710+
pattern = re.compile(r'(\n\s+)(train_split .*\n)(\s.+|.){1,99}')
706711
start, end = pattern.search(doc).span()
707712
doc = doc[:start] + gp_regr_train_split_text + doc[end:]
708713

@@ -744,7 +749,6 @@ def __init__(
744749
Module : gpytorch.models.ApproximateGP (class or instance)
745750
The GPyTorch module; in contrast to exact GP, the return distribution does
746751
not need to be Gaussian.
747-
748752
"""
749753

750754
gp_binary_clf_criterion_text = """
@@ -756,21 +760,26 @@ def __init__(
756760
criterion : gpytorch.mlls.VariationalELBO
757761
The objective function to learn the approximate posterior of of the GP
758762
binary classification.
759-
760763
"""
761764

762765

763766
def get_gp_binary_clf_doc(doc):
764767
"""Customizes the net docs to avoid duplication."""
768+
# dedent/indent roundtrip required for consistent indention in both
769+
# Python <3.13 and Python >=3.13
770+
# Because <3.13 => no automatic dedent, but it is the case in >=3.13
771+
indentation = " "
772+
doc = textwrap.indent(textwrap.dedent(doc.split("\n", 5)[-1]), indentation)
773+
765774
params_start_idx = doc.find(' Parameters\n ----------')
766775
doc = doc[params_start_idx:]
767-
doc = gp_binary_clf_doc_start + " " + doc
776+
doc = gp_binary_clf_doc_start + doc
768777

769-
pattern = re.compile(r'(\n\s+)(module .*\n)(\s.+){1,99}')
778+
pattern = re.compile(r'(\n\s+)(module .*\n)(\s.+|.){1,99}')
770779
start, end = pattern.search(doc).span()
771780
doc = doc[:start] + gp_binary_clf_module_text + doc[end:]
772781

773-
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+){1,99}')
782+
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+|.){1,99}')
774783
start, end = pattern.search(doc).span()
775784
doc = doc[:start] + gp_binary_clf_criterion_text + doc[end:]
776785

‎skorch/tests/test_net.py‎

Lines changed: 4 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -511,16 +511,11 @@ def test_pickle_save_and_load_mixed_devices(
511511
# 1. one for the failed load
512512
# 2. for switching devices on the net instance
513513
# remove possible future warning about weights_only=False
514-
# TODO: remove filter when torch<=2.4 is dropped
515-
w_list = [
516-
warning for warning in w.list
517-
if "weights_only=False" not in warning.message.args[0]
518-
]
519-
assert len(w_list) == 2
520-
assert w_list[0].message.args[0] == (
514+
assert len(w.list) == 2
515+
assert w.list[0].message.args[0] == (
521516
'Requested to load data to CUDA but no CUDA devices '
522517
'are available. Loading on device "cpu" instead.')
523-
assert w_list[1].message.args[0] == (
518+
assert w.list[1].message.args[0] == (
524519
'Setting self.device = {} since the requested device ({}) '
525520
'is not available.'.format(load_dev, save_dev))
526521

@@ -3085,46 +3080,10 @@ def test_torch_load_kwargs_forwarded_to_torch_load(
30853080
del call_kwargs['map_location'] # we're not interested in that
30863081
assert call_kwargs == expected_kwargs
30873082

3088-
def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6(
3083+
def test_torch_load_kwargs_auto_weights_true(
30893084
self, net_cls, module_cls, monkeypatch, tmp_path
30903085
):
3091-
# Same test as
3092-
# test_torch_load_kwargs_auto_weights_only_false_when_load_params but
3093-
# without monkeypatching get_default_torch_load_kwargs. The default is
3094-
# weights_only=False.
30953086
# See discussion in 1063.
3096-
from skorch._version import Version
3097-
3098-
# TODO remove once torch 2.5.0 is no longer supported
3099-
if Version(torch.__version__) >= Version('2.6.0'):
3100-
pytest.skip("Test only for torch < v2.6.0")
3101-
3102-
net = net_cls(module_cls).initialize()
3103-
net.save_params(f_params=tmp_path / 'params.pkl')
3104-
state_dict = net.module_.state_dict()
3105-
expected_kwargs = {"weights_only": False}
3106-
3107-
mock_torch_load = Mock(return_value=state_dict)
3108-
monkeypatch.setattr(torch, "load", mock_torch_load)
3109-
net.load_params(f_params=tmp_path / 'params.pkl')
3110-
3111-
call_kwargs = mock_torch_load.call_args_list[0].kwargs
3112-
del call_kwargs['map_location'] # we're not interested in that
3113-
assert call_kwargs == expected_kwargs
3114-
3115-
def test_torch_load_kwargs_auto_weights_true_pytorch_ge_2_6(
3116-
self, net_cls, module_cls, monkeypatch, tmp_path
3117-
):
3118-
# Same test as
3119-
# test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6 but
3120-
# with weights_only=True, since it's the new default
3121-
# See discussion in 1063.
3122-
from skorch._version import Version
3123-
3124-
# TODO remove once torch 2.5.0 is no longer supported
3125-
if Version(torch.__version__) < Version('2.6.0'):
3126-
pytest.skip("Test only for torch >= 2.6.0")
3127-
31283087
net = net_cls(module_cls).initialize()
31293088
net.save_params(f_params=tmp_path / 'params.pkl')
31303089
state_dict = net.module_.state_dict()
@@ -4332,16 +4291,6 @@ def test_compile_missing_dunder_in_prefix_arguments(
43324291
).initialize()
43334292

43344293
def test_fit_and_predict_with_compile(self, net_cls, module_cls, data):
4335-
if not hasattr(torch, 'compile'):
4336-
pytest.skip(reason="torch.compile not available")
4337-
4338-
# python 3.12 requires torch >= 2.4 to support compile
4339-
# TODO: remove once we remove support for torch < 2.4
4340-
from skorch._version import Version
4341-
4342-
if Version(torch.__version__) < Version('2.4.0') and sys.version_info >= (3, 12):
4343-
pytest.skip(reason="When using Python 3.12, torch.compile requires torch >= 2.4")
4344-
43454294
# use real torch.compile, not mocked, can be a bit slow
43464295
X, y = data
43474296
net = net_cls(module_cls, max_epochs=1, compile=True).initialize()
@@ -4362,13 +4311,6 @@ def test_binary_classifier_with_compile(self, data):
43624311
# because of a failing isinstance check
43634312
from skorch import NeuralNetBinaryClassifier
43644313

4365-
# python 3.12 requires torch >= 2.4 to support compile
4366-
# TODO: remove once we remove support for torch < 2.4
4367-
from skorch._version import Version
4368-
4369-
if Version(torch.__version__) < Version('2.4.0') and sys.version_info >= (3, 12):
4370-
pytest.skip(reason="When using Python 3.12, torch.compile requires torch >= 2.4")
4371-
43724314
X, y = data[0], data[1].astype(np.float32)
43734315

43744316
class MyNet(nn.Module):

0 commit comments

Comments
 (0)