Skip to content

Backends

Backend classes implement the VectorStorageBackend ABC. You generally do not instantiate them directly — Medha creates the correct backend based on Settings.backend_type. These docs are provided for contributors and advanced users who need to extend or inspect backend behaviour.

See the Backends guide for installation instructions and configuration examples.


InMemoryBackend

Bases: VectorStorageBackend

Pure-Python in-process vector backend. No external dependencies required.

Source code in src/medha/backends/memory.py
class InMemoryBackend(VectorStorageBackend):
    """Pure-Python in-process vector backend. No external dependencies required."""

    def __init__(self) -> None:
        # _store: collection_name -> {"dimension": int, "entries": {id: stored_point}}
        self._store: dict[str, dict[str, Any]] = {}
        self._lock = asyncio.Lock()

    async def connect(self) -> None:
        """No-op — backend is always connected."""

    async def initialize(self, collection_name: str, dimension: int, **kwargs: Any) -> None:
        if dimension <= 0:
            raise StorageError(f"dimension must be > 0, got {dimension}")
        async with self._lock:
            if collection_name not in self._store:
                self._store[collection_name] = {"dimension": dimension, "entries": {}}

    async def search(
        self,
        collection_name: str,
        vector: list[float],
        limit: int = 5,
        score_threshold: float = 0.0,
    ) -> list[CacheResult]:
        if collection_name not in self._store:
            raise StorageError(f"Collection '{collection_name}' does not exist.")

        entries = self._store[collection_name]["entries"]
        if not entries:
            return []

        now = datetime.now(timezone.utc)
        scored: list[tuple[float, dict[str, Any]]] = []
        for point in entries.values():
            ea_raw = point["payload"].get("expires_at")
            if ea_raw is not None:
                ea = _parse_dt(ea_raw)
                if ea is not None and ea <= now:
                    continue
            score = _cosine_similarity(vector, point["vector"])
            score = max(0.0, min(1.0, score))
            if score >= score_threshold:
                scored.append((score, point))

        scored.sort(key=lambda t: t[0], reverse=True)

        return [
            _point_to_cache_result(point, score)
            for score, point in scored[:limit]
        ]

    async def upsert(self, collection_name: str, entries: list[CacheEntry]) -> None:
        if collection_name not in self._store:
            raise StorageError(f"Collection '{collection_name}' does not exist.")
        if not entries:
            return
        async with self._lock:
            store_entries = self._store[collection_name]["entries"]
            for entry in entries:
                store_entries[entry.id] = {
                    "id": entry.id,
                    "vector": entry.vector,
                    "payload": {
                        "original_question": entry.original_question,
                        "normalized_question": entry.normalized_question,
                        "generated_query": entry.generated_query,
                        "query_hash": entry.query_hash,
                        "response_summary": entry.response_summary,
                        "template_id": entry.template_id,
                        "usage_count": entry.usage_count,
                        "created_at": entry.created_at.isoformat(),
                        "expires_at": entry.expires_at.isoformat() if entry.expires_at else None,
                    },
                }

    async def scroll(
        self,
        collection_name: str,
        limit: int = 100,
        offset: str | None = None,
        with_vectors: bool = False,
    ) -> tuple[list[CacheResult], str | None]:
        if collection_name not in self._store:
            raise StorageError(f"Collection '{collection_name}' does not exist.")

        entries = self._store[collection_name]["entries"]
        ids = list(entries.keys())
        start = int(offset) if offset is not None else 0
        page_ids = ids[start : start + limit]
        next_offset = str(start + limit) if start + limit < len(ids) else None

        results = [
            _point_to_cache_result(entries[id_], score=1.0)
            for id_ in page_ids
        ]
        return results, next_offset

    async def count(self, collection_name: str) -> int:
        return len(self._store.get(collection_name, {}).get("entries", {}))

    async def delete(self, collection_name: str, ids: list[str]) -> None:
        if not ids:
            return
        async with self._lock:
            entries = self._store.get(collection_name, {}).get("entries", {})
            for id_ in ids:
                entries.pop(id_, None)

    async def close(self) -> None:
        self._store.clear()

    async def find_expired(self, collection_name: str) -> list[str]:
        if collection_name not in self._store:
            return []
        now = datetime.now(timezone.utc)
        return [
            id_
            for id_, point in self._store[collection_name]["entries"].items()
            if (ea_raw := point["payload"].get("expires_at"))
            and (ea := _parse_dt(ea_raw)) is not None
            and ea < now
        ]

    async def search_by_normalized_question(
        self, collection_name: str, normalized_question: str
    ) -> CacheResult | None:
        if collection_name not in self._store:
            raise StorageError(f"Collection '{collection_name}' does not exist.")
        for point in self._store[collection_name]["entries"].values():
            if point["payload"]["normalized_question"] == normalized_question:
                return _point_to_cache_result(point, score=1.0)
        return None

    async def find_by_query_hash(
        self, collection_name: str, query_hash: str
    ) -> list[str]:
        if collection_name not in self._store:
            raise StorageError(f"Collection '{collection_name}' does not exist.")
        return [
            point["id"]
            for point in self._store[collection_name]["entries"].values()
            if point["payload"].get("query_hash") == query_hash
        ]

    async def find_by_template_id(
        self, collection_name: str, template_id: str
    ) -> list[str]:
        if collection_name not in self._store:
            raise StorageError(f"Collection '{collection_name}' does not exist.")
        return [
            point["id"]
            for point in self._store[collection_name]["entries"].values()
            if point["payload"].get("template_id") == template_id
        ]

    async def drop_collection(self, collection_name: str) -> None:
        async with self._lock:
            self._store.pop(collection_name, None)

    async def search_by_query_hash(
        self, collection_name: str, query_hash: str
    ) -> CacheResult | None:
        if collection_name not in self._store:
            raise StorageError(f"Collection '{collection_name}' does not exist.")

        for point in self._store[collection_name]["entries"].values():
            if point["payload"]["query_hash"] == query_hash:
                return _point_to_cache_result(point, score=1.0)
        return None

    async def update_usage_count(self, collection_name: str, point_id: str) -> None:
        async with self._lock:
            entries = self._store.get(collection_name, {}).get("entries", {})
            if point_id not in entries:
                logger.warning(
                    "update_usage_count: id '%s' not found in collection '%s'",
                    point_id,
                    collection_name,
                )
                return
            entries[point_id]["payload"]["usage_count"] += 1

connect() async

No-op — backend is always connected.

Source code in src/medha/backends/memory.py
async def connect(self) -> None:
    """No-op — backend is always connected."""

QdrantBackend

Conditional Import

QdrantBackend is only importable when the qdrant extra is installed (pip install "medha-archai[qdrant]"). Importing it without the extra raises ImportError with a helpful install hint.

Bases: VectorStorageBackend

Qdrant-based vector storage backend.

Supports three deployment modes
  • "memory": In-process, no persistence. Best for testing.
  • "docker": Connect to a local/remote Qdrant Docker instance.
  • "cloud": Connect to Qdrant Cloud with API key.

Parameters:

Name Type Description Default
settings Settings | None

Medha Settings instance. If None, loads from environment.

None
Source code in src/medha/backends/qdrant.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
class QdrantBackend(VectorStorageBackend):
    """Qdrant-based vector storage backend.

    Supports three deployment modes:
        - "memory": In-process, no persistence. Best for testing.
        - "docker": Connect to a local/remote Qdrant Docker instance.
        - "cloud": Connect to Qdrant Cloud with API key.

    Args:
        settings: Medha Settings instance. If None, loads from environment.
    """

    def __init__(self, settings: Settings | None = None):
        self._settings = settings or Settings()
        self._client: AsyncQdrantClient | None = None
        self._initialized_collections: set[str] = set()

    @property
    def client(self) -> AsyncQdrantClient:
        if self._client is None:
            raise StorageError("Backend not connected. Call connect() first.")
        return self._client

    async def connect(self) -> None:
        """Establish connection to Qdrant based on settings.

        Must be called before any other operation.

        Raises:
            StorageInitializationError: If connection fails.
        """
        try:
            self._client = self._build_client()
            logger.info(
                "Connected to Qdrant in '%s' mode", self._settings.qdrant_mode
            )
        except Exception as e:
            raise StorageInitializationError(
                f"Failed to connect to Qdrant in '{self._settings.qdrant_mode}' mode: {e}"
            ) from e

    async def initialize(
        self, collection_name: str, dimension: int, **kwargs: Any
    ) -> None:
        """Create and configure a Qdrant collection.

        Idempotent: skips creation if collection already exists.

        Configuration includes:
            - Vector params (dimension, cosine distance)
            - Quantization (scalar INT8 by default, binary for dim >= 512)
            - HNSW parameters (m=16, ef_construct=100)
            - Optimizer config (indexing threshold, memmap threshold)
            - Payload indexes on frequently queried fields

        Args:
            collection_name: Name of the collection.
            dimension: Vector dimensionality.
            **kwargs: Override settings (hnsw_m, enable_quantization, etc.)
        """
        if collection_name in self._initialized_collections:
            return

        try:
            logger.debug("Initializing collection '%s' (dim=%d)", collection_name, dimension)
            collections = await self.client.get_collections()
            existing = {c.name for c in collections.collections}

            if collection_name not in existing:
                quantization = self._build_quantization_config(dimension, **kwargs)
                hnsw = self._build_hnsw_config(**kwargs)

                await self.client.create_collection(
                    collection_name=collection_name,
                    vectors_config=VectorParams(
                        size=dimension,
                        distance=Distance.COSINE,
                        on_disk=self._settings.on_disk,
                    ),
                    quantization_config=quantization,
                    hnsw_config=hnsw,
                    optimizers_config=OptimizersConfigDiff(
                        indexing_threshold=20000,
                        memmap_threshold=50000,
                    ),
                )
                logger.info(
                    "Created collection '%s' (dim=%d)", collection_name, dimension
                )

                # Create payload indexes only on main collection (not template collections)
                # and only when not in memory mode (indexes have no effect in memory mode
                # and Qdrant emits a UserWarning when they are created there).
                if (
                    not collection_name.startswith("__medha_templates_")
                    and self._settings.qdrant_mode != "memory"
                ):
                    try:
                        await self._create_payload_indexes(collection_name)
                    except Exception as idx_err:
                        logger.warning(
                            "Payload indexes not created on '%s' (non-critical): %s",
                            collection_name,
                            idx_err,
                        )
            else:
                logger.info(
                    "Collection '%s' already exists, skipping creation",
                    collection_name,
                )

            self._initialized_collections.add(collection_name)
        except (StorageError, StorageInitializationError):
            raise
        except Exception as e:
            raise StorageInitializationError(
                f"Failed to initialize collection '{collection_name}': {e}"
            ) from e

    async def search(
        self,
        collection_name: str,
        vector: list[float],
        limit: int = 5,
        score_threshold: float = 0.0,
    ) -> list[CacheResult]:
        """Search for similar vectors using query_points.

        Args:
            collection_name: Collection to search.
            vector: Query vector.
            limit: Max number of results.
            score_threshold: Minimum similarity score (0.0 - 1.0).

        Returns:
            List of CacheResult, sorted by descending score.

        Raises:
            StorageError: If the search fails.
        """
        try:
            logger.debug(
                "Searching '%s': limit=%d, threshold=%.3f",
                collection_name,
                limit,
                score_threshold,
            )
            search_params = self._build_search_params()
            now_iso = datetime.now(timezone.utc).isoformat()
            ttl_filter = Filter(
                must_not=[
                    FieldCondition(
                        key="expires_at",
                        range=DatetimeRange(lte=now_iso),
                    )
                ]
            )
            response = await self.client.query_points(
                collection_name=collection_name,
                query=vector,
                limit=limit,
                score_threshold=score_threshold if score_threshold > 0.0 else None,
                search_params=search_params,
                query_filter=ttl_filter,
                with_payload=True,
            )
            results = [self._point_to_cache_result(point) for point in response.points]
            if results:
                logger.debug(
                    "Search '%s' returned %d results (top score=%.4f)",
                    collection_name,
                    len(results),
                    results[0].score,
                )
            else:
                logger.debug("Search '%s' returned 0 results", collection_name)
            return results
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant search failed on '{collection_name}': {e}"
            ) from e

    async def upsert(self, collection_name: str, entries: list[CacheEntry]) -> None:
        """Insert or update cache entries in batches.

        Args:
            collection_name: Target collection.
            entries: List of CacheEntry objects to upsert.

        Raises:
            StorageError: If the upsert fails.
        """
        try:
            points = [self._entry_to_point(e) for e in entries]
            batch_size = self._settings.batch_size

            for i in range(0, len(points), batch_size):
                batch = points[i : i + batch_size]
                await self.client.upsert(
                    collection_name=collection_name,
                    wait=True,
                    points=batch,
                )
                logger.info(
                    "Upserted batch %d: %d points", i // batch_size + 1, len(batch)
                )
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant upsert failed on '{collection_name}': {e}"
            ) from e

    async def scroll(
        self,
        collection_name: str,
        limit: int = 100,
        offset: str | None = None,
        with_vectors: bool = False,
    ) -> tuple[list[CacheResult], str | None]:
        """Iterate over all points in a collection.

        Args:
            collection_name: Collection to scroll.
            limit: Batch size per scroll.
            offset: Pagination token from a previous scroll.
            with_vectors: Whether to include vectors in results.

        Returns:
            Tuple of (results, next_offset). next_offset is None when done.

        Raises:
            StorageError: If the scroll fails.
        """
        try:
            records, next_offset = await self.client.scroll(
                collection_name=collection_name,
                limit=limit,
                offset=offset,
                with_vectors=with_vectors,
                with_payload=True,
            )

            results = []
            for record in records:
                payload = record.payload or {}
                results.append(
                    CacheResult(
                        id=str(record.id),
                        score=0.0,
                        original_question=payload.get("original_question", ""),
                        normalized_question=payload.get("normalized_question", ""),
                        generated_query=payload.get("generated_query", ""),
                        query_hash=payload.get("query_hash", ""),
                        response_summary=payload.get("response_summary"),
                        template_id=payload.get("template_id"),
                        usage_count=payload.get("usage_count", 0),
                        created_at=payload.get("created_at"),
                    )
                )

            next_offset_str = str(next_offset) if next_offset is not None else None
            logger.debug(
                "Scroll '%s': returned %d records, has_more=%s",
                collection_name,
                len(results),
                next_offset_str is not None,
            )
            return results, next_offset_str
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant scroll failed on '{collection_name}': {e}"
            ) from e

    async def count(self, collection_name: str) -> int:
        """Return the number of points in a collection.

        Raises:
            StorageError: If the count fails.
        """
        try:
            result = await self.client.count(collection_name=collection_name)
            return result.count
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant count failed on '{collection_name}': {e}"
            ) from e

    async def delete(self, collection_name: str, ids: list[str]) -> None:
        """Delete points by ID.

        Args:
            collection_name: Target collection.
            ids: List of point IDs to delete.

        Raises:
            StorageError: If the delete fails.
        """
        try:
            await self.client.delete(
                collection_name=collection_name,
                points_selector=PointIdsList(points=cast("list[int | str | uuid.UUID]", ids)),
                wait=True,
            )
            logger.info("Deleted %d points from '%s'", len(ids), collection_name)
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant delete failed on '{collection_name}': {e}"
            ) from e

    async def find_expired(self, collection_name: str) -> list[str]:
        try:
            now_iso = datetime.now(timezone.utc).isoformat()
            results, _ = await self.client.scroll(
                collection_name=collection_name,
                scroll_filter=Filter(
                    must=[
                        FieldCondition(
                            key="expires_at",
                            range=DatetimeRange(lt=now_iso),
                        )
                    ]
                ),
                with_payload=False,
                with_vectors=False,
                limit=1000,
            )
            return [str(p.id) for p in results]
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant find_expired failed on '{collection_name}': {e}"
            ) from e

    async def close(self) -> None:
        """Close the Qdrant client."""
        if self._client is not None:
            await self._client.close()
            self._client = None
            self._initialized_collections.clear()
            logger.info("Qdrant client closed")

    # --- Collection-specific operations ---

    async def search_by_normalized_question(
        self, collection_name: str, normalized_question: str
    ) -> CacheResult | None:
        try:
            results, _ = await self.client.scroll(
                collection_name=collection_name,
                scroll_filter=Filter(
                    must=[
                        FieldCondition(
                            key="normalized_question",
                            match=MatchValue(value=normalized_question),
                        )
                    ]
                ),
                limit=1,
                with_payload=True,
            )
            if not results:
                return None
            record = results[0]
            payload = record.payload or {}
            return CacheResult(
                id=str(record.id),
                score=1.0,
                original_question=payload.get("original_question", ""),
                normalized_question=payload.get("normalized_question", ""),
                generated_query=payload.get("generated_query", ""),
                query_hash=payload.get("query_hash", ""),
                response_summary=payload.get("response_summary"),
                template_id=payload.get("template_id"),
                usage_count=payload.get("usage_count", 0),
                created_at=payload.get("created_at"),
            )
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant search_by_normalized_question failed on '{collection_name}': {e}"
            ) from e

    async def find_by_query_hash(
        self, collection_name: str, query_hash: str
    ) -> list[str]:
        try:
            ids: list[str] = []
            offset: str | None = None
            while True:
                results, next_offset = await self.client.scroll(
                    collection_name=collection_name,
                    scroll_filter=Filter(
                        must=[
                            FieldCondition(
                                key="query_hash",
                                match=MatchValue(value=query_hash),
                            )
                        ]
                    ),
                    limit=1000,
                    offset=offset,
                    with_payload=False,
                )
                ids.extend(str(r.id) for r in results)
                if next_offset is None:
                    break
                offset = str(next_offset)
            return ids
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant find_by_query_hash failed on '{collection_name}': {e}"
            ) from e

    async def find_by_template_id(
        self, collection_name: str, template_id: str
    ) -> list[str]:
        try:
            ids: list[str] = []
            offset: str | None = None
            while True:
                results, next_offset = await self.client.scroll(
                    collection_name=collection_name,
                    scroll_filter=Filter(
                        must=[
                            FieldCondition(
                                key="template_id",
                                match=MatchValue(value=template_id),
                            )
                        ]
                    ),
                    limit=1000,
                    offset=offset,
                    with_payload=False,
                )
                ids.extend(str(r.id) for r in results)
                if next_offset is None:
                    break
                offset = str(next_offset)
            return ids
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant find_by_template_id failed on '{collection_name}': {e}"
            ) from e

    async def drop_collection(self, collection_name: str) -> None:
        try:
            await self.client.delete_collection(collection_name)
            self._initialized_collections.discard(collection_name)
            logger.info("Dropped collection '%s'", collection_name)
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant drop_collection failed on '{collection_name}': {e}"
            ) from e

    async def search_by_query_hash(
        self,
        collection_name: str,
        query_hash: str,
    ) -> CacheResult | None:
        """Find a cache entry by its query hash (exact payload filter).

        Used to check if a template-generated query already has a cached response.

        Args:
            collection_name: Collection to search.
            query_hash: MD5 hash of the generated query.

        Returns:
            CacheResult if found, None otherwise.
        """
        try:
            results, _ = await self.client.scroll(
                collection_name=collection_name,
                scroll_filter=Filter(
                    must=[
                        FieldCondition(
                            key="query_hash",
                            match=MatchValue(value=query_hash),
                        )
                    ]
                ),
                limit=1,
                with_payload=True,
            )

            if not results:
                return None

            record = results[0]
            payload = record.payload or {}
            return CacheResult(
                id=str(record.id),
                score=1.0,
                original_question=payload.get("original_question", ""),
                normalized_question=payload.get("normalized_question", ""),
                generated_query=payload.get("generated_query", ""),
                query_hash=payload.get("query_hash", ""),
                response_summary=payload.get("response_summary"),
                template_id=payload.get("template_id"),
                usage_count=payload.get("usage_count", 0),
                created_at=payload.get("created_at"),
            )
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant search_by_query_hash failed on '{collection_name}': {e}"
            ) from e

    async def update_usage_count(
        self, collection_name: str, point_id: str
    ) -> None:
        """Increment the usage_count for a specific point.

        Used for cache analytics and potential eviction policies.
        """
        try:
            points = await self.client.retrieve(
                collection_name=collection_name,
                ids=[point_id],
                with_payload=True,
            )

            if not points:
                logger.warning(
                    "Point '%s' not found in '%s'", point_id, collection_name
                )
                return

            current_count = (points[0].payload or {}).get("usage_count", 0)

            await self.client.set_payload(
                collection_name=collection_name,
                payload={"usage_count": current_count + 1},
                points=[point_id],
                wait=True,
            )
            logger.debug(
                "Updated usage_count for '%s' to %d", point_id, current_count + 1
            )
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(
                f"Qdrant update_usage_count failed on '{collection_name}': {e}"
            ) from e

    # --- Private methods ---

    def _build_client(self) -> AsyncQdrantClient:
        """Create the Qdrant client based on settings.mode."""
        mode = self._settings.qdrant_mode
        logger.debug("Building Qdrant client for mode='%s'", mode)

        if mode == "memory":
            return AsyncQdrantClient(":memory:")
        elif mode == "docker":
            url = self._settings.qdrant_url or (
                f"http://{self._settings.qdrant_host}:{self._settings.qdrant_port}"
            )
            logger.debug("Qdrant Docker URL: %s", url)
            return AsyncQdrantClient(url=url)
        elif mode == "cloud":
            logger.debug("Qdrant Cloud URL: %s", self._settings.qdrant_url)
            return AsyncQdrantClient(
                url=self._settings.qdrant_url,
                api_key=self._settings.qdrant_api_key.get_secret_value() if self._settings.qdrant_api_key else None,
            )
        else:
            raise StorageInitializationError(f"Unknown qdrant_mode: '{mode}'")

    def _build_quantization_config(self, dimension: int, **kwargs: Any) -> Any:
        """Choose quantization config based on dimension and settings.

        When ``on_disk=True`` and ``quantization_always_ram=True`` (the defaults
        for hybrid storage), original vectors live on disk while quantized
        vectors stay in RAM — giving a good balance of speed and memory.
        """
        enable = kwargs.get("enable_quantization", self._settings.enable_quantization)
        if not enable:
            return None

        q_type = kwargs.get("quantization_type", self._settings.quantization_type)
        always_ram = kwargs.get(
            "quantization_always_ram", self._settings.quantization_always_ram
        )

        if q_type == "binary" and dimension >= 512:
            return BinaryQuantization(
                binary=BinaryQuantizationConfig(always_ram=always_ram)
            )

        # Default: Scalar INT8
        return ScalarQuantization(
            scalar=ScalarQuantizationConfig(
                type=ScalarType.INT8,
                quantile=0.99,
                always_ram=always_ram,
            )
        )

    def _build_hnsw_config(self, **kwargs: Any) -> HnswConfigDiff:
        """Build HNSW config from settings."""
        return HnswConfigDiff(
            m=kwargs.get("hnsw_m", self._settings.hnsw_m),
            ef_construct=kwargs.get(
                "hnsw_ef_construct", self._settings.hnsw_ef_construct
            ),
            full_scan_threshold=10000,
        )

    def _build_search_params(self) -> SearchParams:
        """Build search params with quantization-aware settings.

        Controls how quantized vectors are used at query time:
        - **ignore**: bypass quantization entirely (use original vectors).
        - **rescore**: re-evaluate top candidates with original vectors for
          better accuracy.  Disable when originals are on slow storage.
        - **oversampling**: fetch ``limit * oversampling`` candidates from the
          quantized index before re-scoring.  Higher values improve recall at
          the cost of latency.
        """
        return SearchParams(
            quantization=QuantizationSearchParams(
                ignore=self._settings.quantization_ignore,
                rescore=self._settings.quantization_rescore,
                oversampling=self._settings.quantization_oversampling,
            )
        )

    async def _create_payload_indexes(self, collection_name: str) -> None:
        """Create payload indexes for fast filtering."""
        # KEYWORD indexes
        for field in ("template_id", "query_hash"):
            await self.client.create_payload_index(
                collection_name=collection_name,
                field_name=field,
                field_schema=PayloadSchemaType.KEYWORD,
            )

        # INTEGER index
        await self.client.create_payload_index(
            collection_name=collection_name,
            field_name="usage_count",
            field_schema=PayloadSchemaType.INTEGER,
        )

        # TEXT index with word tokenizer
        await self.client.create_payload_index(
            collection_name=collection_name,
            field_name="normalized_question",
            field_schema=TextIndexParams(
                type=TextIndexType.TEXT,
                tokenizer=TokenizerType.WORD,
                min_token_len=2,
                max_token_len=20,
                lowercase=True,
            ),
        )

        logger.info("Created payload indexes on '%s'", collection_name)

    @staticmethod
    def _point_to_cache_result(point: Any) -> CacheResult:
        """Convert a Qdrant ScoredPoint to a CacheResult."""
        payload = point.payload or {}
        score = point.score if point.score is not None else 0.0
        score = max(0.0, min(1.0, score))

        expires_at: datetime | None = None
        if raw_ea := payload.get("expires_at"):
            with contextlib.suppress(ValueError, TypeError):
                expires_at = datetime.fromisoformat(str(raw_ea))
                if expires_at.tzinfo is None:
                    expires_at = expires_at.replace(tzinfo=timezone.utc)

        return CacheResult(
            id=str(point.id),
            score=score,
            original_question=payload.get("original_question", ""),
            normalized_question=payload.get("normalized_question", ""),
            generated_query=payload.get("generated_query", ""),
            query_hash=payload.get("query_hash", ""),
            response_summary=payload.get("response_summary"),
            template_id=payload.get("template_id"),
            usage_count=payload.get("usage_count", 0),
            created_at=payload.get("created_at"),
            expires_at=expires_at,
        )

    @staticmethod
    def _entry_to_point(entry: CacheEntry) -> PointStruct:
        """Convert a CacheEntry to a Qdrant PointStruct."""
        return PointStruct(
            id=entry.id,
            vector=entry.vector,
            payload={
                "original_question": entry.original_question,
                "normalized_question": entry.normalized_question,
                "generated_query": entry.generated_query,
                "query_hash": entry.query_hash,
                "response_summary": entry.response_summary,
                "template_id": entry.template_id,
                "usage_count": entry.usage_count,
                "created_at": entry.created_at.isoformat(),
                "expires_at": entry.expires_at.isoformat() if entry.expires_at else None,
            },
        )

connect() async

Establish connection to Qdrant based on settings.

Must be called before any other operation.

Raises:

Type Description
StorageInitializationError

If connection fails.

Source code in src/medha/backends/qdrant.py
async def connect(self) -> None:
    """Establish connection to Qdrant based on settings.

    Must be called before any other operation.

    Raises:
        StorageInitializationError: If connection fails.
    """
    try:
        self._client = self._build_client()
        logger.info(
            "Connected to Qdrant in '%s' mode", self._settings.qdrant_mode
        )
    except Exception as e:
        raise StorageInitializationError(
            f"Failed to connect to Qdrant in '{self._settings.qdrant_mode}' mode: {e}"
        ) from e

initialize(collection_name, dimension, **kwargs) async

Create and configure a Qdrant collection.

Idempotent: skips creation if collection already exists.

Configuration includes
  • Vector params (dimension, cosine distance)
  • Quantization (scalar INT8 by default, binary for dim >= 512)
  • HNSW parameters (m=16, ef_construct=100)
  • Optimizer config (indexing threshold, memmap threshold)
  • Payload indexes on frequently queried fields

Parameters:

Name Type Description Default
collection_name str

Name of the collection.

required
dimension int

Vector dimensionality.

required
**kwargs Any

Override settings (hnsw_m, enable_quantization, etc.)

{}
Source code in src/medha/backends/qdrant.py
async def initialize(
    self, collection_name: str, dimension: int, **kwargs: Any
) -> None:
    """Create and configure a Qdrant collection.

    Idempotent: skips creation if collection already exists.

    Configuration includes:
        - Vector params (dimension, cosine distance)
        - Quantization (scalar INT8 by default, binary for dim >= 512)
        - HNSW parameters (m=16, ef_construct=100)
        - Optimizer config (indexing threshold, memmap threshold)
        - Payload indexes on frequently queried fields

    Args:
        collection_name: Name of the collection.
        dimension: Vector dimensionality.
        **kwargs: Override settings (hnsw_m, enable_quantization, etc.)
    """
    if collection_name in self._initialized_collections:
        return

    try:
        logger.debug("Initializing collection '%s' (dim=%d)", collection_name, dimension)
        collections = await self.client.get_collections()
        existing = {c.name for c in collections.collections}

        if collection_name not in existing:
            quantization = self._build_quantization_config(dimension, **kwargs)
            hnsw = self._build_hnsw_config(**kwargs)

            await self.client.create_collection(
                collection_name=collection_name,
                vectors_config=VectorParams(
                    size=dimension,
                    distance=Distance.COSINE,
                    on_disk=self._settings.on_disk,
                ),
                quantization_config=quantization,
                hnsw_config=hnsw,
                optimizers_config=OptimizersConfigDiff(
                    indexing_threshold=20000,
                    memmap_threshold=50000,
                ),
            )
            logger.info(
                "Created collection '%s' (dim=%d)", collection_name, dimension
            )

            # Create payload indexes only on main collection (not template collections)
            # and only when not in memory mode (indexes have no effect in memory mode
            # and Qdrant emits a UserWarning when they are created there).
            if (
                not collection_name.startswith("__medha_templates_")
                and self._settings.qdrant_mode != "memory"
            ):
                try:
                    await self._create_payload_indexes(collection_name)
                except Exception as idx_err:
                    logger.warning(
                        "Payload indexes not created on '%s' (non-critical): %s",
                        collection_name,
                        idx_err,
                    )
        else:
            logger.info(
                "Collection '%s' already exists, skipping creation",
                collection_name,
            )

        self._initialized_collections.add(collection_name)
    except (StorageError, StorageInitializationError):
        raise
    except Exception as e:
        raise StorageInitializationError(
            f"Failed to initialize collection '{collection_name}': {e}"
        ) from e

search(collection_name, vector, limit=5, score_threshold=0.0) async

Search for similar vectors using query_points.

Parameters:

Name Type Description Default
collection_name str

Collection to search.

required
vector list[float]

Query vector.

required
limit int

Max number of results.

5
score_threshold float

Minimum similarity score (0.0 - 1.0).

0.0

Returns:

Type Description
list[CacheResult]

List of CacheResult, sorted by descending score.

Raises:

Type Description
StorageError

If the search fails.

Source code in src/medha/backends/qdrant.py
async def search(
    self,
    collection_name: str,
    vector: list[float],
    limit: int = 5,
    score_threshold: float = 0.0,
) -> list[CacheResult]:
    """Search for similar vectors using query_points.

    Args:
        collection_name: Collection to search.
        vector: Query vector.
        limit: Max number of results.
        score_threshold: Minimum similarity score (0.0 - 1.0).

    Returns:
        List of CacheResult, sorted by descending score.

    Raises:
        StorageError: If the search fails.
    """
    try:
        logger.debug(
            "Searching '%s': limit=%d, threshold=%.3f",
            collection_name,
            limit,
            score_threshold,
        )
        search_params = self._build_search_params()
        now_iso = datetime.now(timezone.utc).isoformat()
        ttl_filter = Filter(
            must_not=[
                FieldCondition(
                    key="expires_at",
                    range=DatetimeRange(lte=now_iso),
                )
            ]
        )
        response = await self.client.query_points(
            collection_name=collection_name,
            query=vector,
            limit=limit,
            score_threshold=score_threshold if score_threshold > 0.0 else None,
            search_params=search_params,
            query_filter=ttl_filter,
            with_payload=True,
        )
        results = [self._point_to_cache_result(point) for point in response.points]
        if results:
            logger.debug(
                "Search '%s' returned %d results (top score=%.4f)",
                collection_name,
                len(results),
                results[0].score,
            )
        else:
            logger.debug("Search '%s' returned 0 results", collection_name)
        return results
    except StorageError:
        raise
    except Exception as e:
        raise StorageError(
            f"Qdrant search failed on '{collection_name}': {e}"
        ) from e

upsert(collection_name, entries) async

Insert or update cache entries in batches.

Parameters:

Name Type Description Default
collection_name str

Target collection.

required
entries list[CacheEntry]

List of CacheEntry objects to upsert.

required

Raises:

Type Description
StorageError

If the upsert fails.

Source code in src/medha/backends/qdrant.py
async def upsert(self, collection_name: str, entries: list[CacheEntry]) -> None:
    """Insert or update cache entries in batches.

    Args:
        collection_name: Target collection.
        entries: List of CacheEntry objects to upsert.

    Raises:
        StorageError: If the upsert fails.
    """
    try:
        points = [self._entry_to_point(e) for e in entries]
        batch_size = self._settings.batch_size

        for i in range(0, len(points), batch_size):
            batch = points[i : i + batch_size]
            await self.client.upsert(
                collection_name=collection_name,
                wait=True,
                points=batch,
            )
            logger.info(
                "Upserted batch %d: %d points", i // batch_size + 1, len(batch)
            )
    except StorageError:
        raise
    except Exception as e:
        raise StorageError(
            f"Qdrant upsert failed on '{collection_name}': {e}"
        ) from e

scroll(collection_name, limit=100, offset=None, with_vectors=False) async

Iterate over all points in a collection.

Parameters:

Name Type Description Default
collection_name str

Collection to scroll.

required
limit int

Batch size per scroll.

100
offset str | None

Pagination token from a previous scroll.

None
with_vectors bool

Whether to include vectors in results.

False

Returns:

Type Description
tuple[list[CacheResult], str | None]

Tuple of (results, next_offset). next_offset is None when done.

Raises:

Type Description
StorageError

If the scroll fails.

Source code in src/medha/backends/qdrant.py
async def scroll(
    self,
    collection_name: str,
    limit: int = 100,
    offset: str | None = None,
    with_vectors: bool = False,
) -> tuple[list[CacheResult], str | None]:
    """Iterate over all points in a collection.

    Args:
        collection_name: Collection to scroll.
        limit: Batch size per scroll.
        offset: Pagination token from a previous scroll.
        with_vectors: Whether to include vectors in results.

    Returns:
        Tuple of (results, next_offset). next_offset is None when done.

    Raises:
        StorageError: If the scroll fails.
    """
    try:
        records, next_offset = await self.client.scroll(
            collection_name=collection_name,
            limit=limit,
            offset=offset,
            with_vectors=with_vectors,
            with_payload=True,
        )

        results = []
        for record in records:
            payload = record.payload or {}
            results.append(
                CacheResult(
                    id=str(record.id),
                    score=0.0,
                    original_question=payload.get("original_question", ""),
                    normalized_question=payload.get("normalized_question", ""),
                    generated_query=payload.get("generated_query", ""),
                    query_hash=payload.get("query_hash", ""),
                    response_summary=payload.get("response_summary"),
                    template_id=payload.get("template_id"),
                    usage_count=payload.get("usage_count", 0),
                    created_at=payload.get("created_at"),
                )
            )

        next_offset_str = str(next_offset) if next_offset is not None else None
        logger.debug(
            "Scroll '%s': returned %d records, has_more=%s",
            collection_name,
            len(results),
            next_offset_str is not None,
        )
        return results, next_offset_str
    except StorageError:
        raise
    except Exception as e:
        raise StorageError(
            f"Qdrant scroll failed on '{collection_name}': {e}"
        ) from e

count(collection_name) async

Return the number of points in a collection.

Raises:

Type Description
StorageError

If the count fails.

Source code in src/medha/backends/qdrant.py
async def count(self, collection_name: str) -> int:
    """Return the number of points in a collection.

    Raises:
        StorageError: If the count fails.
    """
    try:
        result = await self.client.count(collection_name=collection_name)
        return result.count
    except StorageError:
        raise
    except Exception as e:
        raise StorageError(
            f"Qdrant count failed on '{collection_name}': {e}"
        ) from e

delete(collection_name, ids) async

Delete points by ID.

Parameters:

Name Type Description Default
collection_name str

Target collection.

required
ids list[str]

List of point IDs to delete.

required

Raises:

Type Description
StorageError

If the delete fails.

Source code in src/medha/backends/qdrant.py
async def delete(self, collection_name: str, ids: list[str]) -> None:
    """Delete points by ID.

    Args:
        collection_name: Target collection.
        ids: List of point IDs to delete.

    Raises:
        StorageError: If the delete fails.
    """
    try:
        await self.client.delete(
            collection_name=collection_name,
            points_selector=PointIdsList(points=cast("list[int | str | uuid.UUID]", ids)),
            wait=True,
        )
        logger.info("Deleted %d points from '%s'", len(ids), collection_name)
    except StorageError:
        raise
    except Exception as e:
        raise StorageError(
            f"Qdrant delete failed on '{collection_name}': {e}"
        ) from e

close() async

Close the Qdrant client.

Source code in src/medha/backends/qdrant.py
async def close(self) -> None:
    """Close the Qdrant client."""
    if self._client is not None:
        await self._client.close()
        self._client = None
        self._initialized_collections.clear()
        logger.info("Qdrant client closed")

search_by_query_hash(collection_name, query_hash) async

Find a cache entry by its query hash (exact payload filter).

Used to check if a template-generated query already has a cached response.

Parameters:

Name Type Description Default
collection_name str

Collection to search.

required
query_hash str

MD5 hash of the generated query.

required

Returns:

Type Description
CacheResult | None

CacheResult if found, None otherwise.

Source code in src/medha/backends/qdrant.py
async def search_by_query_hash(
    self,
    collection_name: str,
    query_hash: str,
) -> CacheResult | None:
    """Find a cache entry by its query hash (exact payload filter).

    Used to check if a template-generated query already has a cached response.

    Args:
        collection_name: Collection to search.
        query_hash: MD5 hash of the generated query.

    Returns:
        CacheResult if found, None otherwise.
    """
    try:
        results, _ = await self.client.scroll(
            collection_name=collection_name,
            scroll_filter=Filter(
                must=[
                    FieldCondition(
                        key="query_hash",
                        match=MatchValue(value=query_hash),
                    )
                ]
            ),
            limit=1,
            with_payload=True,
        )

        if not results:
            return None

        record = results[0]
        payload = record.payload or {}
        return CacheResult(
            id=str(record.id),
            score=1.0,
            original_question=payload.get("original_question", ""),
            normalized_question=payload.get("normalized_question", ""),
            generated_query=payload.get("generated_query", ""),
            query_hash=payload.get("query_hash", ""),
            response_summary=payload.get("response_summary"),
            template_id=payload.get("template_id"),
            usage_count=payload.get("usage_count", 0),
            created_at=payload.get("created_at"),
        )
    except StorageError:
        raise
    except Exception as e:
        raise StorageError(
            f"Qdrant search_by_query_hash failed on '{collection_name}': {e}"
        ) from e

update_usage_count(collection_name, point_id) async

Increment the usage_count for a specific point.

Used for cache analytics and potential eviction policies.

Source code in src/medha/backends/qdrant.py
async def update_usage_count(
    self, collection_name: str, point_id: str
) -> None:
    """Increment the usage_count for a specific point.

    Used for cache analytics and potential eviction policies.
    """
    try:
        points = await self.client.retrieve(
            collection_name=collection_name,
            ids=[point_id],
            with_payload=True,
        )

        if not points:
            logger.warning(
                "Point '%s' not found in '%s'", point_id, collection_name
            )
            return

        current_count = (points[0].payload or {}).get("usage_count", 0)

        await self.client.set_payload(
            collection_name=collection_name,
            payload={"usage_count": current_count + 1},
            points=[point_id],
            wait=True,
        )
        logger.debug(
            "Updated usage_count for '%s' to %d", point_id, current_count + 1
        )
    except StorageError:
        raise
    except Exception as e:
        raise StorageError(
            f"Qdrant update_usage_count failed on '{collection_name}': {e}"
        ) from e

PgVectorBackend

Bases: _AsyncpgBackendMixin, VectorStorageBackend

PostgreSQL + pgvector backend. Requires asyncpg and pgvector packages.

Source code in src/medha/backends/pgvector.py
class PgVectorBackend(_AsyncpgBackendMixin, VectorStorageBackend):
    """PostgreSQL + pgvector backend. Requires asyncpg and pgvector packages."""

    def __init__(self, settings: Any = None) -> None:
        if not HAS_PGVECTOR:
            raise ConfigurationError(
                "pgvector backend requires 'asyncpg' and 'pgvector'. "
                "Install with: pip install medha-archai[pgvector]"
            )
        from medha.config import Settings
        self._settings = settings or Settings()
        self._pool: asyncpg.Pool | None = None
        self._initialized_tables: set[str] = set()

    async def connect(self) -> None:
        try:
            kwargs = dict(
                min_size=self._settings.pg_pool_min_size,
                max_size=self._settings.pg_pool_max_size,
                init=pgvector.asyncpg.register_vector,
            )
            if self._settings.pg_dsn:
                self._pool = await asyncpg.create_pool(dsn=self._settings.pg_dsn, **kwargs)
            else:
                self._pool = await asyncpg.create_pool(
                    host=self._settings.pg_host,
                    port=self._settings.pg_port,
                    database=self._settings.pg_database,
                    user=self._settings.pg_user,
                    password=self._settings.pg_password.get_secret_value(),
                    **kwargs,
                )
        except Exception as e:
            raise StorageInitializationError(f"Failed to connect to PostgreSQL: {e}") from e

    async def initialize(self, collection_name: str, dimension: int, **kwargs: Any) -> None:
        if self._pool is None:
            raise StorageError("Not connected. Call connect() first.")
        if collection_name in self._initialized_tables:
            return

        schema = self._settings.pg_schema
        table = self._table_name(collection_name)

        try:
            async with self._pool.acquire() as conn:
                await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")

                await conn.execute(f"""
                    CREATE TABLE IF NOT EXISTS {schema}.{table} (
                        id                   UUID        PRIMARY KEY,
                        vector               vector({dimension}) NOT NULL,
                        original_question    TEXT NOT NULL DEFAULT '',
                        normalized_question  TEXT NOT NULL DEFAULT '',
                        generated_query      TEXT NOT NULL DEFAULT '',
                        query_hash           TEXT NOT NULL DEFAULT '',
                        response_summary     TEXT,
                        template_id          TEXT,
                        usage_count          INTEGER NOT NULL DEFAULT 1,
                        created_at           TIMESTAMPTZ NOT NULL DEFAULT NOW()
                    )
                """)

                await conn.execute(f"""
                    CREATE INDEX IF NOT EXISTS {table}_vector_hnsw_idx
                        ON {schema}.{table}
                        USING hnsw (vector vector_cosine_ops)
                        WITH (m = 16, ef_construction = 64)
                """)

                await conn.execute(f"""
                    CREATE INDEX IF NOT EXISTS {table}_query_hash_idx
                        ON {schema}.{table} (query_hash)
                """)

                await conn.execute(f"""
                    CREATE INDEX IF NOT EXISTS {table}_template_id_idx
                        ON {schema}.{table} (template_id)
                        WHERE template_id IS NOT NULL
                """)

                await conn.execute(f"""
                    ALTER TABLE {schema}.{table}
                        ADD COLUMN IF NOT EXISTS expires_at TIMESTAMPTZ
                """)

                await conn.execute(f"""
                    CREATE INDEX IF NOT EXISTS {table}_expires_at_idx
                        ON {schema}.{table} (expires_at)
                        WHERE expires_at IS NOT NULL
                """)

        except asyncpg.PostgresError as e:
            raise StorageInitializationError(
                f"Failed to initialize collection '{collection_name}': {e}"
            ) from e

        self._initialized_tables.add(collection_name)

    async def search(
        self,
        collection_name: str,
        vector: list[float],
        limit: int = 5,
        score_threshold: float = 0.0,
    ) -> list[CacheResult]:
        if self._pool is None:
            raise StorageError("Not connected. Call connect() first.")

        schema = self._settings.pg_schema
        table = self._table_name(collection_name)

        sql = f"""
            SELECT
                id::text,
                original_question,
                normalized_question,
                generated_query,
                query_hash,
                response_summary,
                template_id,
                usage_count,
                created_at,
                expires_at,
                (1 - (vector <=> $1::vector))::float AS score
            FROM {schema}.{table}
            WHERE (1 - (vector <=> $1::vector)) >= $2
              AND (expires_at IS NULL OR expires_at > NOW())
            ORDER BY vector <=> $1::vector
            LIMIT $3
        """
        try:
            async with self._pool.acquire() as conn:
                rows = await conn.fetch(sql, vector, score_threshold, limit)
        except asyncpg.PostgresError as e:
            raise StorageError(f"PgVector operation failed on '{collection_name}': {e}") from e

        return [_row_to_cache_result(row) for row in rows]

    async def upsert(self, collection_name: str, entries: list[CacheEntry]) -> None:
        if self._pool is None:
            raise StorageError("Not connected. Call connect() first.")
        if not entries:
            return

        schema = self._settings.pg_schema
        table = self._table_name(collection_name)

        sql = f"""
            INSERT INTO {schema}.{table} (
                id, vector, original_question, normalized_question,
                generated_query, query_hash, response_summary,
                template_id, usage_count, created_at, expires_at
            )
            VALUES ($1, $2::vector, $3, $4, $5, $6, $7, $8, $9, $10, $11)
            ON CONFLICT (id) DO UPDATE SET
                vector              = EXCLUDED.vector,
                original_question   = EXCLUDED.original_question,
                normalized_question = EXCLUDED.normalized_question,
                generated_query     = EXCLUDED.generated_query,
                query_hash          = EXCLUDED.query_hash,
                response_summary    = EXCLUDED.response_summary,
                template_id         = EXCLUDED.template_id,
                usage_count         = EXCLUDED.usage_count,
                created_at          = EXCLUDED.created_at,
                expires_at          = EXCLUDED.expires_at
        """
        try:
            async with self._pool.acquire() as conn:
                await conn.executemany(sql, [
                    (
                        uuid.UUID(entry.id),
                        entry.vector,
                        entry.original_question,
                        entry.normalized_question,
                        entry.generated_query,
                        entry.query_hash,
                        entry.response_summary,
                        entry.template_id,
                        entry.usage_count,
                        entry.created_at,
                        entry.expires_at,
                    )
                    for entry in entries
                ])
        except asyncpg.PostgresError as e:
            raise StorageError(f"PgVector operation failed on '{collection_name}': {e}") from e

    async def find_expired(self, collection_name: str) -> list[str]:
        if self._pool is None:
            raise StorageError("Not connected. Call connect() first.")

        schema = self._settings.pg_schema
        table = self._table_name(collection_name)

        sql = f"""
            SELECT id::text FROM {schema}.{table}
            WHERE expires_at IS NOT NULL AND expires_at < NOW()
            LIMIT 10000
        """
        try:
            async with self._pool.acquire() as conn:
                rows = await conn.fetch(sql)
            return [row["id"] for row in rows]
        except asyncpg.PostgresError as e:
            raise StorageError(f"PgVector find_expired failed on '{collection_name}': {e}") from e

ElasticsearchBackend

Bases: VectorStorageBackend

Elasticsearch 8.x backend. Requires elasticsearch[async]>=8.12.

Source code in src/medha/backends/elasticsearch.py
class ElasticsearchBackend(VectorStorageBackend):
    """Elasticsearch 8.x backend. Requires elasticsearch[async]>=8.12."""

    def __init__(self, settings: Any = None) -> None:
        if not HAS_ELASTICSEARCH:
            raise ConfigurationError(
                "elasticsearch backend requires 'elasticsearch[async]>=8.12'. "
                "Install with: pip install medha-archai[elasticsearch]"
            )
        from medha.config import Settings
        self._settings = settings or Settings()
        self._client: Any = None

    def _index_name(self, collection_name: str) -> str:
        safe = _INDEX_UNSAFE_RE.sub("_", collection_name.lower())
        prefix = self._settings.es_index_prefix
        return f"{prefix}_{safe}"[:255]

    async def connect(self) -> None:
        kwargs: dict[str, Any] = {
            "hosts": self._settings.es_hosts,
            "request_timeout": self._settings.es_timeout,
        }
        if self._settings.es_api_key is not None:
            kwargs["api_key"] = self._settings.es_api_key.get_secret_value()
        elif self._settings.es_username is not None and self._settings.es_password is not None:
            kwargs["basic_auth"] = (
                self._settings.es_username,
                self._settings.es_password.get_secret_value(),
            )
        try:
            self._client = AsyncElasticsearch(**kwargs)
            await self._client.info()
        except Exception as e:
            raise StorageInitializationError(f"Failed to connect to Elasticsearch: {e}") from e

    async def initialize(self, collection_name: str, dimension: int, **kwargs: Any) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        index = self._index_name(collection_name)
        try:
            exists = await self._client.indices.exists(index=index)
            if exists:
                return
            mapping = {
                "mappings": {
                    "properties": {
                        "vector": {
                            "type": "dense_vector",
                            "dims": dimension,
                            "index": True,
                            "similarity": "cosine",
                        },
                        "original_question": {"type": "text"},
                        "normalized_question": {"type": "keyword"},
                        "generated_query": {"type": "text"},
                        "query_hash": {"type": "keyword"},
                        "response_summary": {"type": "text"},
                        "template_id": {"type": "keyword"},
                        "usage_count": {"type": "integer"},
                        "created_at": {"type": "date"},
                        "expires_at": {"type": "date"},
                    }
                },
                "settings": {
                    "number_of_shards": 1,
                    "number_of_replicas": 0,
                },
            }
            await self._client.indices.create(index=index, body=mapping)
            logger.info("Created Elasticsearch index '%s'", index)
        except Exception as e:
            raise StorageInitializationError(
                f"Failed to initialize Elasticsearch index '{index}': {e}"
            ) from e

    async def search(
        self,
        collection_name: str,
        vector: list[float],
        limit: int = 5,
        score_threshold: float = 0.0,
    ) -> list[CacheResult]:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        index = self._index_name(collection_name)
        ttl_filter = {
            "bool": {
                "should": [
                    {"bool": {"must_not": {"exists": {"field": "expires_at"}}}},
                    {"range": {"expires_at": {"gt": "now"}}},
                ]
            }
        }
        query = {
            "knn": {
                "field": "vector",
                "query_vector": vector,
                "k": limit,
                "num_candidates": self._settings.es_num_candidates,
                "filter": ttl_filter,
            },
            "size": limit,
            "_source": {
                "excludes": ["vector"]
            },
        }
        try:
            resp = await self._client.search(index=index, body=query)
        except NotFoundError as e:
            raise StorageError(f"Elasticsearch index '{index}' not found: {e}") from e
        except TransportError as e:
            raise StorageError(f"Elasticsearch transport error on '{collection_name}': {e}") from e

        results = []
        for hit in resp["hits"]["hits"]:
            score = (hit["_score"] * 2) - 1
            if score < score_threshold:
                continue
            src = hit["_source"]
            results.append(_hit_to_cache_result(hit["_id"], src, score))
        return results

    async def upsert(self, collection_name: str, entries: list[CacheEntry]) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        if not entries:
            return
        index = self._index_name(collection_name)

        def _actions() -> Any:
            for entry in entries:
                doc: dict[str, Any] = {
                    "original_question": entry.original_question,
                    "normalized_question": entry.normalized_question,
                    "generated_query": entry.generated_query,
                    "query_hash": entry.query_hash,
                    "response_summary": entry.response_summary,
                    "template_id": entry.template_id,
                    "usage_count": entry.usage_count,
                    "created_at": _dt_to_str(entry.created_at),
                    "vector": entry.vector,
                }
                if entry.expires_at is not None:
                    doc["expires_at"] = _dt_to_str(entry.expires_at)
                yield {
                    "_op_type": "index",
                    "_index": index,
                    "_id": entry.id,
                    "_source": doc,
                }

        try:
            await async_bulk(self._client, _actions())
        except Exception as e:
            raise StorageError(f"Elasticsearch upsert failed on '{collection_name}': {e}") from e

    async def scroll(
        self,
        collection_name: str,
        limit: int = 100,
        offset: str | None = None,
        with_vectors: bool = False,
    ) -> tuple[list[CacheResult], str | None]:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        index = self._index_name(collection_name)

        search_after = json.loads(offset) if offset is not None else None
        source_excludes = [] if with_vectors else ["vector"]

        query: dict[str, Any] = {
            "query": {"match_all": {}},
            "size": limit,
            "sort": [{"created_at": "asc"}, {"_id": "asc"}],
            "_source": {"excludes": source_excludes},
        }
        if search_after is not None:
            query["search_after"] = search_after

        try:
            resp = await self._client.search(index=index, body=query)
        except NotFoundError as e:
            raise StorageError(f"Elasticsearch index '{index}' not found: {e}") from e
        except TransportError as e:
            raise StorageError(f"Elasticsearch transport error on '{collection_name}': {e}") from e

        hits = resp["hits"]["hits"]
        results = [_hit_to_cache_result(h["_id"], h["_source"], 1.0) for h in hits]
        next_offset: str | None = None
        if len(hits) == limit:
            next_offset = json.dumps(hits[-1]["sort"])
        return results, next_offset

    async def count(self, collection_name: str) -> int:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        index = self._index_name(collection_name)
        try:
            resp = await self._client.count(index=index)
            return int(resp["count"])
        except NotFoundError:
            return 0
        except TransportError as e:
            raise StorageError(f"Elasticsearch count failed on '{collection_name}': {e}") from e

    async def delete(self, collection_name: str, ids: list[str]) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        if not ids:
            return
        index = self._index_name(collection_name)

        actions = [
            {"_op_type": "delete", "_index": index, "_id": id_}
            for id_ in ids
        ]
        try:
            await async_bulk(self._client, actions, ignore_status=[404])
        except Exception as e:
            raise StorageError(f"Elasticsearch delete failed on '{collection_name}': {e}") from e

    async def search_by_query_hash(
        self, collection_name: str, query_hash: str
    ) -> CacheResult | None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        index = self._index_name(collection_name)
        query = {
            "query": {"term": {"query_hash": query_hash}},
            "size": 1,
            "_source": {"excludes": ["vector"]},
        }
        try:
            resp = await self._client.search(index=index, body=query)
        except NotFoundError:
            return None
        except TransportError as e:
            raise StorageError(f"Elasticsearch search_by_query_hash failed on '{collection_name}': {e}") from e
        hits = resp["hits"]["hits"]
        if not hits:
            return None
        return _hit_to_cache_result(hits[0]["_id"], hits[0]["_source"], 1.0)

    async def update_usage_count(self, collection_name: str, point_id: str) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        index = self._index_name(collection_name)
        try:
            await self._client.update(
                index=index,
                id=point_id,
                body={"script": {"source": "ctx._source.usage_count += 1", "lang": "painless"}},
            )
        except NotFoundError:
            logger.warning(
                "update_usage_count: id '%s' not found in collection '%s'",
                point_id,
                collection_name,
            )
        except TransportError as e:
            raise StorageError(f"Elasticsearch update_usage_count failed on '{collection_name}': {e}") from e

    async def find_expired(self, collection_name: str) -> list[str]:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        index = self._index_name(collection_name)
        query = {
            "query": {
                "bool": {
                    "must": [
                        {"exists": {"field": "expires_at"}},
                        {"range": {"expires_at": {"lt": "now"}}},
                    ]
                }
            },
            "size": 10000,
            "_source": False,
        }
        try:
            resp = await self._client.search(index=index, body=query)
            return [hit["_id"] for hit in resp["hits"]["hits"]]
        except NotFoundError:
            return []
        except TransportError as e:
            raise StorageError(f"Elasticsearch find_expired failed on '{collection_name}': {e}") from e

    async def search_by_normalized_question(
        self, collection_name: str, normalized_question: str
    ) -> CacheResult | None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        index = self._index_name(collection_name)
        query = {
            "query": {"term": {"normalized_question": normalized_question}},
            "size": 1,
            "_source": {"excludes": ["vector"]},
        }
        try:
            resp = await self._client.search(index=index, body=query)
        except NotFoundError:
            return None
        except TransportError as e:
            raise StorageError(
                f"Elasticsearch search_by_normalized_question failed on '{collection_name}': {e}"
            ) from e
        hits = resp["hits"]["hits"]
        if not hits:
            return None
        return _hit_to_cache_result(hits[0]["_id"], hits[0]["_source"], 1.0)

    async def find_by_query_hash(
        self, collection_name: str, query_hash: str
    ) -> list[str]:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        index = self._index_name(collection_name)
        query = {
            "query": {"term": {"query_hash": query_hash}},
            "size": 10000,
            "_source": False,
        }
        try:
            resp = await self._client.search(index=index, body=query)
            return [hit["_id"] for hit in resp["hits"]["hits"]]
        except NotFoundError:
            return []
        except TransportError as e:
            raise StorageError(f"Elasticsearch find_by_query_hash failed on '{collection_name}': {e}") from e

    async def find_by_template_id(
        self, collection_name: str, template_id: str
    ) -> list[str]:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        index = self._index_name(collection_name)
        query = {
            "query": {"term": {"template_id": template_id}},
            "size": 10000,
            "_source": False,
        }
        try:
            resp = await self._client.search(index=index, body=query)
            return [hit["_id"] for hit in resp["hits"]["hits"]]
        except NotFoundError:
            return []
        except TransportError as e:
            raise StorageError(f"Elasticsearch find_by_template_id failed on '{collection_name}': {e}") from e

    async def drop_collection(self, collection_name: str) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        index = self._index_name(collection_name)
        try:
            await self._client.indices.delete(index=index, ignore_unavailable=True)
            logger.info("Dropped Elasticsearch index '%s'", index)
        except TransportError as e:
            raise StorageError(f"Elasticsearch drop_collection failed on '{collection_name}': {e}") from e

    async def close(self) -> None:
        if self._client is not None:
            await self._client.close()
            self._client = None

VectorChordBackend

Bases: _AsyncpgBackendMixin, VectorStorageBackend

PostgreSQL + VectorChord backend using vchordrq index with RaBitQ quantization.

Drop-in replacement for PgVectorBackend. Requires asyncpg (no pgvector Python package). The vectorchord PostgreSQL extension must be installed in the database.

Source code in src/medha/backends/vectorchord.py
class VectorChordBackend(_AsyncpgBackendMixin, VectorStorageBackend):
    """PostgreSQL + VectorChord backend using vchordrq index with RaBitQ quantization.

    Drop-in replacement for PgVectorBackend. Requires asyncpg (no pgvector Python package).
    The vectorchord PostgreSQL extension must be installed in the database.
    """

    def __init__(self, settings: Any = None) -> None:
        if not HAS_VECTORCHORD:
            raise ConfigurationError(
                "vectorchord backend requires 'asyncpg'. "
                "Install with: pip install medha-archai[vectorchord]"
            )
        from medha.config import Settings
        self._settings = settings or Settings()
        self._pool: asyncpg.Pool | None = None
        self._initialized_tables: set[str] = set()

    async def _register_codecs(self, conn: Any) -> None:
        await conn.set_type_codec(
            "vector",
            encoder=_encode_vector,
            decoder=_decode_vector,
            schema="public",
            format="text",
        )

    async def connect(self) -> None:
        try:
            kwargs = dict(
                min_size=self._settings.pg_pool_min_size,
                max_size=self._settings.pg_pool_max_size,
                init=self._register_codecs,
            )
            if self._settings.pg_dsn:
                self._pool = await asyncpg.create_pool(dsn=self._settings.pg_dsn, **kwargs)
            else:
                self._pool = await asyncpg.create_pool(
                    host=self._settings.pg_host,
                    port=self._settings.pg_port,
                    database=self._settings.pg_database,
                    user=self._settings.pg_user,
                    password=self._settings.pg_password.get_secret_value(),
                    **kwargs,
                )
        except Exception as e:
            raise StorageInitializationError(f"Failed to connect to PostgreSQL: {e}") from e

    async def initialize(self, collection_name: str, dimension: int, **kwargs: Any) -> None:
        if self._pool is None:
            raise StorageError("Not connected. Call connect() first.")
        if collection_name in self._initialized_tables:
            return

        schema = self._settings.pg_schema
        table = self._table_name(collection_name)

        vc_lists = kwargs.get("vc_lists", self._settings.vc_lists)
        vc_residual = kwargs.get("vc_residual_quantization", self._settings.vc_residual_quantization)

        lists_sql = json.dumps(vc_lists)

        try:
            async with self._pool.acquire() as conn:
                await conn.execute("CREATE EXTENSION IF NOT EXISTS vectorchord")

                await conn.execute(f"""
                    CREATE TABLE IF NOT EXISTS {schema}.{table} (
                        id                   UUID        PRIMARY KEY,
                        vector               vector({dimension}) NOT NULL,
                        original_question    TEXT NOT NULL DEFAULT '',
                        normalized_question  TEXT NOT NULL DEFAULT '',
                        generated_query      TEXT NOT NULL DEFAULT '',
                        query_hash           TEXT NOT NULL DEFAULT '',
                        response_summary     TEXT,
                        template_id          TEXT,
                        usage_count          INTEGER NOT NULL DEFAULT 1,
                        created_at           TIMESTAMPTZ NOT NULL DEFAULT NOW(),
                        expires_at           TIMESTAMPTZ
                    )
                """)

                residual_str = "true" if vc_residual else "false"
                await conn.execute(f"""
                    CREATE INDEX IF NOT EXISTS {table}_vector_vchordrq_idx
                        ON {schema}.{table}
                        USING vchordrq (vector vector_cosine_ops)
                        WITH (residual_quantization = {residual_str}, lists = '{lists_sql}')
                """)

                await conn.execute(f"""
                    CREATE INDEX IF NOT EXISTS {table}_query_hash_idx
                        ON {schema}.{table} (query_hash)
                """)

                await conn.execute(f"""
                    CREATE INDEX IF NOT EXISTS {table}_template_id_idx
                        ON {schema}.{table} (template_id)
                        WHERE template_id IS NOT NULL
                """)

                await conn.execute(f"""
                    CREATE INDEX IF NOT EXISTS {table}_expires_at_idx
                        ON {schema}.{table} (expires_at)
                        WHERE expires_at IS NOT NULL
                """)

        except asyncpg.PostgresError as e:
            raise StorageInitializationError(
                f"Failed to initialize collection '{collection_name}': {e}"
            ) from e

        self._initialized_tables.add(collection_name)

    async def search(
        self,
        collection_name: str,
        vector: list[float],
        limit: int = 5,
        score_threshold: float = 0.0,
    ) -> list[CacheResult]:
        if self._pool is None:
            raise StorageError("Not connected. Call connect() first.")

        schema = self._settings.pg_schema
        table = self._table_name(collection_name)

        sql = f"""
            SELECT
                id::text,
                original_question,
                normalized_question,
                generated_query,
                query_hash,
                response_summary,
                template_id,
                usage_count,
                created_at,
                expires_at,
                (1 - (vector <=> $1::vector))::float AS score
            FROM {schema}.{table}
            WHERE (1 - (vector <=> $1::vector)) >= $2
              AND (expires_at IS NULL OR expires_at > NOW())
            ORDER BY vector <=> $1::vector
            LIMIT $3
        """
        try:
            async with self._pool.acquire() as conn:
                rows = await conn.fetch(sql, vector, score_threshold, limit)
        except asyncpg.PostgresError as e:
            raise StorageError(f"VectorChord operation failed on '{collection_name}': {e}") from e

        return [_row_to_cache_result(row) for row in rows]

    async def upsert(self, collection_name: str, entries: list[CacheEntry]) -> None:
        if self._pool is None:
            raise StorageError("Not connected. Call connect() first.")
        if not entries:
            return

        schema = self._settings.pg_schema
        table = self._table_name(collection_name)

        sql = f"""
            INSERT INTO {schema}.{table} (
                id, vector, original_question, normalized_question,
                generated_query, query_hash, response_summary,
                template_id, usage_count, created_at, expires_at
            )
            VALUES ($1, $2::vector, $3, $4, $5, $6, $7, $8, $9, $10, $11)
            ON CONFLICT (id) DO UPDATE SET
                vector              = EXCLUDED.vector,
                original_question   = EXCLUDED.original_question,
                normalized_question = EXCLUDED.normalized_question,
                generated_query     = EXCLUDED.generated_query,
                query_hash          = EXCLUDED.query_hash,
                response_summary    = EXCLUDED.response_summary,
                template_id         = EXCLUDED.template_id,
                usage_count         = EXCLUDED.usage_count,
                created_at          = EXCLUDED.created_at,
                expires_at          = EXCLUDED.expires_at
        """
        try:
            async with self._pool.acquire() as conn:
                await conn.executemany(sql, [
                    (
                        uuid.UUID(entry.id),
                        entry.vector,
                        entry.original_question,
                        entry.normalized_question,
                        entry.generated_query,
                        entry.query_hash,
                        entry.response_summary,
                        entry.template_id,
                        entry.usage_count,
                        entry.created_at,
                        entry.expires_at,
                    )
                    for entry in entries
                ])
        except asyncpg.PostgresError as e:
            raise StorageError(f"VectorChord operation failed on '{collection_name}': {e}") from e

    async def find_expired(self, collection_name: str) -> list[str]:
        if self._pool is None:
            raise StorageError("Not connected. Call connect() first.")

        schema = self._settings.pg_schema
        table = self._table_name(collection_name)

        sql = f"""
            SELECT id::text FROM {schema}.{table}
            WHERE expires_at IS NOT NULL AND expires_at < NOW()
            LIMIT 10000
        """
        try:
            async with self._pool.acquire() as conn:
                rows = await conn.fetch(sql)
            return [row["id"] for row in rows]
        except asyncpg.PostgresError as e:
            raise StorageError(f"VectorChord find_expired failed on '{collection_name}': {e}") from e

ChromaBackend

Bases: VectorStorageBackend

Chroma vector backend. Supports ephemeral, persistent, and http modes.

Only 'http' uses the native async client; the other two wrap sync calls with asyncio.to_thread.

Source code in src/medha/backends/chroma.py
class ChromaBackend(VectorStorageBackend):
    """Chroma vector backend. Supports ephemeral, persistent, and http modes.

    Only 'http' uses the native async client; the other two wrap sync calls with
    asyncio.to_thread.
    """

    def __init__(self, settings: Any = None) -> None:
        if not HAS_CHROMA:
            raise ConfigurationError(
                "chroma backend requires 'chromadb>=0.5'. "
                "Install with: pip install medha-archai[chroma]"
            )
        from medha.config import Settings
        self._settings = settings or Settings()
        self._client: Any = None
        self._is_async: bool = False
        self._collections: dict[str, Any] = {}

    async def _run(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
        if self._is_async:
            return await fn(*args, **kwargs)
        return await asyncio.to_thread(fn, *args, **kwargs)

    async def connect(self) -> None:
        mode = self._settings.chroma_mode
        try:
            if mode == "http":
                self._is_async = True
                extra: dict[str, Any] = {}
                if self._settings.chroma_auth_token:
                    extra["headers"] = {
                        "Authorization": f"Bearer {self._settings.chroma_auth_token.get_secret_value()}"
                    }
                self._client = await chromadb.AsyncHttpClient(
                    host=self._settings.chroma_host,
                    port=self._settings.chroma_port,
                    ssl=self._settings.chroma_ssl,
                    **extra,
                )
            elif mode == "persistent":
                self._is_async = False
                path = self._settings.chroma_persist_path or "./chroma_data"
                self._client = await asyncio.to_thread(chromadb.PersistentClient, path=path)
            else:
                self._is_async = False
                self._client = await asyncio.to_thread(chromadb.EphemeralClient)
        except Exception as e:
            raise StorageInitializationError(f"Failed to connect to Chroma ({mode}): {e}") from e

    async def initialize(self, collection_name: str, dimension: int, **kwargs: Any) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        if collection_name in self._collections:
            return
        chroma_name = _chroma_collection_name(collection_name)
        try:
            collection = await self._run(
                self._client.get_or_create_collection,
                name=chroma_name,
                metadata={"hnsw:space": "cosine"},
            )
            self._collections[collection_name] = collection
        except Exception as e:
            raise StorageInitializationError(
                f"Failed to initialize Chroma collection '{chroma_name}': {e}"
            ) from e

    def _get_collection(self, collection_name: str) -> Any:
        col = self._collections.get(collection_name)
        if col is None:
            raise StorageError(
                f"Collection '{collection_name}' not initialized. Call initialize() first."
            )
        return col

    async def search(
        self,
        collection_name: str,
        vector: list[float],
        limit: int = 5,
        score_threshold: float = 0.0,
    ) -> list[CacheResult]:
        collection = self._get_collection(collection_name)
        cnt = await self._run(collection.count)
        if cnt == 0:
            return []
        now_iso = _now_iso()
        where = {"$or": [{"expires_at": {"$eq": ""}}, {"expires_at": {"$gt": now_iso}}]}
        try:
            result = await self._run(
                collection.query,
                query_embeddings=[vector],
                n_results=min(limit, cnt),
                where=where,
                include=["metadatas", "distances"],
            )
        except Exception as e:
            raise StorageError(f"Chroma search failed on '{collection_name}': {e}") from e

        ids = result["ids"][0]
        distances = result["distances"][0]
        metadatas = result["metadatas"][0]
        out = []
        for id_, dist, meta in zip(ids, distances, metadatas):
            score = 1.0 - dist
            if score >= score_threshold:
                out.append(_meta_to_result(id_, score, meta))
        return out

    async def upsert(self, collection_name: str, entries: list[CacheEntry]) -> None:
        if not entries:
            return
        collection = self._get_collection(collection_name)
        ids = [e.id for e in entries]
        embeddings = [e.vector for e in entries]
        metadatas = [_entry_to_metadata(e) for e in entries]
        try:
            await self._run(collection.upsert, ids=ids, embeddings=embeddings, metadatas=metadatas)
        except Exception as e:
            raise StorageError(f"Chroma upsert failed on '{collection_name}': {e}") from e

    async def scroll(
        self,
        collection_name: str,
        limit: int = 100,
        offset: str | None = None,
        with_vectors: bool = False,
    ) -> tuple[list[CacheResult], str | None]:
        collection = self._get_collection(collection_name)
        int_offset = int(offset) if offset else 0
        include = ["metadatas", "embeddings"] if with_vectors else ["metadatas"]
        try:
            result = await self._run(
                collection.get,
                limit=limit,
                offset=int_offset,
                include=include,
            )
        except Exception as e:
            raise StorageError(f"Chroma scroll failed on '{collection_name}': {e}") from e

        ids: list[str] = result.get("ids", [])
        metadatas: list[dict[str, Any]] = result.get("metadatas", [])
        cache_results = [_meta_to_result(id_, 1.0, meta) for id_, meta in zip(ids, metadatas)]
        next_offset = str(int_offset + len(ids)) if len(ids) == limit else None
        return cache_results, next_offset

    async def count(self, collection_name: str) -> int:
        collection = self._get_collection(collection_name)
        try:
            return await self._run(collection.count)
        except Exception as e:
            raise StorageError(f"Chroma count failed on '{collection_name}': {e}") from e

    async def delete(self, collection_name: str, ids: list[str]) -> None:
        if not ids:
            return
        collection = self._get_collection(collection_name)
        try:
            await self._run(collection.delete, ids=ids)
        except Exception as e:
            raise StorageError(f"Chroma delete failed on '{collection_name}': {e}") from e

    async def search_by_query_hash(
        self, collection_name: str, query_hash: str
    ) -> CacheResult | None:
        collection = self._get_collection(collection_name)
        try:
            result = await self._run(
                collection.get,
                where={"query_hash": {"$eq": query_hash}},
                limit=1,
                include=["metadatas"],
            )
        except Exception as e:
            raise StorageError(
                f"Chroma search_by_query_hash failed on '{collection_name}': {e}"
            ) from e
        ids: list[str] = result.get("ids", [])
        metadatas: list[dict[str, Any]] = result.get("metadatas", [])
        if not ids:
            return None
        return _meta_to_result(ids[0], 1.0, metadatas[0])

    async def update_usage_count(self, collection_name: str, id_: str) -> None:
        collection = self._get_collection(collection_name)
        try:
            result = await self._run(collection.get, ids=[id_], include=["metadatas"])
            ids: list[str] = result.get("ids", [])
            if not ids:
                return
            meta = dict(result["metadatas"][0])
            meta["usage_count"] = int(meta.get("usage_count", 0)) + 1
            await self._run(collection.upsert, ids=[id_], metadatas=[meta])
        except Exception as e:
            raise StorageError(
                f"Chroma update_usage_count failed on '{collection_name}': {e}"
            ) from e

    async def find_expired(self, collection_name: str) -> list[str]:
        collection = self._get_collection(collection_name)
        now_iso = _now_iso()
        try:
            result = await self._run(
                collection.get,
                where={"$and": [{"expires_at": {"$ne": ""}}, {"expires_at": {"$lt": now_iso}}]},
                include=["metadatas"],
            )
        except Exception as e:
            raise StorageError(f"Chroma find_expired failed on '{collection_name}': {e}") from e
        return result.get("ids", [])

    async def search_by_normalized_question(
        self, collection_name: str, normalized_question: str
    ) -> CacheResult | None:
        collection = self._get_collection(collection_name)
        try:
            result = await self._run(
                collection.get,
                where={"normalized_question": {"$eq": normalized_question}},
                limit=1,
                include=["metadatas"],
            )
        except Exception as e:
            raise StorageError(
                f"Chroma search_by_normalized_question failed on '{collection_name}': {e}"
            ) from e
        ids: list[str] = result.get("ids", [])
        metadatas: list[dict[str, Any]] = result.get("metadatas", [])
        if not ids:
            return None
        return _meta_to_result(ids[0], 1.0, metadatas[0])

    async def find_by_query_hash(self, collection_name: str, query_hash: str) -> list[str]:
        collection = self._get_collection(collection_name)
        try:
            result = await self._run(
                collection.get,
                where={"query_hash": {"$eq": query_hash}},
                include=["metadatas"],
            )
        except Exception as e:
            raise StorageError(
                f"Chroma find_by_query_hash failed on '{collection_name}': {e}"
            ) from e
        return result.get("ids", [])

    async def find_by_template_id(self, collection_name: str, template_id: str) -> list[str]:
        collection = self._get_collection(collection_name)
        try:
            result = await self._run(
                collection.get,
                where={"template_id": {"$eq": template_id}},
                include=["metadatas"],
            )
        except Exception as e:
            raise StorageError(
                f"Chroma find_by_template_id failed on '{collection_name}': {e}"
            ) from e
        return result.get("ids", [])

    async def drop_collection(self, collection_name: str) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        chroma_name = _chroma_collection_name(collection_name)
        try:
            if self._is_async:
                await self._client.delete_collection(name=chroma_name)
            else:
                await asyncio.to_thread(self._client.delete_collection, name=chroma_name)
            self._collections.pop(collection_name, None)
        except Exception as e:
            raise StorageError(
                f"Chroma drop_collection failed on '{collection_name}': {e}"
            ) from e

    async def close(self) -> None:
        if self._is_async and self._client is not None:
            try:
                await self._client.aclose()
            except Exception:
                pass
        self._client = None
        self._collections.clear()

WeaviateBackend

Bases: VectorStorageBackend

Weaviate v4 vector backend. Supports local and cloud modes.

Source code in src/medha/backends/weaviate.py
class WeaviateBackend(VectorStorageBackend):
    """Weaviate v4 vector backend. Supports local and cloud modes."""

    def __init__(self, settings: Any = None) -> None:
        if not HAS_WEAVIATE:
            raise ConfigurationError(
                "weaviate backend requires 'weaviate-client>=4.6'. "
                "Install with: pip install medha-archai[weaviate]"
            )
        from medha.config import Settings
        self._settings = settings or Settings()
        self._client: Any = None
        self._collections: dict[str, Any] = {}

    async def connect(self) -> None:
        mode = self._settings.weaviate_mode
        try:
            auth = None
            if self._settings.weaviate_api_key:
                auth = wvc.init.Auth.api_key(self._settings.weaviate_api_key.get_secret_value())

            if mode == "cloud":
                self._client = weaviate.use_async_with_weaviate_cloud(
                    cluster_url=self._settings.weaviate_cloud_url,
                    auth_credentials=auth,
                )
            else:
                self._client = weaviate.use_async_with_local(
                    host=self._settings.weaviate_host,
                    port=self._settings.weaviate_http_port,
                    grpc_port=self._settings.weaviate_grpc_port,
                    auth_credentials=auth,
                )
            await self._client.connect()
        except Exception as e:
            raise StorageInitializationError(f"Failed to connect to Weaviate ({mode}): {e}") from e

    async def initialize(self, collection_name: str, dimension: int, **kwargs: Any) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        if collection_name in self._collections:
            return
        wv_name = _wv_collection_name(self._settings.weaviate_collection_prefix, collection_name)
        try:
            if await self._client.collections.exists(wv_name):
                # index_null_state is immutable — drop and recreate if it is not set.
                # Data loss is acceptable: this is a cache.
                col = self._client.collections.get(wv_name)
                try:
                    cfg = await col.config.get()
                    needs_recreate = not getattr(cfg.inverted_index_config, "index_null_state", False)
                except Exception:
                    needs_recreate = False
                if needs_recreate:
                    logger.warning(
                        "Collection '%s' was created without indexNullState=True "
                        "(required for TTL filters). Dropping and recreating — cached entries will be lost.",
                        wv_name,
                    )
                    await self._client.collections.delete(wv_name)
                else:
                    self._collections[collection_name] = col
                    return

            await self._client.collections.create(
                name=wv_name,
                properties=[
                    wvc.config.Property(name="original_question", data_type=wvc.config.DataType.TEXT),
                    wvc.config.Property(name="normalized_question", data_type=wvc.config.DataType.TEXT),
                    wvc.config.Property(name="generated_query", data_type=wvc.config.DataType.TEXT),
                    wvc.config.Property(name="query_hash", data_type=wvc.config.DataType.TEXT),
                    wvc.config.Property(name="response_summary", data_type=wvc.config.DataType.TEXT),
                    wvc.config.Property(name="template_id", data_type=wvc.config.DataType.TEXT),
                    wvc.config.Property(name="usage_count", data_type=wvc.config.DataType.INT),
                    wvc.config.Property(name="created_at", data_type=wvc.config.DataType.DATE),
                    wvc.config.Property(name="expires_at", data_type=wvc.config.DataType.DATE),
                ],
                inverted_index_config=wvc.config.Configure.inverted_index(
                    index_null_state=True,
                ),
                vectorizer_config=wvc.config.Configure.Vectorizer.none(),
                vector_index_config=wvc.config.Configure.VectorIndex.hnsw(
                    distance_metric=wvc.config.VectorDistances.COSINE
                ),
            )
            self._collections[collection_name] = self._client.collections.get(wv_name)
        except Exception as e:
            raise StorageInitializationError(
                f"Failed to initialize Weaviate collection '{wv_name}': {e}"
            ) from e

    def _get_collection(self, collection_name: str) -> Any:
        col = self._collections.get(collection_name)
        if col is None:
            raise StorageError(
                f"Collection '{collection_name}' not initialized. Call initialize() first."
            )
        return col

    async def search(
        self,
        collection_name: str,
        vector: list[float],
        limit: int = 5,
        score_threshold: float = 0.0,
    ) -> list[CacheResult]:
        collection = self._get_collection(collection_name)
        try:
            result = await collection.query.near_vector(
                near_vector=vector,
                limit=limit,
                filters=_ttl_filter(),
                return_metadata=MetadataQuery(distance=True),
            )
        except Exception as e:
            raise StorageError(f"Weaviate search failed on '{collection_name}': {e}") from e

        out = []
        for obj in result.objects:
            score = 1.0 - (obj.metadata.distance or 0.0)
            if score >= score_threshold:
                out.append(_obj_to_result(obj, score))
        return out

    async def upsert(self, collection_name: str, entries: list[CacheEntry]) -> None:
        if not entries:
            return
        collection = self._get_collection(collection_name)
        objects = [
            DataObject(
                uuid=entry.id,
                properties=_entry_to_properties(entry),
                vector=entry.vector,
            )
            for entry in entries
        ]
        try:
            result = await collection.data.insert_many(objects)
            if result.has_errors:
                errors = [str(e) for e in result.errors.values()]
                raise StorageError(f"Weaviate upsert errors on '{collection_name}': {errors}")
        except StorageError:
            raise
        except Exception as e:
            raise StorageError(f"Weaviate upsert failed on '{collection_name}': {e}") from e

    async def scroll(
        self,
        collection_name: str,
        limit: int = 100,
        offset: str | None = None,
        with_vectors: bool = False,
    ) -> tuple[list[CacheResult], str | None]:
        collection = self._get_collection(collection_name)
        after = UUID(offset) if offset else None
        try:
            result = await collection.query.fetch_objects(
                limit=limit,
                after=after,
                include_vector=with_vectors,
            )
        except Exception as e:
            raise StorageError(f"Weaviate scroll failed on '{collection_name}': {e}") from e

        objects = result.objects
        cache_results = [_obj_to_result(obj, 1.0) for obj in objects]
        next_offset = str(objects[-1].uuid) if len(objects) == limit else None
        return cache_results, next_offset

    async def count(self, collection_name: str) -> int:
        collection = self._get_collection(collection_name)
        try:
            result = await collection.aggregate.over_all(total_count=True)
            return result.total_count or 0
        except Exception as e:
            raise StorageError(f"Weaviate count failed on '{collection_name}': {e}") from e

    async def delete(self, collection_name: str, ids: list[str]) -> None:
        if not ids:
            return
        collection = self._get_collection(collection_name)
        try:
            if len(ids) <= 10:
                for id_ in ids:
                    await collection.data.delete_by_id(id_)
            else:
                await collection.data.delete_many(
                    where=Filter.by_id().contains_any(ids)
                )
        except Exception as e:
            raise StorageError(f"Weaviate delete failed on '{collection_name}': {e}") from e

    async def search_by_query_hash(
        self, collection_name: str, query_hash: str
    ) -> CacheResult | None:
        collection = self._get_collection(collection_name)
        try:
            result = await collection.query.fetch_objects(
                filters=Filter.by_property("query_hash").equal(query_hash),
                limit=1,
            )
        except Exception as e:
            raise StorageError(
                f"Weaviate search_by_query_hash failed on '{collection_name}': {e}"
            ) from e
        if not result.objects:
            return None
        return _obj_to_result(result.objects[0], 1.0)

    async def update_usage_count(self, collection_name: str, id_: str) -> None:
        collection = self._get_collection(collection_name)
        try:
            obj = await collection.query.fetch_object_by_id(id_)
            if obj is None:
                return
            new_count = int(obj.properties.get("usage_count", 0)) + 1
            await collection.data.update(uuid=id_, properties={"usage_count": new_count})
        except Exception as e:
            raise StorageError(
                f"Weaviate update_usage_count failed on '{collection_name}': {e}"
            ) from e

    async def find_expired(self, collection_name: str) -> list[str]:
        collection = self._get_collection(collection_name)
        now = _now_utc()
        try:
            result = await collection.query.fetch_objects(
                filters=(
                    Filter.by_property("expires_at").is_none(False)
                    & Filter.by_property("expires_at").less_than(now)
                ),
            )
        except Exception as e:
            raise StorageError(f"Weaviate find_expired failed on '{collection_name}': {e}") from e
        return [str(obj.uuid) for obj in result.objects]

    async def search_by_normalized_question(
        self, collection_name: str, normalized_question: str
    ) -> CacheResult | None:
        collection = self._get_collection(collection_name)
        try:
            result = await collection.query.fetch_objects(
                filters=Filter.by_property("normalized_question").equal(normalized_question),
                limit=1,
            )
        except Exception as e:
            raise StorageError(
                f"Weaviate search_by_normalized_question failed on '{collection_name}': {e}"
            ) from e
        if not result.objects:
            return None
        return _obj_to_result(result.objects[0], 1.0)

    async def find_by_query_hash(self, collection_name: str, query_hash: str) -> list[str]:
        collection = self._get_collection(collection_name)
        try:
            result = await collection.query.fetch_objects(
                filters=Filter.by_property("query_hash").equal(query_hash),
            )
        except Exception as e:
            raise StorageError(
                f"Weaviate find_by_query_hash failed on '{collection_name}': {e}"
            ) from e
        return [str(obj.uuid) for obj in result.objects]

    async def find_by_template_id(self, collection_name: str, template_id: str) -> list[str]:
        collection = self._get_collection(collection_name)
        try:
            result = await collection.query.fetch_objects(
                filters=Filter.by_property("template_id").equal(template_id),
            )
        except Exception as e:
            raise StorageError(
                f"Weaviate find_by_template_id failed on '{collection_name}': {e}"
            ) from e
        return [str(obj.uuid) for obj in result.objects]

    async def drop_collection(self, collection_name: str) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        wv_name = _wv_collection_name(self._settings.weaviate_collection_prefix, collection_name)
        try:
            await self._client.collections.delete(wv_name)
            self._collections.pop(collection_name, None)
        except Exception as e:
            raise StorageError(
                f"Weaviate drop_collection failed on '{collection_name}': {e}"
            ) from e

    async def close(self) -> None:
        if self._client is not None:
            try:
                await self._client.close()
            except Exception:
                pass
        self._client = None
        self._collections.clear()

RedisVectorBackend

Bases: VectorStorageBackend

Redis Stack (RediSearch) vector backend. Supports standalone and sentinel modes.

Source code in src/medha/backends/redis_vector.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
class RedisVectorBackend(VectorStorageBackend):
    """Redis Stack (RediSearch) vector backend. Supports standalone and sentinel modes."""

    def __init__(self, settings: Any = None) -> None:
        if not HAS_REDIS:
            raise ConfigurationError(
                "redis backend requires 'redis[hiredis]>=4.6'. "
                "Install with: pip install medha-archai[redis]"
            )
        from medha.config import Settings
        self._settings = settings or Settings()
        self._client: Any = None

    async def connect(self) -> None:
        mode = self._settings.redis_mode
        s = self._settings
        try:
            if mode == "sentinel":
                hosts = []
                for h in s.redis_sentinel_hosts:
                    if ":" in h:
                        host, port_str = h.rsplit(":", 1)
                        hosts.append((host, int(port_str)))
                    else:
                        hosts.append((h, 26379))
                sentinel = Sentinel(hosts)
                self._client = sentinel.master_for(s.redis_sentinel_master)
            else:
                ssl_params: dict[str, Any] = {}
                if s.redis_ssl:
                    ssl_params = {
                        "ssl": True,
                        "ssl_certfile": s.redis_ssl_certfile,
                        "ssl_keyfile": s.redis_ssl_keyfile,
                        "ssl_ca_certs": s.redis_ssl_ca_certs,
                    }
                common = {
                    "socket_timeout": s.redis_socket_timeout,
                    "socket_connect_timeout": s.redis_socket_connect_timeout,
                    **ssl_params,
                }
                if s.redis_url:
                    self._client = aioredis.from_url(s.redis_url, **common)
                else:
                    kwargs: dict[str, Any] = {
                        "host": s.redis_host,
                        "port": s.redis_port,
                        "db": s.redis_db,
                        **common,
                    }
                    if s.redis_username:
                        kwargs["username"] = s.redis_username
                    if s.redis_password:
                        kwargs["password"] = s.redis_password.get_secret_value()
                    self._client = Redis(**kwargs)
            await self._client.ping()
        except Exception as e:
            raise StorageInitializationError(f"Failed to connect to Redis ({mode}): {e}") from e

    def _build_schema(self, dimension: int) -> list[Any]:
        s = self._settings
        algo = s.redis_index_algorithm
        if algo == "HNSW":
            vec_attrs: dict[str, Any] = {
                "TYPE": "FLOAT32",
                "DIM": dimension,
                "DISTANCE_METRIC": "COSINE",
                "M": s.redis_hnsw_m,
                "EF_CONSTRUCTION": s.redis_hnsw_ef_construction,
                "EF_RUNTIME": s.redis_hnsw_ef_runtime,
            }
        else:
            vec_attrs = {
                "TYPE": "FLOAT32",
                "DIM": dimension,
                "DISTANCE_METRIC": "COSINE",
            }
        return [
            TextField("original_question"),
            TextField("generated_query"),
            TagField("normalized_question", separator="|"),
            TagField("query_hash", separator="|"),
            TagField("template_id", separator="|"),
            NumericField("usage_count"),
            NumericField("created_at"),
            NumericField("expires_at"),
            VectorField("vector", algo, vec_attrs),
        ]

    async def initialize(self, collection_name: str, dimension: int, **kwargs: Any) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        idx = _index_name(self._settings.redis_key_prefix, collection_name)
        col_key = _key_prefix(self._settings.redis_key_prefix, collection_name)
        try:
            await self._client.ft(idx).info()
            return  # index already exists
        except Exception as e:
            if "unknown index name" not in str(e).lower() and "no such index" not in str(e).lower():
                # not an "index not found" error — re-raise only if it's truly unexpected
                # Some redis versions say "Unknown Index name" so we check case-insensitively
                err_lower = str(e).lower()
                if "unknown" not in err_lower and "index" not in err_lower:
                    raise StorageInitializationError(
                        f"Redis initialize failed on '{collection_name}': {e}"
                    ) from e
        try:
            schema = self._build_schema(dimension)
            definition = IndexDefinition(prefix=[f"{col_key}:"], index_type=IndexType.HASH)
            await self._client.ft(idx).create_index(schema, definition=definition)
            logger.debug("Created Redis index '%s' for collection '%s'", idx, collection_name)
        except Exception as e:
            err_lower = str(e).lower()
            if "index already exists" in err_lower:
                return
            raise StorageInitializationError(
                f"Redis initialize failed on '{collection_name}': {e}"
            ) from e

    async def search(
        self,
        collection_name: str,
        vector: list[float],
        limit: int = 5,
        score_threshold: float = 0.0,
    ) -> list[CacheResult]:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        idx = _index_name(self._settings.redis_key_prefix, collection_name)
        try:
            import numpy as np
            vec_bytes = np.array(vector, dtype=np.float32).tobytes()
            now_ts = int(time.time())
            ttl_filter = f"(@expires_at:[0 0] | @expires_at:[({now_ts} +inf])"
            q = (
                Query(f"({ttl_filter})=>[KNN {limit} @vector $vec AS __score]")
                .sort_by("__score", asc=True)
                .return_fields(
                    "original_question", "normalized_question", "generated_query",
                    "query_hash", "response_summary", "template_id",
                    "usage_count", "created_at", "expires_at", "__score",
                )
                .paging(0, limit)
                .dialect(2)
            )
            result = await self._client.ft(idx).search(q, query_params={"vec": vec_bytes})
        except Exception as e:
            import redis as redis_lib
            if isinstance(e, redis_lib.exceptions.ResponseError):
                raise StorageError(
                    f"Redis search failed on '{collection_name}': {e}. "
                    "Did you call initialize()?"
                ) from e
            raise StorageError(f"Redis search failed on '{collection_name}': {e}") from e

        out = []
        for doc in result.docs:
            raw_score = getattr(doc, "__score", "1.0")
            try:
                dist = float(raw_score)
            except (ValueError, TypeError):
                dist = 1.0
            score = max(0.0, min(1.0, 1.0 - dist))
            if score >= score_threshold:
                out.append(_doc_to_result(doc, score))
        return out

    async def upsert(self, collection_name: str, entries: list[CacheEntry]) -> None:
        if not entries:
            return
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        col_key = _key_prefix(self._settings.redis_key_prefix, collection_name)
        try:
            import numpy as np
            pipe = self._client.pipeline(transaction=False)
            for entry in entries:
                vec_bytes = np.array(entry.vector, dtype=np.float32).tobytes()
                expires_at = (
                    str(int(entry.expires_at.timestamp())) if entry.expires_at else "0"
                )
                created_at = (
                    str(int(entry.created_at.timestamp())) if entry.created_at else "0"
                )
                mapping: dict[str, Any] = {
                    "original_question": entry.original_question,
                    "normalized_question": entry.normalized_question,
                    "generated_query": entry.generated_query,
                    "query_hash": entry.query_hash,
                    "response_summary": entry.response_summary or "",
                    "template_id": entry.template_id or "",
                    "usage_count": entry.usage_count,
                    "created_at": created_at,
                    "expires_at": expires_at,
                    "vector": vec_bytes,
                }
                pipe.hset(f"{col_key}:{entry.id}", mapping=mapping)
            await pipe.execute()
        except Exception as e:
            raise StorageError(f"Redis upsert failed on '{collection_name}': {e}") from e

    async def scroll(
        self,
        collection_name: str,
        limit: int = 100,
        offset: str | None = None,
        with_vectors: bool = False,
    ) -> tuple[list[CacheResult], str | None]:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        idx = _index_name(self._settings.redis_key_prefix, collection_name)
        int_offset = int(offset) if offset else 0
        return_fields = [
            "original_question", "normalized_question", "generated_query",
            "query_hash", "response_summary", "template_id",
            "usage_count", "created_at", "expires_at",
        ]
        if with_vectors:
            return_fields.append("vector")
        try:
            q = (
                Query("*")
                .sort_by("created_at", asc=True)
                .return_fields(*return_fields)
                .paging(int_offset, limit)
            )
            result = await self._client.ft(idx).search(q)
        except Exception as e:
            raise StorageError(f"Redis scroll failed on '{collection_name}': {e}") from e

        items = result.docs
        cache_results = [_doc_to_result(doc, 1.0) for doc in items]
        next_offset = str(int_offset + len(items)) if len(items) == limit else None
        return cache_results, next_offset

    async def count(self, collection_name: str) -> int:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        idx = _index_name(self._settings.redis_key_prefix, collection_name)
        try:
            info = await self._client.ft(idx).info()
            return int(info.get("num_docs", 0))
        except Exception as e:
            raise StorageError(f"Redis count failed on '{collection_name}': {e}") from e

    async def delete(self, collection_name: str, ids: list[str]) -> None:
        if not ids:
            return
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        col_key = _key_prefix(self._settings.redis_key_prefix, collection_name)
        try:
            pipe = self._client.pipeline(transaction=False)
            for id_ in ids:
                pipe.delete(f"{col_key}:{id_}")
            await pipe.execute()
        except Exception as e:
            raise StorageError(f"Redis delete failed on '{collection_name}': {e}") from e

    async def search_by_query_hash(
        self, collection_name: str, query_hash: str
    ) -> CacheResult | None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        idx = _index_name(self._settings.redis_key_prefix, collection_name)
        try:
            escaped = _escape_tag(query_hash)
            q = (
                Query(f"@query_hash:{{{escaped}}}")
                .return_fields(
                    "original_question", "normalized_question", "generated_query",
                    "query_hash", "response_summary", "template_id",
                    "usage_count", "created_at", "expires_at",
                )
                .paging(0, 1)
                .dialect(2)
            )
            result = await self._client.ft(idx).search(q)
        except Exception as e:
            raise StorageError(
                f"Redis search_by_query_hash failed on '{collection_name}': {e}"
            ) from e
        if not result.docs:
            return None
        return _doc_to_result(result.docs[0], 1.0)

    async def update_usage_count(self, collection_name: str, id_: str) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        col_key = _key_prefix(self._settings.redis_key_prefix, collection_name)
        key = f"{col_key}:{id_}"
        try:
            exists = await self._client.hexists(key, "original_question")
            if not exists:
                return
            await self._client.hincrby(key, "usage_count", 1)
        except Exception as e:
            raise StorageError(
                f"Redis update_usage_count failed on '{collection_name}': {e}"
            ) from e

    async def find_expired(self, collection_name: str) -> list[str]:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        idx = _index_name(self._settings.redis_key_prefix, collection_name)
        now_ts = int(time.time())
        try:
            q = (
                Query(f"@expires_at:[(0 ({now_ts}]")
                .return_fields("expires_at")
                .paging(0, 10000)
            )
            result = await self._client.ft(idx).search(q)
        except Exception as e:
            raise StorageError(f"Redis find_expired failed on '{collection_name}': {e}") from e
        return [doc.id.rsplit(":", 1)[-1] for doc in result.docs]

    async def search_by_normalized_question(
        self, collection_name: str, normalized_question: str
    ) -> CacheResult | None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        idx = _index_name(self._settings.redis_key_prefix, collection_name)
        try:
            escaped = _escape_tag(normalized_question)
            q = (
                Query(f"@normalized_question:{{{escaped}}}")
                .return_fields(
                    "original_question", "normalized_question", "generated_query",
                    "query_hash", "response_summary", "template_id",
                    "usage_count", "created_at", "expires_at",
                )
                .paging(0, 1)
                .dialect(2)
            )
            result = await self._client.ft(idx).search(q)
        except Exception as e:
            raise StorageError(
                f"Redis search_by_normalized_question failed on '{collection_name}': {e}"
            ) from e
        if not result.docs:
            return None
        return _doc_to_result(result.docs[0], 1.0)

    async def find_by_query_hash(self, collection_name: str, query_hash: str) -> list[str]:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        idx = _index_name(self._settings.redis_key_prefix, collection_name)
        try:
            escaped = _escape_tag(query_hash)
            q = (
                Query(f"@query_hash:{{{escaped}}}")
                .return_fields("query_hash")
                .paging(0, 10000)
                .dialect(2)
            )
            result = await self._client.ft(idx).search(q)
        except Exception as e:
            raise StorageError(
                f"Redis find_by_query_hash failed on '{collection_name}': {e}"
            ) from e
        return [doc.id.rsplit(":", 1)[-1] for doc in result.docs]

    async def find_by_template_id(self, collection_name: str, template_id: str) -> list[str]:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        idx = _index_name(self._settings.redis_key_prefix, collection_name)
        try:
            escaped = _escape_tag(template_id)
            q = (
                Query(f"@template_id:{{{escaped}}}")
                .return_fields("template_id")
                .paging(0, 10000)
                .dialect(2)
            )
            result = await self._client.ft(idx).search(q)
        except Exception as e:
            raise StorageError(
                f"Redis find_by_template_id failed on '{collection_name}': {e}"
            ) from e
        return [doc.id.rsplit(":", 1)[-1] for doc in result.docs]

    async def drop_collection(self, collection_name: str) -> None:
        if self._client is None:
            raise StorageError("Not connected. Call connect() first.")
        idx = _index_name(self._settings.redis_key_prefix, collection_name)
        col_key = _key_prefix(self._settings.redis_key_prefix, collection_name)
        try:
            await self._client.ft(idx).dropindex(delete_documents=True)
        except Exception as e:
            if "unknown index name" not in str(e).lower() and "no such index" not in str(e).lower():
                logger.warning("Redis dropindex warning on '%s': %s", collection_name, e)
            # fallback: scan and delete orphan hashes
            try:
                async for key in self._client.scan_iter(match=f"{col_key}:*", count=100):
                    await self._client.delete(key)
            except Exception as scan_e:
                raise StorageError(
                    f"Redis drop_collection fallback scan failed on '{collection_name}': {scan_e}"
                ) from scan_e

    async def close(self) -> None:
        if self._client is not None:
            try:
                await self._client.aclose()
            except Exception:
                pass
        self._client = None

AzureSearchBackend

Bases: VectorStorageBackend

Azure AI Search backend. Requires azure-search-documents>=11.4,<12.

Source code in src/medha/backends/azure_search.py
class AzureSearchBackend(VectorStorageBackend):
    """Azure AI Search backend. Requires azure-search-documents>=11.4,<12."""

    def __init__(self, settings: Any = None) -> None:
        if not HAS_AZURE_SEARCH:
            raise ConfigurationError(
                "azure-search backend requires 'azure-search-documents>=11.4,<12'. "
                "Install with: pip install medha-archai[azure-search]"
            )
        from medha.config import Settings
        self._settings = settings or Settings()
        self._search_clients: dict[str, AsyncSearchClient] = {}
        self._index_client: AsyncSearchIndexClient | None = None
        self._credential: Any = None

    async def connect(self) -> None:
        endpoint = self._settings.azure_search_endpoint
        api_version = self._settings.azure_search_api_version

        if self._settings.azure_search_api_key is not None:
            self._credential = AzureKeyCredential(
                self._settings.azure_search_api_key.get_secret_value()
            )
        else:
            try:
                from azure.identity.aio import DefaultAzureCredential
                self._credential = DefaultAzureCredential()
            except ImportError as e:
                raise ConfigurationError(
                    "AAD authentication requires 'azure-identity'. "
                    "Install it separately: pip install azure-identity"
                ) from e

        try:
            self._index_client = AsyncSearchIndexClient(
                endpoint, self._credential, api_version=api_version
            )
            [idx async for idx in self._index_client.list_index_names()]
        except ServiceRequestError as e:
            raise StorageInitializationError(
                f"Failed to connect to Azure AI Search (network error): {e}"
            ) from e
        except HttpResponseError as e:
            if e.status_code in (401, 403):
                raise StorageInitializationError(
                    f"Azure AI Search authentication failed (HTTP {e.status_code}): {e}"
                ) from e
            raise StorageInitializationError(
                f"Failed to connect to Azure AI Search: {e}"
            ) from e

    async def initialize(self, collection_name: str, dimension: int, **kwargs: Any) -> None:
        if self._index_client is None:
            raise StorageError("Not connected. Call connect() first.")

        index_name = _az_index_name(collection_name, self._settings.azure_search_index_name)
        api_version = self._settings.azure_search_api_version
        endpoint = self._settings.azure_search_endpoint

        try:
            await self._index_client.get_index(index_name)
            # Index already exists — create client and return
            if collection_name not in self._search_clients:
                self._search_clients[collection_name] = AsyncSearchClient(
                    endpoint, index_name, self._credential, api_version=api_version
                )
            return
        except ResourceNotFoundError:
            pass
        except HttpResponseError as e:
            raise StorageInitializationError(
                f"Failed to check Azure Search index '{index_name}': {e}"
            ) from e

        try:
            vector_search = VectorSearch(
                algorithms=[HnswAlgorithmConfiguration(name="medha-hnsw")],
                profiles=[VectorSearchProfile(
                    name="medha-hnsw-profile",
                    algorithm_configuration_name="medha-hnsw",
                )],
            )
            fields = [
                SimpleField(name="id", type=SearchFieldDataType.String, key=True, filterable=True),
                SearchableField(name="original_question", analyzer_name="standard.lucene"),
                SimpleField(name="normalized_question", type=SearchFieldDataType.String, filterable=True),
                SimpleField(name="generated_query", type=SearchFieldDataType.String),
                SimpleField(name="query_hash", type=SearchFieldDataType.String, filterable=True),
                SimpleField(name="response_summary", type=SearchFieldDataType.String, nullable=True),
                SimpleField(name="template_id", type=SearchFieldDataType.String, filterable=True, nullable=True),
                SimpleField(name="usage_count", type=SearchFieldDataType.Int32, filterable=True),
                SimpleField(name="created_at", type=SearchFieldDataType.DateTimeOffset, filterable=True, sortable=True),
                SimpleField(name="expires_at", type=SearchFieldDataType.DateTimeOffset, filterable=True, nullable=True),
                SearchField(
                    name="vector",
                    type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
                    searchable=True,
                    vector_search_dimensions=dimension,
                    vector_search_profile_name="medha-hnsw-profile",
                ),
            ]
            index = SearchIndex(name=index_name, fields=fields, vector_search=vector_search)
            await self._index_client.create_index(index)
            logger.info("Created Azure Search index '%s'", index_name)
        except HttpResponseError as e:
            raise StorageInitializationError(
                f"Failed to create Azure Search index '{index_name}': {e}"
            ) from e

        self._search_clients[collection_name] = AsyncSearchClient(
            endpoint, index_name, self._credential, api_version=api_version
        )

    def _get_client(self, collection_name: str) -> AsyncSearchClient:
        if collection_name in self._search_clients:
            return self._search_clients[collection_name]
        if self._index_client is not None:
            index_name = _az_index_name(collection_name, self._settings.azure_search_index_name)
            client = AsyncSearchClient(
                self._settings.azure_search_endpoint,
                index_name,
                self._credential,
                api_version=self._settings.azure_search_api_version,
            )
            self._search_clients[collection_name] = client
            return client
        raise StorageError("Not connected.")

    async def search(
        self,
        collection_name: str,
        vector: list[float],
        limit: int = 5,
        score_threshold: float = 0.0,
    ) -> list[CacheResult]:
        client = self._get_client(collection_name)
        now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
        ttl_filter = f"(expires_at eq null) or (expires_at gt {now_iso})"
        vq = VectorizedQuery(
            vector=vector,
            k_nearest_neighbors=limit + self._settings.azure_search_top_k_candidates,
            fields="vector",
        )
        try:
            items: list[CacheResult] = []
            async for result in await client.search(
                search_text=None,
                vector_queries=[vq],
                filter=ttl_filter,
                top=limit,
                select=_SCALAR_FIELDS,
            ):
                score = result.get("@search.score", 0.0)
                if score < score_threshold:
                    continue
                items.append(_doc_to_result(dict(result), score))
            return items
        except HttpResponseError as e:
            raise StorageError(f"Azure Search search failed on '{collection_name}': {e}") from e

    async def upsert(self, collection_name: str, entries: list[CacheEntry]) -> None:
        if not entries:
            return
        client = self._get_client(collection_name)
        docs = [_entry_to_doc(e) for e in entries]
        try:
            await client.merge_or_upload_documents(docs)
        except HttpResponseError as e:
            raise StorageError(f"Azure Search upsert failed on '{collection_name}': {e}") from e

    async def scroll(
        self,
        collection_name: str,
        limit: int = 100,
        offset: str | None = None,
        with_vectors: bool = False,
    ) -> tuple[list[CacheResult], str | None]:
        client = self._get_client(collection_name)
        skip = int(offset) if offset is not None else 0
        select = None if with_vectors else _SCALAR_FIELDS
        try:
            items: list[CacheResult] = []
            async for result in await client.search(
                search_text="*",
                skip=skip,
                top=limit,
                order_by=["created_at asc", "id asc"],
                select=select,
            ):
                items.append(_doc_to_result(dict(result), 1.0))
            next_offset = str(skip + len(items)) if len(items) == limit else None
            return items, next_offset
        except HttpResponseError as e:
            raise StorageError(f"Azure Search scroll failed on '{collection_name}': {e}") from e

    async def count(self, collection_name: str) -> int:
        client = self._get_client(collection_name)
        try:
            return await client.get_document_count()
        except HttpResponseError as e:
            raise StorageError(f"Azure Search count failed on '{collection_name}': {e}") from e

    async def delete(self, collection_name: str, ids: list[str]) -> None:
        if not ids:
            return
        client = self._get_client(collection_name)
        try:
            await client.delete_documents([{"id": id_} for id_ in ids])
        except HttpResponseError as e:
            raise StorageError(f"Azure Search delete failed on '{collection_name}': {e}") from e

    async def search_by_query_hash(
        self, collection_name: str, query_hash: str
    ) -> CacheResult | None:
        client = self._get_client(collection_name)
        filter_expr = f"query_hash eq '{_esc(query_hash)}'"
        try:
            async for result in await client.search(
                search_text=None,
                filter=filter_expr,
                top=1,
                select=_SCALAR_FIELDS,
            ):
                return _doc_to_result(dict(result), 1.0)
            return None
        except HttpResponseError as e:
            raise StorageError(
                f"Azure Search search_by_query_hash failed on '{collection_name}': {e}"
            ) from e

    async def update_usage_count(self, collection_name: str, point_id: str) -> None:
        """Increment usage_count for a document.

        Note: This is not atomic. A race condition exists if two callers update
        the same document concurrently — both may read the same value and write
        the same incremented result, causing one increment to be lost.
        """
        client = self._get_client(collection_name)
        try:
            doc = await client.get_document(key=point_id)
        except ResourceNotFoundError:
            logger.warning(
                "update_usage_count: id '%s' not found in collection '%s'",
                point_id,
                collection_name,
            )
            return
        except HttpResponseError as e:
            raise StorageError(
                f"Azure Search update_usage_count failed on '{collection_name}': {e}"
            ) from e
        try:
            await client.merge_documents([{"id": point_id, "usage_count": doc["usage_count"] + 1}])
        except HttpResponseError as e:
            raise StorageError(
                f"Azure Search update_usage_count merge failed on '{collection_name}': {e}"
            ) from e

    async def find_expired(self, collection_name: str) -> list[str]:
        client = self._get_client(collection_name)
        now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
        filter_expr = f"expires_at ne null and expires_at lt {now_iso}"
        try:
            ids: list[str] = []
            async for result in await client.search(
                search_text=None,
                filter=filter_expr,
                select=["id"],
                top=10000,
            ):
                ids.append(result["id"])
            return ids
        except HttpResponseError as e:
            raise StorageError(
                f"Azure Search find_expired failed on '{collection_name}': {e}"
            ) from e

    async def search_by_normalized_question(
        self, collection_name: str, normalized_question: str
    ) -> CacheResult | None:
        client = self._get_client(collection_name)
        filter_expr = f"normalized_question eq '{_esc(normalized_question)}'"
        try:
            async for result in await client.search(
                search_text=None,
                filter=filter_expr,
                top=1,
                select=_SCALAR_FIELDS,
            ):
                return _doc_to_result(dict(result), 1.0)
            return None
        except HttpResponseError as e:
            raise StorageError(
                f"Azure Search search_by_normalized_question failed on '{collection_name}': {e}"
            ) from e

    async def find_by_query_hash(
        self, collection_name: str, query_hash: str
    ) -> list[str]:
        client = self._get_client(collection_name)
        filter_expr = f"query_hash eq '{_esc(query_hash)}'"
        try:
            ids: list[str] = []
            async for result in await client.search(
                search_text=None,
                filter=filter_expr,
                select=["id"],
                top=10000,
            ):
                ids.append(result["id"])
            return ids
        except HttpResponseError as e:
            raise StorageError(
                f"Azure Search find_by_query_hash failed on '{collection_name}': {e}"
            ) from e

    async def find_by_template_id(
        self, collection_name: str, template_id: str
    ) -> list[str]:
        client = self._get_client(collection_name)
        filter_expr = f"template_id eq '{_esc(template_id)}'"
        try:
            ids: list[str] = []
            async for result in await client.search(
                search_text=None,
                filter=filter_expr,
                select=["id"],
                top=10000,
            ):
                ids.append(result["id"])
            return ids
        except HttpResponseError as e:
            raise StorageError(
                f"Azure Search find_by_template_id failed on '{collection_name}': {e}"
            ) from e

    async def drop_collection(self, collection_name: str) -> None:
        if self._index_client is None:
            raise StorageError("Not connected. Call connect() first.")
        index_name = _az_index_name(collection_name, self._settings.azure_search_index_name)
        try:
            await self._index_client.delete_index(index_name)
            logger.info("Dropped Azure Search index '%s'", index_name)
        except ResourceNotFoundError:
            logger.warning("drop_collection: index '%s' not found", index_name)
        except HttpResponseError as e:
            raise StorageError(
                f"Azure Search drop_collection failed on '{collection_name}': {e}"
            ) from e
        self._search_clients.pop(collection_name, None)

    async def close(self) -> None:
        for client in self._search_clients.values():
            await client.close()
        self._search_clients.clear()
        if self._index_client is not None:
            await self._index_client.close()
            self._index_client = None

update_usage_count(collection_name, point_id) async

Increment usage_count for a document.

Note: This is not atomic. A race condition exists if two callers update the same document concurrently — both may read the same value and write the same incremented result, causing one increment to be lost.

Source code in src/medha/backends/azure_search.py
async def update_usage_count(self, collection_name: str, point_id: str) -> None:
    """Increment usage_count for a document.

    Note: This is not atomic. A race condition exists if two callers update
    the same document concurrently — both may read the same value and write
    the same incremented result, causing one increment to be lost.
    """
    client = self._get_client(collection_name)
    try:
        doc = await client.get_document(key=point_id)
    except ResourceNotFoundError:
        logger.warning(
            "update_usage_count: id '%s' not found in collection '%s'",
            point_id,
            collection_name,
        )
        return
    except HttpResponseError as e:
        raise StorageError(
            f"Azure Search update_usage_count failed on '{collection_name}': {e}"
        ) from e
    try:
        await client.merge_documents([{"id": point_id, "usage_count": doc["usage_count"] + 1}])
    except HttpResponseError as e:
        raise StorageError(
            f"Azure Search update_usage_count merge failed on '{collection_name}': {e}"
        ) from e

LanceDBBackend

Bases: VectorStorageBackend

LanceDB vector backend. Supports local, S3, GCS, and Azure storage.

Uses the native async API (lancedb.connect_async) for non-blocking I/O. Local mode requires no external services; cloud URIs (s3://, gs://, az://) require the appropriate credentials to be set in the environment.

Source code in src/medha/backends/lancedb.py
class LanceDBBackend(VectorStorageBackend):
    """LanceDB vector backend. Supports local, S3, GCS, and Azure storage.

    Uses the native async API (lancedb.connect_async) for non-blocking I/O.
    Local mode requires no external services; cloud URIs (s3://, gs://, az://)
    require the appropriate credentials to be set in the environment.
    """

    def __init__(self, settings: Any = None) -> None:
        if not HAS_LANCEDB:
            raise ConfigurationError(
                "lancedb backend requires 'lancedb>=0.6'. "
                "Install with: pip install medha-archai[lancedb]"
            )
        from medha.config import Settings
        self._settings = settings or Settings()
        self._db: Any = None
        self._tables: dict[str, Any] = {}
        self._dimensions: dict[str, int] = {}

    # ------------------------------------------------------------------
    # Lifecycle
    # ------------------------------------------------------------------

    async def connect(self) -> None:
        uri = self._settings.lancedb_uri
        try:
            self._db = await lancedb.connect_async(uri)
        except Exception as e:
            raise StorageInitializationError(
                f"Failed to connect to LanceDB at '{uri}': {e}"
            ) from e

    async def initialize(self, collection_name: str, dimension: int, **kwargs: Any) -> None:
        if self._db is None:
            raise StorageError("Not connected. Call connect() first.")
        if collection_name in self._tables:
            return
        table_name = self._table_name(collection_name)
        schema = _build_schema(dimension)
        self._dimensions[collection_name] = dimension
        try:
            table = await self._db.create_table(table_name, schema=schema, exist_ok=True)
            self._tables[collection_name] = table
        except Exception as e:
            raise StorageInitializationError(
                f"Failed to initialize LanceDB table '{table_name}': {e}"
            ) from e

    async def close(self) -> None:
        self._tables.clear()
        self._dimensions.clear()
        self._db = None

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _table_name(self, collection_name: str) -> str:
        import re
        prefix = self._settings.lancedb_table_prefix
        safe = re.sub(r"[^a-zA-Z0-9_]", "_", collection_name)
        return f"{prefix}_{safe}" if prefix else safe

    def _get_table(self, collection_name: str) -> Any:
        tbl = self._tables.get(collection_name)
        if tbl is None:
            raise StorageError(
                f"Collection '{collection_name}' not initialized. Call initialize() first."
            )
        return tbl

    # ------------------------------------------------------------------
    # Core operations
    # ------------------------------------------------------------------

    async def search(
        self,
        collection_name: str,
        vector: list[float],
        limit: int = 5,
        score_threshold: float = 0.0,
    ) -> list[CacheResult]:
        table = self._get_table(collection_name)
        now_iso = _now_iso()
        where = f"expires_at = '' OR expires_at > '{now_iso}'"
        metric: str = self._settings.lancedb_metric
        try:
            rows: list[dict[str, Any]] = await (
                table.vector_search(vector)
                .distance_type(metric)
                .where(where)
                .limit(limit)
                .to_list()
            )
        except Exception as e:
            raise StorageError(f"LanceDB search failed on '{collection_name}': {e}") from e

        out: list[CacheResult] = []
        for row in rows:
            score = _distance_to_score(float(row.get("_distance", 0.0)), metric)
            score = max(0.0, min(1.0, score))
            if score >= score_threshold:
                out.append(_row_to_result(row, score))
        return out

    async def upsert(self, collection_name: str, entries: list[CacheEntry]) -> None:
        if not entries:
            return
        table = self._get_table(collection_name)
        rows = [_entry_to_row(e) for e in entries]
        try:
            await (
                table.merge_insert("id")
                .when_matched_update_all()
                .when_not_matched_insert_all()
                .execute(rows)
            )
        except Exception as e:
            raise StorageError(f"LanceDB upsert failed on '{collection_name}': {e}") from e

    async def scroll(
        self,
        collection_name: str,
        limit: int = 100,
        offset: str | None = None,
        with_vectors: bool = False,
    ) -> tuple[list[CacheResult], str | None]:
        table = self._get_table(collection_name)
        int_offset = int(offset) if offset else 0
        columns = None if with_vectors else [
            "id", "original_question", "normalized_question", "generated_query",
            "query_hash", "response_summary", "template_id", "usage_count",
            "created_at", "expires_at",
        ]
        try:
            q = table.query().limit(limit).offset(int_offset)
            if columns is not None:
                q = q.select(columns)
            rows: list[dict[str, Any]] = await q.to_list()
        except Exception as e:
            raise StorageError(f"LanceDB scroll failed on '{collection_name}': {e}") from e

        next_offset = str(int_offset + len(rows)) if len(rows) == limit else None
        return [_row_to_result(row, 1.0) for row in rows], next_offset

    async def count(self, collection_name: str) -> int:
        table = self._get_table(collection_name)
        try:
            return await table.count_rows()
        except Exception as e:
            raise StorageError(f"LanceDB count failed on '{collection_name}': {e}") from e

    async def delete(self, collection_name: str, ids: list[str]) -> None:
        if not ids:
            return
        table = self._get_table(collection_name)
        safe_ids = ", ".join(f"'{id_.replace(chr(39), chr(39) * 2)}'" for id_ in ids)
        try:
            await table.delete(f"id IN ({safe_ids})")
        except Exception as e:
            raise StorageError(f"LanceDB delete failed on '{collection_name}': {e}") from e

    async def find_expired(self, collection_name: str) -> list[str]:
        table = self._get_table(collection_name)
        now_iso = _now_iso()
        try:
            rows: list[dict[str, Any]] = await (
                table.query()
                .where(f"expires_at != '' AND expires_at < '{now_iso}'")
                .select(["id"])
                .to_list()
            )
        except Exception as e:
            raise StorageError(f"LanceDB find_expired failed on '{collection_name}': {e}") from e
        return [row["id"] for row in rows]

    async def search_by_normalized_question(
        self, collection_name: str, normalized_question: str
    ) -> CacheResult | None:
        table = self._get_table(collection_name)
        safe_q = normalized_question.replace("'", "''")
        try:
            rows: list[dict[str, Any]] = await (
                table.query()
                .where(f"normalized_question = '{safe_q}'")
                .limit(1)
                .to_list()
            )
        except Exception as e:
            raise StorageError(
                f"LanceDB search_by_normalized_question failed on '{collection_name}': {e}"
            ) from e
        return _row_to_result(rows[0], 1.0) if rows else None

    async def find_by_query_hash(self, collection_name: str, query_hash: str) -> list[str]:
        table = self._get_table(collection_name)
        safe_hash = query_hash.replace("'", "''")
        try:
            rows: list[dict[str, Any]] = await (
                table.query()
                .where(f"query_hash = '{safe_hash}'")
                .select(["id"])
                .to_list()
            )
        except Exception as e:
            raise StorageError(
                f"LanceDB find_by_query_hash failed on '{collection_name}': {e}"
            ) from e
        return [row["id"] for row in rows]

    async def find_by_template_id(self, collection_name: str, template_id: str) -> list[str]:
        table = self._get_table(collection_name)
        safe_tid = template_id.replace("'", "''")
        try:
            rows: list[dict[str, Any]] = await (
                table.query()
                .where(f"template_id = '{safe_tid}'")
                .select(["id"])
                .to_list()
            )
        except Exception as e:
            raise StorageError(
                f"LanceDB find_by_template_id failed on '{collection_name}': {e}"
            ) from e
        return [row["id"] for row in rows]

    async def drop_collection(self, collection_name: str) -> None:
        if self._db is None:
            raise StorageError("Not connected. Call connect() first.")
        table_name = self._table_name(collection_name)
        try:
            await self._db.drop_table(table_name)
            self._tables.pop(collection_name, None)
            self._dimensions.pop(collection_name, None)
        except Exception as e:
            raise StorageError(
                f"LanceDB drop_collection failed on '{collection_name}': {e}"
            ) from e