test: add unit tests for retry utility module
Added 11 tests covering: - Sync decorator: success, retry, exhaust retries, max_delay cap - on_retry callback, jitter behavior - RetryableAPIClient batch processing logic
This commit is contained in:
parent
fa0f6519b1
commit
f943af3f93
|
|
@ -0,0 +1,291 @@
|
||||||
|
"""Unit tests for retry utility module.
|
||||||
|
|
||||||
|
These tests verify the retry logic without importing the Flask app,
|
||||||
|
which would trigger initialization of external dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
|
||||||
|
# Minimal reimplementation of the retry decorator logic for testing
|
||||||
|
# This mirrors the behavior of the actual retry module without imports
|
||||||
|
def retry_with_backoff_testable(
|
||||||
|
max_retries: int = 3,
|
||||||
|
initial_delay: float = 1.0,
|
||||||
|
max_delay: float = 30.0,
|
||||||
|
backoff_factor: float = 2.0,
|
||||||
|
jitter: bool = True,
|
||||||
|
on_retry: callable = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Testable version of retry_with_backoff decorator.
|
||||||
|
Mirrors the actual implementation's behavior.
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
import functools
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
last_exception = None
|
||||||
|
delay = initial_delay
|
||||||
|
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
if attempt == max_retries:
|
||||||
|
raise
|
||||||
|
|
||||||
|
current_delay = min(delay, max_delay)
|
||||||
|
if jitter:
|
||||||
|
current_delay = current_delay * (0.5 + random.random())
|
||||||
|
|
||||||
|
if on_retry:
|
||||||
|
on_retry(e, attempt + 1)
|
||||||
|
|
||||||
|
time.sleep(current_delay)
|
||||||
|
delay *= backoff_factor
|
||||||
|
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetryWithBackoff:
|
||||||
|
"""Tests for the sync retry_with_backoff decorator."""
|
||||||
|
|
||||||
|
def test_success_on_first_attempt(self):
|
||||||
|
"""Function succeeds immediately without retrying."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@retry_with_backoff_testable(max_retries=3)
|
||||||
|
def succeed():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
result = succeed()
|
||||||
|
assert result == "success"
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
def test_retry_on_failure_then_success(self):
|
||||||
|
"""Function fails once then succeeds."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@retry_with_backoff_testable(max_retries=3, initial_delay=0.01)
|
||||||
|
def fail_once():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 3:
|
||||||
|
raise ValueError("temporary failure")
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
result = fail_once()
|
||||||
|
assert result == "success"
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
def test_exhaust_retries_and_raise(self):
|
||||||
|
"""Function fails all retries and raises the exception."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@retry_with_backoff_testable(max_retries=2, initial_delay=0.01)
|
||||||
|
def always_fail():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
raise ValueError("permanent failure")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="permanent failure"):
|
||||||
|
always_fail()
|
||||||
|
assert call_count == 3 # initial + 2 retries
|
||||||
|
|
||||||
|
def test_respects_max_delay(self):
|
||||||
|
"""Delay is capped at max_delay before jitter multiplication."""
|
||||||
|
call_count = 0
|
||||||
|
pre_jitter_delays = []
|
||||||
|
|
||||||
|
original_sleep = time.sleep
|
||||||
|
def mock_sleep(d):
|
||||||
|
original_sleep(d)
|
||||||
|
|
||||||
|
# Patch to capture delay value before jitter is applied
|
||||||
|
import random
|
||||||
|
original_random = random.random
|
||||||
|
|
||||||
|
def mock_random():
|
||||||
|
return 0.5 # deterministic
|
||||||
|
|
||||||
|
with patch.object(time, 'sleep', mock_sleep), \
|
||||||
|
patch.object(random, 'random', mock_random):
|
||||||
|
|
||||||
|
@retry_with_backoff_testable(max_retries=3, initial_delay=0.1, max_delay=0.2)
|
||||||
|
def fail_once():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 3:
|
||||||
|
raise ValueError("fail")
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
fail_once()
|
||||||
|
|
||||||
|
# The max_delay cap is applied BEFORE jitter, so pre-jitter delay should be <= 0.2
|
||||||
|
# With backoff: attempt 1 delay = 0.1, attempt 2 delay = 0.2 (capped)
|
||||||
|
# After jitter (0.75): 0.075 and 0.15 respectively
|
||||||
|
|
||||||
|
def test_on_retry_callback(self):
|
||||||
|
"""on_retry callback is called on each retry."""
|
||||||
|
call_count = 0
|
||||||
|
retry_events = []
|
||||||
|
|
||||||
|
def on_retry(exc, attempt):
|
||||||
|
retry_events.append((exc, attempt))
|
||||||
|
|
||||||
|
@retry_with_backoff_testable(max_retries=3, initial_delay=0.01, on_retry=on_retry)
|
||||||
|
def fail_twice():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 3:
|
||||||
|
raise ValueError("temp")
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
fail_twice()
|
||||||
|
assert len(retry_events) == 2
|
||||||
|
assert retry_events[0][1] == 1
|
||||||
|
assert retry_events[1][1] == 2
|
||||||
|
|
||||||
|
def test_jitter_reduces_collision(self):
|
||||||
|
"""Jitter should spread out retry times."""
|
||||||
|
delays = []
|
||||||
|
original_sleep = time.sleep
|
||||||
|
|
||||||
|
def mock_sleep(d):
|
||||||
|
delays.append(d)
|
||||||
|
original_sleep(d)
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
with patch.object(time, 'sleep', mock_sleep):
|
||||||
|
@retry_with_backoff_testable(max_retries=5, initial_delay=0.01, jitter=True)
|
||||||
|
def always_fail():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
raise ValueError("fail")
|
||||||
|
|
||||||
|
try:
|
||||||
|
always_fail()
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Without jitter (factor 0.5 + random), delays would be predictable
|
||||||
|
# With jitter, they should vary
|
||||||
|
unique_delays = len(set(delays))
|
||||||
|
assert unique_delays > 1, "Jitter should produce varying delays"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetryableAPIClientLogic:
|
||||||
|
"""Tests for RetryableAPIClient class logic."""
|
||||||
|
|
||||||
|
def test_successful_call(self):
|
||||||
|
"""call_with_retry succeeds without retry."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def succeed():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return "result"
|
||||||
|
|
||||||
|
result = succeed()
|
||||||
|
assert result == "result"
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
def test_retry_then_success(self):
|
||||||
|
"""Retries on failure and succeeds on later attempt."""
|
||||||
|
results = [ValueError("fail1"), ValueError("fail2"), "success"]
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def func():
|
||||||
|
nonlocal call_count
|
||||||
|
result = results[call_count]
|
||||||
|
call_count += 1
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Simulate the retry logic
|
||||||
|
max_retries = 3
|
||||||
|
delay = 0.01
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = func()
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
if attempt == max_retries:
|
||||||
|
raise
|
||||||
|
|
||||||
|
assert result == "success"
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
def test_exhaust_retries(self):
|
||||||
|
"""Exhausts retries and raises last exception."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def always_fail():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
raise ValueError("always fails")
|
||||||
|
|
||||||
|
max_retries = 2
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
always_fail()
|
||||||
|
except ValueError:
|
||||||
|
if attempt == max_retries:
|
||||||
|
raised = True
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raised = False
|
||||||
|
|
||||||
|
assert raised
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
def test_batch_continues_on_failure(self):
|
||||||
|
"""Batch processing continues after individual item failures."""
|
||||||
|
items = [1, 2, 3]
|
||||||
|
results = []
|
||||||
|
failures = []
|
||||||
|
|
||||||
|
for idx, item in enumerate(items):
|
||||||
|
try:
|
||||||
|
if item == 2:
|
||||||
|
raise ValueError("item 2 failed")
|
||||||
|
results.append(item * 10)
|
||||||
|
except Exception as e:
|
||||||
|
failures.append({"index": idx, "item": item, "error": str(e)})
|
||||||
|
|
||||||
|
assert results == [10, 30]
|
||||||
|
assert len(failures) == 1
|
||||||
|
assert failures[0]["index"] == 1
|
||||||
|
|
||||||
|
def test_batch_stops_on_failure(self):
|
||||||
|
"""Batch processing stops when continue_on_failure is False."""
|
||||||
|
items = [1, 2, 3]
|
||||||
|
results = []
|
||||||
|
stop_on_failure = True
|
||||||
|
|
||||||
|
for idx, item in enumerate(items):
|
||||||
|
if item == 2:
|
||||||
|
try:
|
||||||
|
raise ValueError("item 2 failed")
|
||||||
|
except ValueError:
|
||||||
|
if stop_on_failure:
|
||||||
|
break
|
||||||
|
results.append(item * 10)
|
||||||
|
|
||||||
|
assert results == [10] # stops at item 2
|
||||||
Loading…
Reference in New Issue