-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_rerankers.py
More file actions
761 lines (546 loc) · 27.4 KB
/
test_rerankers.py
File metadata and controls
761 lines (546 loc) · 27.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
"""Comprehensive tests for the reranker modules.
These tests verify:
1. HTTPReranker initialization and HTTP communication
2. HTTPReranker error handling (connection failures, timeouts)
3. CrossEncoderReranker initialization with different models
4. CrossEncoderReranker model loading and caching
5. CrossEncoderReranker reranking logic (scoring, sorting)
6. CrossEncoderReranker idle timeout behavior
7. Empty inputs handling for both rerankers
8. Thread safety for CrossEncoderReranker
All tests mock external dependencies to avoid real HTTP requests and model loading.
"""
import gc
import threading
import time
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import MagicMock, Mock, PropertyMock, patch
import numpy as np
import pytest
import requests
# ============================================================================
# HTTPReranker Tests
# ============================================================================
class TestHttpRerankerInit:
"""Tests for HTTPReranker initialization."""
def test_init_with_default_port(self):
"""Test initialization with default port."""
from code_rag.reranker.http_reranker import DEFAULT_PORT, HttpReranker
reranker = HttpReranker()
assert reranker.port == DEFAULT_PORT
assert reranker.base_url == f"http://127.0.0.1:{DEFAULT_PORT}"
assert reranker.client_id == ""
def test_init_with_custom_port(self):
"""Test initialization with custom port."""
from code_rag.reranker.http_reranker import HttpReranker
reranker = HttpReranker(port=9999)
assert reranker.port == 9999
assert reranker.base_url == "http://127.0.0.1:9999"
def test_init_with_client_id(self):
"""Test initialization with client_id."""
from code_rag.reranker.http_reranker import HttpReranker
reranker = HttpReranker(port=8199, client_id="test-client-123")
assert reranker.client_id == "test-client-123"
class TestHttpRerankerRerank:
"""Tests for HTTPReranker rerank method."""
@pytest.fixture
def reranker(self):
"""Provide an HTTPReranker instance."""
from code_rag.reranker.http_reranker import HttpReranker
return HttpReranker(port=8199, client_id="test-client")
def test_rerank_empty_documents_returns_empty_list(self, reranker):
"""Test that reranking empty documents returns empty list without HTTP call."""
with patch("requests.post") as mock_post:
result = reranker.rerank(query="test query", documents=[])
assert result == []
mock_post.assert_not_called()
def test_rerank_successful_response(self, reranker):
"""Test successful reranking via HTTP."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"results": [[0, 0.95], [2, 0.85], [1, 0.75]]}
mock_response.raise_for_status = Mock()
with patch("requests.post", return_value=mock_response) as mock_post:
result = reranker.rerank(
query="test query", documents=["doc1", "doc2", "doc3"], top_k=3
)
assert result == [(0, 0.95), (2, 0.85), (1, 0.75)]
mock_post.assert_called_once()
# Verify the request payload
call_args = mock_post.call_args
assert call_args[0][0] == "http://127.0.0.1:8199/rerank"
payload = call_args[1]["json"]
assert payload["query"] == "test query"
assert payload["documents"] == ["doc1", "doc2", "doc3"]
assert payload["top_k"] == 3
assert payload["client_id"] == "test-client"
def test_rerank_with_metadatas(self, reranker):
"""Test reranking with metadata passed."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"results": [[0, 0.9]]}
mock_response.raise_for_status = Mock()
metadatas = [{"file_path": "test.py", "function_name": "foo"}]
with patch("requests.post", return_value=mock_response) as mock_post:
result = reranker.rerank(
query="test query", documents=["doc1"], metadatas=metadatas, top_k=1
)
assert result == [(0, 0.9)]
payload = mock_post.call_args[1]["json"]
assert payload["metadatas"] == metadatas
def test_rerank_with_custom_model(self, reranker):
"""Test reranking with custom model parameter."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"results": [[0, 0.9]]}
mock_response.raise_for_status = Mock()
with patch("requests.post", return_value=mock_response) as mock_post:
result = reranker.rerank(
query="test query", documents=["doc1"], model="custom-model"
)
payload = mock_post.call_args[1]["json"]
assert payload["model"] == "custom-model"
def test_rerank_error_in_response_returns_empty(self, reranker):
"""Test that error in response JSON returns empty list."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"error": "Model not loaded"}
mock_response.raise_for_status = Mock()
with patch("requests.post", return_value=mock_response):
result = reranker.rerank(query="test query", documents=["doc1", "doc2"])
assert result == []
class TestHttpRerankerErrorHandling:
"""Tests for HTTPReranker error handling."""
@pytest.fixture
def reranker(self):
"""Provide an HTTPReranker instance."""
from code_rag.reranker.http_reranker import HttpReranker
return HttpReranker(port=8199)
def test_rerank_connection_error_returns_empty(self, reranker):
"""Test that connection error returns empty list."""
with patch("requests.post", side_effect=requests.ConnectionError()):
result = reranker.rerank(query="test query", documents=["doc1", "doc2"])
assert result == []
def test_rerank_timeout_returns_empty(self, reranker):
"""Test that timeout returns empty list."""
with patch("requests.post", side_effect=requests.Timeout()):
result = reranker.rerank(query="test query", documents=["doc1", "doc2"])
assert result == []
def test_rerank_http_error_propagates(self, reranker):
"""Test that HTTP errors (non-connection) propagate."""
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError(
"500 Server Error"
)
with patch("requests.post", return_value=mock_response):
with pytest.raises(requests.HTTPError):
reranker.rerank(query="test query", documents=["doc1", "doc2"])
class TestHttpRerankerNoOps:
"""Tests for HTTPReranker no-op methods."""
def test_start_background_loading_is_noop(self):
"""Test that start_background_loading is a no-op."""
from code_rag.reranker.http_reranker import HttpReranker
reranker = HttpReranker()
# Should not raise any exception
reranker.start_background_loading()
def test_unload_model_is_noop(self):
"""Test that unload_model is a no-op."""
from code_rag.reranker.http_reranker import HttpReranker
reranker = HttpReranker()
# Should not raise any exception
reranker.unload_model()
# ============================================================================
# CrossEncoderReranker Tests
# ============================================================================
class TestCrossEncoderRerankerInit:
"""Tests for CrossEncoderReranker initialization."""
@pytest.fixture
def mock_cross_encoder(self):
"""Mock CrossEncoder class."""
with patch("code_rag.reranker.cross_encoder_reranker.CrossEncoder") as mock:
mock_instance = MagicMock()
mock.return_value = mock_instance
yield mock, mock_instance
def test_init_with_default_model_lazy_load(self, mock_cross_encoder):
"""Test initialization with lazy loading (default)."""
mock_class, mock_instance = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker()
assert reranker.model_name == "jinaai/jina-reranker-v3"
assert reranker.model is None # Not loaded yet due to lazy_load=True
mock_class.assert_not_called()
# Cleanup
reranker.stop_cleanup_thread()
def test_init_with_custom_model(self, mock_cross_encoder):
"""Test initialization with custom model name."""
mock_class, mock_instance = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(model_name="custom/reranker-model")
assert reranker.model_name == "custom/reranker-model"
reranker.stop_cleanup_thread()
def test_init_with_lazy_load_false_loads_model(self, mock_cross_encoder):
"""Test initialization with lazy_load=False loads model immediately."""
mock_class, mock_instance = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(lazy_load=False)
mock_class.assert_called_once_with("jinaai/jina-reranker-v3")
assert reranker.model == mock_instance
reranker.stop_cleanup_thread()
def test_init_with_idle_timeout(self, mock_cross_encoder):
"""Test initialization with custom idle timeout."""
mock_class, mock_instance = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(idle_timeout=600)
assert reranker._idle_timeout == 600
reranker.stop_cleanup_thread()
def test_init_with_zero_idle_timeout_no_cleanup_thread(self, mock_cross_encoder):
"""Test that zero idle timeout doesn't start cleanup thread."""
mock_class, mock_instance = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(idle_timeout=0)
assert reranker._idle_timeout == 0
# Cleanup thread should not have started
assert reranker._cleanup_thread is None
class TestCrossEncoderRerankerModelLoading:
"""Tests for CrossEncoderReranker model loading behavior."""
@pytest.fixture
def mock_cross_encoder(self):
"""Mock CrossEncoder class."""
with patch("code_rag.reranker.cross_encoder_reranker.CrossEncoder") as mock:
mock_instance = MagicMock()
mock.return_value = mock_instance
yield mock, mock_instance
def test_load_model_creates_cross_encoder(self, mock_cross_encoder):
"""Test _load_model creates CrossEncoder instance."""
mock_class, mock_instance = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(idle_timeout=0)
reranker._load_model()
mock_class.assert_called_once_with("jinaai/jina-reranker-v3")
assert reranker.model == mock_instance
def test_load_model_only_loads_once(self, mock_cross_encoder):
"""Test _load_model only loads model once."""
mock_class, mock_instance = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(idle_timeout=0)
reranker._load_model()
reranker._load_model() # Second call should be no-op
mock_class.assert_called_once()
def test_start_background_loading_creates_thread(self, mock_cross_encoder):
"""Test start_background_loading starts a background thread."""
mock_class, mock_instance = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(idle_timeout=0)
assert reranker._loading_thread is None
reranker.start_background_loading()
assert reranker._loading_thread is not None
assert reranker._loading_thread.daemon is True
# Wait for thread to complete
reranker._loading_thread.join(timeout=2)
assert reranker.model == mock_instance
def test_ensure_model_loaded_waits_for_background_thread(self, mock_cross_encoder):
"""Test _ensure_model_loaded waits for background loading to complete."""
mock_class, mock_instance = mock_cross_encoder
# Add a small delay to model loading
loading_started = threading.Event()
loading_complete = threading.Event()
def delayed_cross_encoder(*args, **kwargs):
loading_started.set()
time.sleep(0.1)
loading_complete.set()
return mock_instance
mock_class.side_effect = delayed_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(idle_timeout=0)
reranker.start_background_loading()
# Wait for loading to start
loading_started.wait(timeout=1)
# This should block until loading completes
reranker._ensure_model_loaded()
assert loading_complete.is_set()
assert reranker.model == mock_instance
class TestCrossEncoderRerankerRerank:
"""Tests for CrossEncoderReranker rerank method."""
@pytest.fixture
def mock_cross_encoder(self):
"""Mock CrossEncoder class with predict method."""
with patch("code_rag.reranker.cross_encoder_reranker.CrossEncoder") as mock:
mock_instance = MagicMock()
# Default scores
mock_instance.predict.return_value = np.array([0.8, 0.6, 0.9])
mock.return_value = mock_instance
yield mock, mock_instance
@pytest.fixture
def reranker(self, mock_cross_encoder):
"""Provide a CrossEncoderReranker instance with mocked model."""
mock_class, mock_instance = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(lazy_load=False, idle_timeout=0)
yield reranker
def test_rerank_empty_documents_returns_empty(self, reranker, mock_cross_encoder):
"""Test reranking empty documents returns empty list."""
_, mock_instance = mock_cross_encoder
result = reranker.rerank(query="test query", documents=[])
assert result == []
mock_instance.predict.assert_not_called()
def test_rerank_returns_sorted_results(self, reranker, mock_cross_encoder):
"""Test reranking returns results sorted by score descending."""
_, mock_instance = mock_cross_encoder
mock_instance.predict.return_value = np.array([0.5, 0.9, 0.3])
result = reranker.rerank(
query="test query", documents=["doc1", "doc2", "doc3"], top_k=3
)
# doc2 (idx=1) has highest score, then doc1 (idx=0), then doc3 (idx=2)
assert result == [(1, 0.9), (0, 0.5), (2, 0.3)]
def test_rerank_respects_top_k(self, reranker, mock_cross_encoder):
"""Test reranking respects top_k limit."""
_, mock_instance = mock_cross_encoder
mock_instance.predict.return_value = np.array([0.5, 0.9, 0.3, 0.7, 0.8])
result = reranker.rerank(
query="test query",
documents=["doc1", "doc2", "doc3", "doc4", "doc5"],
top_k=2,
)
assert len(result) == 2
assert result[0] == (1, 0.9) # Highest score
assert result[1] == (4, 0.8) # Second highest
def test_rerank_creates_query_document_pairs(self, reranker, mock_cross_encoder):
"""Test reranking creates correct query-document pairs."""
_, mock_instance = mock_cross_encoder
reranker.rerank(
query="search query", documents=["first doc", "second doc"], top_k=2
)
# Check the pairs passed to predict
call_args = mock_instance.predict.call_args
pairs = call_args[0][0]
assert pairs == [["search query", "first doc"], ["search query", "second doc"]]
def test_rerank_with_metadata_adds_context(self, reranker, mock_cross_encoder):
"""Test reranking with metadata enriches documents with context."""
_, mock_instance = mock_cross_encoder
mock_instance.predict.return_value = np.array([0.8, 0.6])
metadatas = [
{"file_path": "test.py", "function_name": "foo", "class_name": "Bar"},
{"file_path": "other.py"},
]
reranker.rerank(
query="test query", documents=["doc1", "doc2"], metadatas=metadatas, top_k=2
)
call_args = mock_instance.predict.call_args
pairs = call_args[0][0]
# First doc should have full context
assert "File: test.py" in pairs[0][1]
assert "Class: Bar" in pairs[0][1]
assert "Function: foo" in pairs[0][1]
assert "Code: doc1" in pairs[0][1]
# Second doc has minimal context
assert "File: other.py" in pairs[1][1]
assert "Code: doc2" in pairs[1][1]
def test_rerank_with_partial_metadata_uses_documents_directly(
self, reranker, mock_cross_encoder
):
"""Test reranking with mismatched metadata length uses raw documents."""
_, mock_instance = mock_cross_encoder
mock_instance.predict.return_value = np.array([0.8, 0.6])
# Metadata length doesn't match documents length
metadatas = [{"file_path": "test.py"}] # Only 1 metadata for 2 docs
reranker.rerank(
query="test query", documents=["doc1", "doc2"], metadatas=metadatas, top_k=2
)
call_args = mock_instance.predict.call_args
pairs = call_args[0][0]
# Should use raw documents (no metadata processing)
assert pairs[0][1] == "doc1"
assert pairs[1][1] == "doc2"
def test_rerank_updates_last_used_timestamp(self, mock_cross_encoder):
"""Test that reranking updates the last used timestamp."""
mock_class, mock_instance = mock_cross_encoder
mock_instance.predict.return_value = np.array([0.8])
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(lazy_load=False, idle_timeout=0)
initial_time = reranker._last_used
time.sleep(0.01)
reranker.rerank(query="test", documents=["doc"])
assert reranker._last_used > initial_time
class TestCrossEncoderRerankerIdleTimeout:
"""Tests for CrossEncoderReranker idle timeout behavior."""
@pytest.fixture
def mock_cross_encoder(self):
"""Mock CrossEncoder class."""
with patch("code_rag.reranker.cross_encoder_reranker.CrossEncoder") as mock:
mock_instance = MagicMock()
mock_instance.predict.return_value = np.array([0.8])
mock.return_value = mock_instance
yield mock, mock_instance
def test_cleanup_thread_starts_with_positive_timeout(self, mock_cross_encoder):
"""Test cleanup thread starts when idle_timeout > 0."""
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(idle_timeout=1800)
assert reranker._cleanup_thread is not None
assert reranker._cleanup_thread.is_alive()
reranker.stop_cleanup_thread()
def test_stop_cleanup_thread_stops_thread(self, mock_cross_encoder):
"""Test stop_cleanup_thread stops the cleanup thread."""
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(idle_timeout=1800)
assert reranker._cleanup_thread is not None
reranker.stop_cleanup_thread()
# Give thread time to stop
time.sleep(0.1)
assert not reranker._cleanup_thread.is_alive()
def test_cleanup_loop_unloads_model_after_timeout(self, mock_cross_encoder):
"""Test that cleanup loop unloads model after idle timeout."""
mock_class, mock_instance = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
# Very short timeout for testing
reranker = CrossEncoderReranker(lazy_load=False, idle_timeout=1)
assert reranker.model is not None
# Manually trigger cleanup check by setting last_used in the past
reranker._last_used = time.time() - 10
# Trigger cleanup manually (simulating what the loop does)
with reranker._loading_lock:
idle_time = time.time() - reranker._last_used
if idle_time >= reranker._idle_timeout:
del reranker.model
reranker.model = None
gc.collect()
assert reranker.model is None
reranker.stop_cleanup_thread()
class TestCrossEncoderRerankerUnloadModel:
"""Tests for CrossEncoderReranker model unloading."""
@pytest.fixture
def mock_cross_encoder(self):
"""Mock CrossEncoder class."""
with patch("code_rag.reranker.cross_encoder_reranker.CrossEncoder") as mock:
mock_instance = MagicMock()
mock.return_value = mock_instance
yield mock, mock_instance
def test_unload_model_sets_model_to_none(self, mock_cross_encoder):
"""Test unload_model sets model to None."""
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(lazy_load=False, idle_timeout=0)
assert reranker.model is not None
reranker.unload_model()
assert reranker.model is None
def test_unload_model_when_model_is_none_is_safe(self, mock_cross_encoder):
"""Test unload_model when model is None doesn't raise."""
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(idle_timeout=0)
assert reranker.model is None
# Should not raise
reranker.unload_model()
assert reranker.model is None
def test_unload_model_with_torch_clears_cuda_cache(self, mock_cross_encoder):
"""Test unload_model clears CUDA cache when torch is available."""
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(lazy_load=False, idle_timeout=0)
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = True
with patch.dict("sys.modules", {"torch": mock_torch}):
reranker.unload_model()
mock_torch.cuda.is_available.assert_called_once()
mock_torch.cuda.empty_cache.assert_called_once()
class TestCrossEncoderRerankerThreadSafety:
"""Tests for CrossEncoderReranker thread safety."""
@pytest.fixture
def mock_cross_encoder(self):
"""Mock CrossEncoder class with thread-safe tracking."""
with patch("code_rag.reranker.cross_encoder_reranker.CrossEncoder") as mock:
call_count = {"count": 0}
lock = threading.Lock()
def create_instance(*args, **kwargs):
with lock:
call_count["count"] += 1
mock_instance = MagicMock()
mock_instance.predict.return_value = np.array([0.8])
return mock_instance
mock.side_effect = create_instance
yield mock, call_count
def test_concurrent_load_model_only_loads_once(self, mock_cross_encoder):
"""Test concurrent _load_model calls only load model once."""
mock_class, call_count = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(idle_timeout=0)
# Start multiple threads trying to load
threads = []
for _ in range(5):
t = threading.Thread(target=reranker._load_model)
threads.append(t)
t.start()
for t in threads:
t.join()
# Model should only be loaded once
assert call_count["count"] == 1
assert reranker.model is not None
def test_concurrent_rerank_is_thread_safe(self, mock_cross_encoder):
"""Test concurrent rerank calls are thread safe."""
mock_class, _ = mock_cross_encoder
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(lazy_load=False, idle_timeout=0)
results = []
errors = []
def do_rerank():
try:
result = reranker.rerank(
query="test", documents=["doc1", "doc2", "doc3"]
)
results.append(result)
except Exception as e:
errors.append(e)
threads = []
for _ in range(10):
t = threading.Thread(target=do_rerank)
threads.append(t)
t.start()
for t in threads:
t.join()
assert len(errors) == 0
assert len(results) == 10
class TestCrossEncoderRerankerDestructor:
"""Tests for CrossEncoderReranker destructor behavior."""
@pytest.fixture
def mock_cross_encoder(self):
"""Mock CrossEncoder class."""
with patch("code_rag.reranker.cross_encoder_reranker.CrossEncoder") as mock:
mock_instance = MagicMock()
mock.return_value = mock_instance
yield mock, mock_instance
def test_destructor_stops_cleanup_thread(self, mock_cross_encoder):
"""Test destructor stops cleanup thread."""
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(idle_timeout=1800)
cleanup_thread = reranker._cleanup_thread
assert cleanup_thread is not None
assert cleanup_thread.is_alive()
# Explicitly call destructor
reranker.__del__()
# Give thread time to stop
time.sleep(0.2)
assert not cleanup_thread.is_alive()
def test_destructor_unloads_model(self, mock_cross_encoder):
"""Test destructor unloads model."""
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
reranker = CrossEncoderReranker(lazy_load=False, idle_timeout=0)
assert reranker.model is not None
reranker.__del__()
assert reranker.model is None
# ============================================================================
# RerankerInterface Tests
# ============================================================================
class TestRerankerInterface:
"""Tests for RerankerInterface abstract class."""
def test_interface_cannot_be_instantiated(self):
"""Test that RerankerInterface cannot be instantiated directly."""
from code_rag.reranker.reranker_interface import RerankerInterface
with pytest.raises(TypeError):
RerankerInterface()
def test_implementations_implement_interface(self):
"""Test that implementations properly implement the interface."""
from code_rag.reranker.http_reranker import HttpReranker
from code_rag.reranker.reranker_interface import RerankerInterface
reranker = HttpReranker()
assert isinstance(reranker, RerankerInterface)
# CrossEncoderReranker would need mocking to test
with patch("code_rag.reranker.cross_encoder_reranker.CrossEncoder"):
from code_rag.reranker.cross_encoder_reranker import CrossEncoderReranker
ce_reranker = CrossEncoderReranker(idle_timeout=0)
assert isinstance(ce_reranker, RerankerInterface)