AirLibrary/Resilience/
Retry.rs1use std::{
13 collections::HashMap,
14 sync::Arc,
15 time::{Duration, Instant},
16};
17
18use serde::{Deserialize, Serialize};
19use tokio::sync::{Mutex, broadcast};
20
21use crate::dev_log;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25pub enum ErrorClass {
26 Transient,
28
29 NonRetryable,
31
32 RateLimited,
34
35 ServerError,
37
38 Unknown,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct RetryPolicy {
45 pub MaxRetries:u32,
46
47 pub InitialIntervalMs:u64,
48
49 pub MaxIntervalMs:u64,
50
51 pub BackoffMultiplier:f64,
52
53 pub JitterFactor:f64,
55
56 pub BudgetPerMinute:u32,
57
58 pub ErrorClassification:HashMap<String, ErrorClass>,
59}
60
61impl Default for RetryPolicy {
62 fn default() -> Self {
63 let mut ErrorClassification = HashMap::new();
64
65 ErrorClassification.insert("timeout".to_string(), ErrorClass::Transient);
66
67 ErrorClassification.insert("connection_refused".to_string(), ErrorClass::Transient);
68
69 ErrorClassification.insert("connection_reset".to_string(), ErrorClass::Transient);
70
71 ErrorClassification.insert("rate_limit_exceeded".to_string(), ErrorClass::RateLimited);
72
73 ErrorClassification.insert("authentication_failed".to_string(), ErrorClass::NonRetryable);
74
75 ErrorClassification.insert("unauthorized".to_string(), ErrorClass::NonRetryable);
76
77 ErrorClassification.insert("not_found".to_string(), ErrorClass::NonRetryable);
78
79 ErrorClassification.insert("server_error".to_string(), ErrorClass::ServerError);
80
81 ErrorClassification.insert("internal_server_error".to_string(), ErrorClass::ServerError);
82
83 ErrorClassification.insert("service_unavailable".to_string(), ErrorClass::ServerError);
84
85 ErrorClassification.insert("gateway_timeout".to_string(), ErrorClass::Transient);
86
87 Self {
88 MaxRetries:3,
89
90 InitialIntervalMs:1000,
91
92 MaxIntervalMs:32000,
93
94 BackoffMultiplier:2.0,
95
96 JitterFactor:0.1,
97
98 BudgetPerMinute:100,
99
100 ErrorClassification,
101 }
102 }
103}
104
105#[derive(Debug, Clone)]
108struct RetryBudget {
109 Attempts:Vec<Instant>,
110
111 MaxPerMinute:u32,
112}
113
114impl RetryBudget {
115 fn new(MaxPerMinute:u32) -> Self { Self { Attempts:Vec::new(), MaxPerMinute } }
116
117 fn can_retry(&mut self) -> bool {
118 let Now = Instant::now();
119
120 let OneMinuteAgo = Now - Duration::from_secs(60);
121
122 self.Attempts.retain(|&attempt| attempt > OneMinuteAgo);
123
124 if self.Attempts.len() < self.MaxPerMinute as usize {
125 self.Attempts.push(Now);
126
127 true
128 } else {
129 false
130 }
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct RetryEvent {
137 pub Service:String,
138
139 pub Attempt:u32,
140
141 pub ErrorClass:ErrorClass,
142
143 pub DelayMs:u64,
144
145 pub Success:bool,
146
147 pub ErrorMessage:Option<String>,
148}
149
150#[derive(Debug)]
153pub struct RetryManager {
154 Policy:RetryPolicy,
155
156 Budgets:Arc<Mutex<HashMap<String, RetryBudget>>>,
157
158 EventTx:Arc<broadcast::Sender<RetryEvent>>,
159}
160
161impl RetryManager {
162 pub fn new(policy:RetryPolicy) -> Self {
163 let (EventTx, _) = broadcast::channel(1000);
164
165 Self {
166 Policy:policy,
167
168 Budgets:Arc::new(Mutex::new(HashMap::new())),
169
170 EventTx:Arc::new(EventTx),
171 }
172 }
173
174 pub fn GetEventTransmitter(&self) -> broadcast::Sender<RetryEvent> { (*self.EventTx).clone() }
175
176 pub fn CalculateRetryDelay(&self, Attempt:u32) -> Duration {
179 if Attempt == 0 {
180 return Duration::from_millis(0);
181 }
182
183 let BaseDelay = (self.Policy.InitialIntervalMs as f64 * self.Policy.BackoffMultiplier.powi(Attempt as i32 - 1))
184 .min(self.Policy.MaxIntervalMs as f64) as u64;
185
186 let Jitter = (BaseDelay as f64 * self.Policy.JitterFactor) as u64;
187
188 let RandomJitter = (rand::random::<f64>() * Jitter as f64) as u64;
189
190 Duration::from_millis(BaseDelay + RandomJitter)
191 }
192
193 pub fn CalculateAdaptiveRetryDelay(&self, ErrorType:&str, attempt:u32) -> Duration {
195 let Class = self
196 .Policy
197 .ErrorClassification
198 .get(ErrorType)
199 .copied()
200 .unwrap_or(ErrorClass::Unknown);
201
202 match Class {
203 ErrorClass::RateLimited => Duration::from_millis(((attempt + 1) * 5000) as u64),
204
205 ErrorClass::ServerError => {
206 let BaseDelay = self.Policy.InitialIntervalMs * 2_u64.pow(attempt);
207
208 Duration::from_millis(BaseDelay.min(self.Policy.MaxIntervalMs))
209 },
210
211 ErrorClass::Transient => self.CalculateRetryDelay(attempt),
212
213 ErrorClass::NonRetryable | ErrorClass::Unknown => Duration::from_millis(100),
214 }
215 }
216
217 pub fn ClassifyError(&self, ErrorMessage:&str) -> ErrorClass {
218 let Lower = ErrorMessage.to_lowercase();
219
220 for (pattern, class) in &self.Policy.ErrorClassification {
221 if Lower.contains(pattern) {
222 return *class;
223 }
224 }
225
226 ErrorClass::Unknown
227 }
228
229 pub async fn CanRetry(&self, service:&str) -> bool {
230 let mut budgets = self.Budgets.lock().await;
231
232 let budget = budgets
233 .entry(service.to_string())
234 .or_insert_with(|| RetryBudget::new(self.Policy.BudgetPerMinute));
235
236 budget.can_retry()
237 }
238
239 pub fn PublishRetryEvent(&self, event:RetryEvent) { let _ = self.EventTx.send(event); }
240
241 pub fn ValidatePolicy(&self) -> Result<(), String> {
242 if self.Policy.MaxRetries == 0 {
243 return Err("MaxRetries must be greater than 0".to_string());
244 }
245
246 if self.Policy.InitialIntervalMs == 0 {
247 return Err("InitialIntervalMs must be greater than 0".to_string());
248 }
249
250 if self.Policy.InitialIntervalMs > self.Policy.MaxIntervalMs {
251 return Err("InitialIntervalMs cannot be greater than MaxIntervalMs".to_string());
252 }
253
254 if self.Policy.BackoffMultiplier <= 1.0 {
255 return Err("BackoffMultiplier must be greater than 1.0".to_string());
256 }
257
258 if self.Policy.JitterFactor < 0.0 || self.Policy.JitterFactor > 1.0 {
259 return Err("JitterFactor must be between 0 and 1".to_string());
260 }
261
262 if self.Policy.BudgetPerMinute == 0 {
263 return Err("BudgetPerMinute must be greater than 0".to_string());
264 }
265
266 Ok(())
267 }
268}