Skip to content

Commit 6e14a9f

Browse files
committed
[DOC] Update the README.md
1 parent 7d46da8 commit 6e14a9f

3 files changed

Lines changed: 180 additions & 122 deletions

File tree

‎README.md‎

Lines changed: 0 additions & 109 deletions
This file was deleted.

‎README.rst‎

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
Ensemble-Pytorch
2+
================
3+
4+
Implementation of scikit-learn like ensemble methods in Pytorch.
5+
6+
News
7+
----
8+
9+
- Apart from installing from source, now you can easily install
10+
Ensemble-PyTorch with:
11+
12+
::
13+
14+
pip install -i https://test.pypi.org/simple/ torchensemble
15+
16+
- A pre-released documentation is available at
17+
https://ensemble-pytorch.readthedocs.io/en/latest/.
18+
19+
Methods
20+
-------
21+
22+
- **FusionClassifier** / **FusionRegressor**
23+
- In ``Fusion``, the output from all base estimators is first
24+
aggregated as an average output. After then, a loss is computed based
25+
on the average output and the ground-truth. Next, all base estimators
26+
are jointly trained with back-propagation.
27+
- **VotingClassifier** / **VotingRegressor**
28+
- In ``Voting``, each base estimator is independently trained. The
29+
majority voting is adopted for classification, and the average over
30+
predictions from all base estimators is adopted for regression.
31+
- **BaggingClassifier** / **BaggingRegressor**
32+
- The training stage of ``Bagging`` is similar to that of ``Voting``.
33+
In addition, sampling with replacement is adopted when training each
34+
base estimator to introduce more diversity.
35+
- **GradientBoostingClassifier** / **GradientBoostingRegressor**
36+
- In ``GradientBoosting``, the learning target of a newly-added base
37+
estimator is to fit toward the negative gradient of the output from
38+
base estimators previously fitted with respect to the loss function
39+
and the ground-truth, using least square regression.
40+
41+
Installation
42+
------------
43+
44+
Installing Ensemble-Pytorch package is simple. Just clone this repo and
45+
run ``setup.py``.
46+
47+
::
48+
49+
$ git clone https://github.com/AaronX121/Ensemble-Pytorch.git
50+
$ cd Ensemble-Pytorch
51+
$ pip install -r requirements.txt
52+
$ python setup.py install
53+
54+
Minimal example on how to use
55+
-----------------------------
56+
57+
.. code:: python
58+
59+
"""
60+
- Please see scripts in `examples` for details on how to use
61+
- Please see implementations in `torchensemble` for details on ensemble methods
62+
- Please feel free to open an issue if you have any problem or feature request
63+
"""
64+
65+
from torchensemble import ensemble_method # import ensemble method (e.g., VotingClassifier)
66+
67+
# Define the base estimator
68+
base_estimator = torch.nn.Module(...) # class of base estimaotr (e.g., CNN)
69+
70+
# Define the ensemble model
71+
model = ensemble_method(estimator=base_estimator, # base estimator
72+
n_estimators=10, # number of base estimators
73+
output_dim=output_dim, # e.g., the number of classes for classification
74+
lr=learning_rate, # learning rate of the optimizer
75+
weight_decay=weight_decay, # weight decay of model parameters
76+
epochs=epochs) # number of training epochs
77+
78+
# Load data
79+
train_loader = DataLoader(...)
80+
test_loader = DataLoader(...)
81+
82+
# Train
83+
model.fit(train_loader)
84+
85+
# Evaluate
86+
model.predict(test_loader)
87+
88+
Benchmarks
89+
----------
90+
91+
- **Classification on CIFAR-10**
92+
- The table below presents the classification accuracy of different
93+
ensemble classifiers on the testing data of **CIFAR-10**
94+
- Each classifier uses **10** LeNet-5 model (with RELU activation and
95+
Dropout) as the base estimators
96+
- Each base estimator is trained over **100** epochs, with batch size
97+
**128**, learning rate **1e-3**, and weight decay **5e-4**
98+
- Experiment results can be reproduced by running
99+
``./examples/classification_cifar10_cnn.py``
100+
101+
+----------------------------------+---------------+-------------------+-------------------+
102+
| Model Name | Params (MB) | Testing Acc (%) | Improvement (%) |
103+
+==================================+===============+===================+===================+
104+
| **Single LeNet-5** | 0.32 | 73.04 | - |
105+
+----------------------------------+---------------+-------------------+-------------------+
106+
| **FusionClassifier** | 3.17 | 78.75 | + 5.71 |
107+
+----------------------------------+---------------+-------------------+-------------------+
108+
| **VotingClassifier** | 3.17 | 80.08 | + 7.04 |
109+
+----------------------------------+---------------+-------------------+-------------------+
110+
| **BaggingClassifier** | 3.17 | 78.75 | + 5.71 |
111+
+----------------------------------+---------------+-------------------+-------------------+
112+
| **GradientBoostingClassifier** | 3.17 | 80.82 | + 7.78 |
113+
+----------------------------------+---------------+-------------------+-------------------+
114+
115+
- **Regression on YearPredictionMSD**
116+
- The table below presents the mean squared error (MSE) of different
117+
ensemble regressors on the testing data of **YearPredictionMSD**
118+
- Each regressor uses **10** multi-layered perceptron (MLP) model (with
119+
RELU activation and Dropout) as the base estimators, and the network
120+
architecture is fixed as ``Input-128-128-Output``
121+
- Each base estimator is trained over **50** epochs, with batch size
122+
**256**, learning rate **1e-3**, and weight decay **5e-4**
123+
- Experiment results can be reproduced by running
124+
``./examples/regression_YearPredictionMSD_mlp.py``
125+
126+
+---------------------------------+---------------+---------------+---------------+
127+
| Model Name | Params (MB) | Testing MSE | Improvement |
128+
+=================================+===============+===============+===============+
129+
| **Single MLP** | 0.11 | 0.83 | - |
130+
+---------------------------------+---------------+---------------+---------------+
131+
| **FusionRegressor** | 1.08 | 0.73 | - 0.10 |
132+
+---------------------------------+---------------+---------------+---------------+
133+
| **VotingRegressor** | 1.08 | 0.69 | - 0.14 |
134+
+---------------------------------+---------------+---------------+---------------+
135+
| **BaggingRegressor** | 1.08 | 0.70 | - 0.13 |
136+
+---------------------------------+---------------+---------------+---------------+
137+
| **GradientBoostingRegressor** | 1.08 | 0.71 | - 0.12 |
138+
+---------------------------------+---------------+---------------+---------------+
139+
140+
Package dependencies
141+
--------------------
142+
143+
- joblib>=0.11
144+
- scikit-learn>=0.23.0
145+
- torch>=0.4.1
146+
- torchvision>=0.2.2
147+
148+
TODO
149+
~~~~
150+
151+
I have listed some things planing to do in the next, and I would be very
152+
happy to have someone join me to make this lib better.
153+
154+
- Add ``StackingClassifier`` and ``StackingRegressor``.
155+
- Add ``SoftGradientBoostingClassifier`` and
156+
``SoftGradientBoostingRegressor``.
157+
- Add more training options such as the type of optimizer.
158+
- Add more callbacks to ``predict``.
159+
- Add PyTest scripts.
160+
- Upload to PyPI.
161+
- Build the documentation.
162+

‎setup.py‎

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
here = path.abspath(path.dirname(__file__))
99

10-
# Get the long description from README.md
11-
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
10+
# Get the long description from README.rst
11+
with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
1212
long_description = f.read()
1313

1414
# get the dependencies and installs
@@ -54,28 +54,33 @@ def run(self):
5454

5555
setup(
5656
name='torchensemble',
57-
version='1.0.0',
58-
author='AaronX121',
59-
57+
maintainer='Yi-Xuan Xu',
58+
maintainer_email='xuyx@lamda.nju.edu.cn',
6059
description=('Implementations of scikit-learn like ensemble methods in'
6160
' Pytorch'),
62-
long_description=long_description,
63-
long_description_content_type='text/markdown',
61+
license='BSD 3-Clause',
6462
url='https://github.com/AaronX121/Ensemble-Pytorch',
63+
project_urls={
64+
'Bug Tracker': 'https://github.com/AaronX121/Ensemble-Pytorch/issues',
65+
'Documentation': 'https://ensemble-pytorch.readthedocs.io/en/latest/',
66+
'Source Code': 'https://github.com/AaronX121/Ensemble-Pytorch'},
67+
version='0.0.1',
68+
long_description=long_description,
6569
classifiers=[
6670
'Intended Audience :: Science/Research',
6771
'Intended Audience :: Developers',
68-
'License :: OSI Approved',
69-
'Programming Language :: C',
70-
'Programming Language :: Python',
7172
'Topic :: Software Development',
7273
'Topic :: Scientific/Engineering',
74+
'Development Status :: 4 - Beta',
7375
'Operating System :: Microsoft :: Windows',
7476
'Operating System :: POSIX',
7577
'Operating System :: Unix',
76-
'Operating System :: MacOS'],
77-
keywords='Ensemble Learning',
78-
78+
'Operating System :: MacOS',
79+
'Programming Language :: Python :: 3',
80+
'Programming Language :: Python :: 3.6',
81+
'Programming Language :: Python :: 3.7',
82+
'Programming Language :: Python :: 3.8'],
83+
keywords=['PyTorch', 'Ensemble Learning'],
7984
packages=find_packages(),
8085
cmdclass=cmdclass,
8186
install_requires=install_requires,

0 commit comments

Comments
 (0)