|
1 | 1 | import numpy as np |
2 | 2 | import sympy |
3 | 3 | import matplotlib.pyplot as plt |
| 4 | +import timeit |
4 | 5 |
|
5 | 6 | import KalmanFilter |
| 7 | +import KalmanFilterPurePython |
6 | 8 |
|
| 9 | +# Calculate parameters for Kalman filter |
7 | 10 | t = sympy.symbols('t', real = True, positive=True) |
8 | 11 | s = sympy.symbols('s', real = True, positive=True) |
9 | 12 | w = sympy.symbols('w', real = True, positive=True) |
@@ -31,7 +34,8 @@ def TransformSystemsDynamicsMatrix(F): |
31 | 34 | class data: |
32 | 35 | pass |
33 | 36 |
|
34 | | -data.time = np.arange(0, 10*2*np.pi/w, Ts) |
| 37 | +data.time = np.arange(0, 100*2*np.pi/w, Ts) # 100 periods of oscillation |
| 38 | +print("{} Points to be filtered".format(len(data.time))) |
35 | 39 | noiseAmp = 5 |
36 | 40 | data.trueSignal = 300 + 5*np.sin(w*data.time + 0) |
37 | 41 | data.signal = data.trueSignal + np.random.normal(0, noiseAmp, len(data.trueSignal)) |
@@ -59,12 +63,40 @@ class data: |
59 | 63 |
|
60 | 64 | print("Executing Filtering") |
61 | 65 |
|
| 66 | +start_time = timeit.default_timer() |
62 | 67 | ResultData = kf.FilterData(data.signal, len(data.signal)) |
| 68 | +elapsed = timeit.default_timer() - start_time |
| 69 | +print("C++ filtering took {}s".format(elapsed)) |
63 | 70 |
|
| 71 | +print("Constructing Python Filter") |
| 72 | + |
| 73 | +KF2 = KalmanFilterPurePython.KalmanFilterLinear(A, B, H, x_init, P, Q, R) |
| 74 | + |
| 75 | +print("Executing Python Filtering") |
| 76 | + |
| 77 | +start_time = timeit.default_timer() |
| 78 | +KalmanPredictionArray = [] |
| 79 | +KalmanErrorArray = [] |
| 80 | +KalmanGainArray = [] |
| 81 | +for i, x in enumerate(data.signal): |
| 82 | + X_state = np.matrix('{0} ; {1}'.format(x, 0)) |
| 83 | + KF2.Step(np.matrix([0]), np.matrix([x])) |
| 84 | + KalmanPredictionArray.append(KF2.current_state_estimate) |
| 85 | + KalmanErrorArray.append(KF2.current_prob_estimate) |
| 86 | + KalmanGainArray.append(KF2.KG) |
| 87 | + |
| 88 | +KalmanPredictionArray = np.array(KalmanPredictionArray) |
| 89 | +KalmanErrorArray = np.array(KalmanErrorArray) |
| 90 | +KalmanGainArray = np.array(KalmanGainArray) |
| 91 | + |
| 92 | +x = KalmanPredictionArray[:, 0] + KalmanPredictionArray[:, 2] |
| 93 | +elapsed = timeit.default_timer() - start_time |
| 94 | +print("C++ filtering took {}s".format(elapsed)) |
64 | 95 |
|
65 | 96 | plt.figure(figsize=[10, 10]) |
66 | 97 | plt.plot(data.time, data.signal, label="noisy sine wave", color='blue', alpha=0.6) |
67 | 98 | plt.plot(data.time, data.trueSignal, label="pure source sine wave", lw=3, color='red', alpha=0.8) |
| 99 | +plt.plot(data.time, x, label="Pure Python Kalman Filter Output", lw=3, color='darkblue', alpha=0.9) |
68 | 100 | plt.plot(data.time, ResultData, label="C++ Kalman Filter Output", lw=3, color='darkred', alpha=0.9) |
69 | 101 |
|
70 | 102 | plt.legend(loc="best") |
|
0 commit comments