Skip to content

Commit 566b3b1

Browse files
committed
pid: optimize calculations with pre-computed gains and vec3 operations
1 parent 5c94e15 commit 566b3b1

File tree

6 files changed

+136
-78
lines changed

6 files changed

+136
-78
lines changed

src/core/looptime.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ void looptime_update() {
8787

8888
state.looptime_us = CYCLES_TO_US(time_cycles() - last_loop_cycles);
8989
state.looptime = state.looptime_us * 1e-6f;
90-
// 0.0032f is there for legacy purposes, should be 0.001f = looptime
91-
state.timefactor = 0.0032f / state.looptime;
90+
// looptime_inverse is the loop frequency (1/looptime)
91+
state.looptime_inverse = 1.0f / state.looptime;
9292
state.loop_counter++;
9393

9494
last_loop_cycles = time_cycles();

src/flight/angle_pid.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ float angle_pid(int x) {
1717
const float angle_error_abs = fabsf(state.angle_error.axis[x]);
1818

1919
const float small_angle = (1 - angle_error_abs) * state.angle_error.axis[x] * profile.pid.small_angle.kp // P term weighted
20-
+ ((state.angle_error.axis[x] - lasterror.axis[x]) * profile.pid.small_angle.kd * (1 - angle_error_abs) * state.timefactor); // D term weighted
20+
+ ((state.angle_error.axis[x] - lasterror.axis[x]) * profile.pid.small_angle.kd * (1 - angle_error_abs) * state.looptime_inverse); // D term weighted
2121

2222
const float big_angle = angle_error_abs * state.angle_error.axis[x] * profile.pid.big_angle.kp // P term weighted
23-
+ ((state.angle_error.axis[x] - lasterror.axis[x]) * profile.pid.big_angle.kd * angle_error_abs * state.timefactor); // D term weighted
23+
+ ((state.angle_error.axis[x] - lasterror.axis[x]) * profile.pid.big_angle.kd * angle_error_abs * state.looptime_inverse); // D term weighted
2424

2525
lasterror.axis[x] = state.angle_error.axis[x];
2626

src/flight/control.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ typedef struct {
4444
float looptime; // looptime in seconds
4545
float looptime_us; // looptime in us
4646
float looptime_autodetect; // desired looptime in us
47-
float timefactor; // timefactor for pid calc
47+
float looptime_inverse; // 1/looptime for derivative calculations
4848
uint32_t loop_counter; // number of loops ran
4949

5050
float uptime; // running sum of looptimes
@@ -116,7 +116,7 @@ typedef struct {
116116
MEMBER(looptime, float) \
117117
MEMBER(looptime_us, float) \
118118
MEMBER(looptime_autodetect, float) \
119-
MEMBER(timefactor, float) \
119+
MEMBER(looptime_inverse, float) \
120120
MEMBER(loop_counter, uint32_t) \
121121
MEMBER(uptime, float) \
122122
MEMBER(armtime, float) \

src/flight/pid.c

Lines changed: 86 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "flight/filter.h"
1212
#include "io/led.h"
1313
#include "util/util.h"
14+
#include "util/vector.h"
1415

1516
#define PID_SIZE 3
1617

@@ -21,20 +22,19 @@
2122
#define RELAX_FACTOR_YAW (RELAX_FACTOR_YAW_DEG * DEGTORAD)
2223

2324
/// output limit
24-
static const float out_limit[PID_SIZE] = {0.8, 0.8, 0.6};
25+
static const vec3_t out_limit = {.roll = 0.8f, .pitch = 0.8f, .yaw = 0.6f};
2526

2627
// limit of integral term (abs)
27-
static const float integral_limit[PID_SIZE] = {0.8, 0.8, 0.6};
28+
static const vec3_t integral_limit = {.roll = 0.8f, .pitch = 0.8f, .yaw = 0.6f};
2829

29-
static const float pid_scales[PID_SIZE][PID_SIZE] = {
30-
// roll, pitch, yaw
31-
{1.0f / 628.0f, 1.0f / 628.0f, 1.0f / 314.0f}, // kp
32-
{1.0f / 50.0f, 1.0f / 50.0f, 1.0f / 50.0f}, // ki
33-
{1.0f / 120.0f, 1.0f / 120.0f, 1.0f / 120.0f}, // kd
30+
static const vec3_t pid_scales[PID_SIZE] = {
31+
{.roll = 1.0f / 628.0f, .pitch = 1.0f / 628.0f, .yaw = 1.0f / 314.0f}, // kp
32+
{.roll = 1.0f / 100.0f, .pitch = 1.0f / 100.0f, .yaw = 1.0f / 100.0f}, // ki - includes historical 0.5x scaling from Silverware
33+
{.roll = 1.0f / 37500.0f, .pitch = 1.0f / 37500.0f, .yaw = 1.0f / 37500.0f}, // kd - includes 0.0032 constant (0.0032 / 120 = 1 / 37500)
3434
};
3535

36-
static float lastrate[PID_SIZE] = {0, 0, 0};
37-
static float lastsetpoint[PID_SIZE] = {0, 0, 0};
36+
static vec3_t lastrate = {.roll = 0, .pitch = 0, .yaw = 0};
37+
static vec3_t lastsetpoint = {.roll = 0, .pitch = 0, .yaw = 0};
3838

3939
static vec3_t ierror;
4040
static vec3_t last_error;
@@ -58,31 +58,41 @@ void pid_init() {
5858
}
5959

6060
// (iwindup = 0 windup is not allowed) (iwindup = 1 windup is allowed)
61-
static inline float pid_compute_iterm_windup(uint8_t x, float pid_output) {
62-
if ((pid_output >= out_limit[x]) && (state.error.axis[x] > 0)) {
63-
return 0.0f;
64-
}
65-
if ((pid_output <= -out_limit[x]) && (state.error.axis[x] < 0)) {
66-
return 0.0f;
67-
}
61+
static inline vec3_t pid_compute_iterm_windup_vec(const vec3_t *pid_output) {
62+
vec3_t windup = {.roll = 1.0f, .pitch = 1.0f, .yaw = 1.0f};
6863

69-
#ifdef ITERM_RELAX // Roll - Pitch Setpoint based I term relax method
70-
static float avg_setpoint[3] = {0, 0, 0};
71-
if (x < 2) {
72-
lpf(&avg_setpoint[x], state.setpoint.axis[x], lpfcalc(state.looptime, 1.0f / (float)RELAX_FREQUENCY_HZ)); // 11 Hz filter
73-
const float hpfSetpoint = fabsf(state.setpoint.axis[x] - avg_setpoint[x]);
74-
return max(1.0f - hpfSetpoint / RELAX_FACTOR, 0.0f);
75-
}
64+
for (uint8_t x = 0; x < PID_SIZE; x++) {
65+
if ((pid_output->axis[x] >= out_limit.axis[x]) && (state.error.axis[x] > 0)) {
66+
windup.axis[x] = 0.0f;
67+
} else if ((pid_output->axis[x] <= -out_limit.axis[x]) && (state.error.axis[x] < 0)) {
68+
windup.axis[x] = 0.0f;
69+
}
70+
#ifdef ITERM_RELAX
71+
else {
72+
static vec3_t avg_setpoint = {.roll = 0, .pitch = 0, .yaw = 0};
73+
static float lpf_coeff = 0;
74+
static float lpf_coeff_yaw = 0;
75+
if (lpf_coeff == 0) {
76+
lpf_coeff = lpfcalc(state.looptime, 1.0f / (float)RELAX_FREQUENCY_HZ);
77+
lpf_coeff_yaw = lpfcalc(state.looptime, 1.0f / (float)RELAX_FREQUENCY_HZ_YAW);
78+
}
79+
if (x < 2) {
80+
lpf(&avg_setpoint.axis[x], state.setpoint.axis[x], lpf_coeff);
81+
const float hpfSetpoint = fabsf(state.setpoint.axis[x] - avg_setpoint.axis[x]);
82+
windup.axis[x] = max(1.0f - hpfSetpoint / RELAX_FACTOR, 0.0f);
83+
}
7684
#ifdef ITERM_RELAX_YAW
77-
else { // axis is yaw
78-
lpf(&avg_setpoint[x], state.setpoint.axis[x], lpfcalc(state.looptime, 1.0f / (float)RELAX_FREQUENCY_HZ_YAW)); // 25 Hz filter
79-
const float hpfSetpoint = fabsf(state.setpoint.axis[x] - avg_setpoint[x]);
80-
return max(1.0f - hpfSetpoint / RELAX_FACTOR_YAW, 0.0f);
81-
}
85+
else {
86+
lpf(&avg_setpoint.axis[x], state.setpoint.axis[x], lpf_coeff_yaw);
87+
const float hpfSetpoint = fabsf(state.setpoint.axis[x] - avg_setpoint.axis[x]);
88+
windup.axis[x] = max(1.0f - hpfSetpoint / RELAX_FACTOR_YAW, 0.0f);
89+
}
8290
#endif
91+
}
8392
#endif
93+
}
8494

85-
return 1.0f;
95+
return windup;
8696
}
8797

8898
static inline float pid_filter_dterm(uint8_t x, float dterm) {
@@ -92,26 +102,29 @@ static inline float pid_filter_dterm(uint8_t x, float dterm) {
92102
return dterm;
93103
}
94104

95-
static inline bool pid_should_enable_iterm(uint8_t x) {
105+
static inline vec3_t pid_should_enable_iterm_vec() {
96106
static bool stick_movement = false;
97107
if (!flags.arm_state) {
98108
// disarmed, disable, flag no movement
99109
stick_movement = false;
100-
return false;
110+
return (vec3_t){{0.0f, 0.0f, 0.0f}};
101111
}
102112
if (flags.in_air) {
103113
// in-air, enable, flag no movement
104114
stick_movement = false;
105-
return true;
115+
return (vec3_t){{1.0f, 1.0f, 1.0f}};
106116
}
107117

108-
if (fabsf(state.setpoint.axis[x]) > 0.1f) {
109-
// record first stick crossing
118+
// Check for stick movement on any axis
119+
if (fabsf(state.setpoint.roll) > 0.1f ||
120+
fabsf(state.setpoint.pitch) > 0.1f ||
121+
fabsf(state.setpoint.yaw) > 0.1f) {
110122
stick_movement = true;
111123
}
112124

113-
// enable if we recored stick movement previously
114-
return stick_movement;
125+
// enable if we recorded stick movement previously
126+
const float enable = stick_movement ? 1.0f : 0.0f;
127+
return (vec3_t){{enable, enable, enable}};
115128
}
116129

117130
static inline float pid_voltage_compensation() {
@@ -159,30 +172,36 @@ void pid_calc() {
159172
// rotates errors
160173
ierror = vec3_rotate(ierror, state.gyro_delta_angle);
161174

162-
#pragma GCC unroll 3
163-
for (uint8_t x = 0; x < PID_SIZE; x++) {
164-
const float current_kp = profile_current_pid_rates()->kp.axis[x] * pid_scales[0][x];
165-
const float current_ki = profile_current_pid_rates()->ki.axis[x] * pid_scales[1][x];
166-
const float current_kd = profile_current_pid_rates()->kd.axis[x] * pid_scales[2][x];
175+
// Calculate deltas for derivatives
176+
const vec3_t setpoint_delta = vec3_sub(state.setpoint, lastsetpoint);
177+
const vec3_t gyro_delta = vec3_sub(state.gyro, lastrate);
167178

168-
// P term
169-
state.pid_p_term.axis[x] = state.error.axis[x] * current_kp;
179+
// Pre-calculate all PID gains using vec3 operations
180+
const pid_rate_t *rates = profile_current_pid_rates();
181+
const vec3_t current_kp = vec3_mul(vec3_mul_elem(rates->kp, pid_scales[0]), v_compensation);
170182

171-
// Pid Voltage Comp applied to P term only
172-
state.pid_p_term.axis[x] *= v_compensation;
183+
// Pre-calculate common terms
184+
const float ki_looptime = state.looptime * (1.0f / 3.0f); // Simpson's rule constant * looptime
173185

174-
// I term
175-
const float iterm_windup = pid_compute_iterm_windup(x, pid_output.axis[x]);
176-
if (!pid_should_enable_iterm(x)) {
177-
// wind down integral while we are still on ground and we do not get any input from the sticks
178-
ierror.axis[x] *= 0.98f;
179-
}
180-
// SIMPSON_RULE_INTEGRAL
181-
// assuming similar time intervals
182-
ierror.axis[x] = ierror.axis[x] + 0.5f * (1.0f / 3.0f) * (last_error2.axis[x] + 4 * last_error.axis[x] + state.error.axis[x]) * current_ki * iterm_windup * state.looptime;
183-
ierror.axis[x] = constrain(ierror.axis[x], -integral_limit[x], integral_limit[x]);
184-
last_error2.axis[x] = last_error.axis[x];
185-
last_error.axis[x] = state.error.axis[x];
186+
// Pre-multiply Ki and Kd with their time factors
187+
const vec3_t current_ki = vec3_mul(vec3_mul_elem(rates->ki, pid_scales[1]), ki_looptime);
188+
const vec3_t current_kd = vec3_mul(vec3_mul_elem(rates->kd, pid_scales[2]), state.looptime_inverse);
189+
const vec3_t iterm_enable = pid_should_enable_iterm_vec();
190+
const vec3_t iterm_windup = pid_compute_iterm_windup_vec(&pid_output);
191+
const bool rx_filter_enabled = state.rx_filter_hz > 0.1f;
192+
193+
#pragma GCC unroll 3
194+
#pragma GCC ivdep // no loop dependencies, safe to vectorize
195+
for (uint8_t x = 0; x < PID_SIZE; x++) {
196+
// P term (voltage compensation already applied to current_kp)
197+
state.pid_p_term.axis[x] = state.error.axis[x] * current_kp.axis[x];
198+
199+
// I term - combine decay and simpson integration
200+
const float simpson_sum = last_error2.axis[x] + 4.0f * last_error.axis[x] + state.error.axis[x];
201+
const float ierror_delta = simpson_sum * current_ki.axis[x] * iterm_windup.axis[x];
202+
// If iterm disabled, decay by 0.98, otherwise add the delta
203+
ierror.axis[x] = iterm_enable.axis[x] ? (ierror.axis[x] + ierror_delta) : (ierror.axis[x] * 0.98f);
204+
ierror.axis[x] = constrain(ierror.axis[x], -integral_limit.axis[x], integral_limit.axis[x]);
186205

187206
state.pid_i_term.axis[x] = ierror.axis[x];
188207

@@ -194,19 +213,23 @@ void pid_calc() {
194213
transition_setpoint_weight = (fabsf(state.rx_filtered.axis[x]) * (stick_transition[x] / stick_accelerator[x])) + (1 - stick_transition[x]);
195214
}
196215

197-
float setpoint_derivative = (state.setpoint.axis[x] - lastsetpoint[x]) * current_kd * state.timefactor;
198-
if (state.rx_filter_hz > 0.1f) {
216+
float setpoint_derivative = setpoint_delta.axis[x] * current_kd.axis[x];
217+
if (rx_filter_enabled) {
199218
setpoint_derivative = filter_lp_pt1_step(&rx_filter, &rx_filter_state[x], setpoint_derivative);
200219
}
201-
lastsetpoint[x] = state.setpoint.axis[x];
202220

203-
const float gyro_derivative = (state.gyro.axis[x] - lastrate[x]) * current_kd * state.timefactor * tda_compensation;
204-
lastrate[x] = state.gyro.axis[x];
221+
const float gyro_derivative = gyro_delta.axis[x] * current_kd.axis[x] * tda_compensation;
205222

206223
const float dterm = (setpoint_derivative * stick_accelerator[x] * transition_setpoint_weight) - (gyro_derivative);
207224
state.pid_d_term.axis[x] = pid_filter_dterm(x, dterm);
208225

209226
state.pidoutput.axis[x] = pid_output.axis[x] = state.pid_p_term.axis[x] + state.pid_i_term.axis[x] + state.pid_d_term.axis[x];
210-
state.pidoutput.axis[x] = constrain(state.pidoutput.axis[x], -out_limit[x], out_limit[x]);
227+
state.pidoutput.axis[x] = constrain(state.pidoutput.axis[x], -out_limit.axis[x], out_limit.axis[x]);
211228
}
229+
230+
// Update history
231+
last_error2 = last_error;
232+
last_error = state.error;
233+
lastsetpoint = state.setpoint;
234+
lastrate = state.gyro;
212235
}

src/util/vector.h

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,41 @@ cbor_result_t cbor_decode_compact_vec3_t(cbor_value_t *dec, compact_vec3_t *vec)
3131
void vec3_from_array(vec3_t *out, float *in);
3232
void vec3_compress(compact_vec3_t *out, vec3_t *in, float scale);
3333

34-
// Rodrigues' rotation formula, used in hot paths, lets inline
34+
// Helper functions for vec3_rotate
35+
static inline vec3_t vec3_add(vec3_t a, vec3_t b) {
36+
return (vec3_t){{a.axis[0] + b.axis[0],
37+
a.axis[1] + b.axis[1],
38+
a.axis[2] + b.axis[2]}};
39+
}
40+
41+
static inline vec3_t vec3_sub(vec3_t a, vec3_t b) {
42+
return (vec3_t){{a.axis[0] - b.axis[0],
43+
a.axis[1] - b.axis[1],
44+
a.axis[2] - b.axis[2]}};
45+
}
46+
47+
static inline vec3_t vec3_mul(vec3_t v, float s) {
48+
return (vec3_t){{v.axis[0] * s,
49+
v.axis[1] * s,
50+
v.axis[2] * s}};
51+
}
52+
53+
static inline vec3_t vec3_mul_elem(vec3_t a, vec3_t b) {
54+
return (vec3_t){{a.axis[0] * b.axis[0],
55+
a.axis[1] * b.axis[1],
56+
a.axis[2] * b.axis[2]}};
57+
}
58+
59+
static inline float vec3_dot(vec3_t a, vec3_t b) {
60+
return a.axis[0] * b.axis[0] + a.axis[1] * b.axis[1] + a.axis[2] * b.axis[2];
61+
}
62+
63+
static inline vec3_t vec3_cross(vec3_t a, vec3_t b) {
64+
return (vec3_t){{a.axis[1] * b.axis[2] - a.axis[2] * b.axis[1],
65+
a.axis[2] * b.axis[0] - a.axis[0] * b.axis[2],
66+
a.axis[0] * b.axis[1] - a.axis[1] * b.axis[0]}};
67+
}
68+
3569
static inline vec3_t vec3_rotate(const vec3_t vec, const vec3_t rot) {
3670
return (vec3_t){{
3771
vec.axis[0] - vec.axis[1] * rot.axis[2] + vec.axis[2] * rot.axis[1],

test/test_native/test_pid.c

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ static void pid_setUp(void) {
4141

4242
// Initialize control state
4343
state.looptime = 0.000125f; // 8kHz
44+
state.looptime_inverse = 1.0f / state.looptime; // 8000Hz
4445
state.vbat_filtered = 4.0f;
4546
state.vbat_compensated = 1.0f;
4647

@@ -117,8 +118,8 @@ void test_pid_derivative_calculation(void) {
117118
profile.pid.stick_rates[0].accelerator.roll = 1.0f;
118119
profile.pid.stick_rates[0].transition.roll = 1.0f;
119120

120-
// Initialize state.timefactor (needed for D-term)
121-
state.timefactor = 0.0032f / state.looptime; // Original timefactor calculation
121+
// Initialize state.looptime_inverse (needed for D-term)
122+
state.looptime_inverse = 1.0f / state.looptime;
122123

123124
// Initialize with no gyro rate
124125
state.gyro.roll = 0.0f;
@@ -147,8 +148,8 @@ void test_pid_dterm_setpoint_response(void) {
147148
profile.pid.stick_rates[0].accelerator.roll = 1.0f;
148149
profile.pid.stick_rates[0].transition.roll = 0.0f; // No transition weighting
149150

150-
// Initialize state.timefactor
151-
state.timefactor = 0.0032f / state.looptime;
151+
// Initialize state.looptime_inverse
152+
state.looptime_inverse = 1.0f / state.looptime;
152153

153154
// Set some stick input so transition weight isn't 0
154155
state.rx_filtered.roll = 0.0f; // No stick input
@@ -179,8 +180,8 @@ void test_pid_dterm_combined_response(void) {
179180
profile.pid.stick_rates[0].accelerator.roll = 1.0f;
180181
profile.pid.stick_rates[0].transition.roll = 0.0f;
181182

182-
// Initialize state.timefactor
183-
state.timefactor = 0.0032f / state.looptime;
183+
// Initialize state.looptime_inverse
184+
state.looptime_inverse = 1.0f / state.looptime;
184185

185186
// Initialize
186187
state.gyro.roll = 0.0f;
@@ -210,8 +211,8 @@ void test_pid_dterm_stick_weighting(void) {
210211
profile.pid.stick_rates[0].accelerator.roll = 2.0f; // Higher than 1
211212
profile.pid.stick_rates[0].transition.roll = 0.0f; // Ensure setpoint has effect
212213

213-
// Initialize state.timefactor
214-
state.timefactor = 0.0032f / state.looptime;
214+
// Initialize state.looptime_inverse
215+
state.looptime_inverse = 1.0f / state.looptime;
215216

216217
// Set some stick input
217218
state.rx_filtered.roll = 0.5f;

0 commit comments

Comments
 (0)