数据场景检索

🔍 模块目标:掌握多模态数据场景检索技术,实现智能场景匹配和内容发现

🌟 场景检索概述

数据场景检索是指根据查询条件从大规模数据集中快速准确地找到相关场景或内容的技术。它结合了计算机视觉、自然语言处理、多媒体检索等多个领域的技术,在智能搜索、内容推荐、场景理解等应用中发挥重要作用。

🏗️ 检索系统架构

系统组件

import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import cv2
import torch
from sentence_transformers import SentenceTransformer
import faiss
import json
from datetime import datetime

class MultiModalRetrievalSystem:
    def __init__(self):
        self.text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
        self.image_features = {}
        self.text_features = {}
        self.metadata = {}
        self.indexes = {}
        
    def initialize_indexes(self):
        """初始化各种检索索引"""
        # 文本索引
        self.indexes['text'] = faiss.IndexFlatIP(384)  # 384维度
        
        # 图像索引
        self.indexes['image'] = faiss.IndexFlatIP(2048)  # ResNet特征维度
        
        # 混合索引
        self.indexes['hybrid'] = faiss.IndexFlatIP(512)  # 混合特征维度
    
    def add_data(self, data_id, text_content=None, image_path=None, metadata=None):
        """添加数据到检索系统"""
        features = {}
        
        # 处理文本特征
        if text_content:
            text_embedding = self.text_encoder.encode([text_content])
            features['text'] = text_embedding[0]
            self.text_features[data_id] = text_embedding[0]
        
        # 处理图像特征
        if image_path:
            image_features = self.extract_image_features(image_path)
            features['image'] = image_features
            self.image_features[data_id] = image_features
        
        # 存储元数据
        if metadata:
            self.metadata[data_id] = metadata
        
        return features
    
    def extract_image_features(self, image_path):
        """提取图像特征"""
        # 使用预训练的ResNet模型提取特征
        import torchvision.models as models
        import torchvision.transforms as transforms
        from PIL import Image
        
        # 加载预训练模型
        model = models.resnet50(pretrained=True)
        model.eval()
        
        # 移除最后的分类层
        feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])
        
        # 图像预处理
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        # 加载和处理图像
        image = Image.open(image_path).convert('RGB')
        image_tensor = transform(image).unsqueeze(0)
        
        # 提取特征
        with torch.no_grad():
            features = feature_extractor(image_tensor)
            features = features.squeeze().numpy()
        
        return features
    
    def build_indexes(self):
        """构建所有索引"""
        # 构建文本索引
        if self.text_features:
            text_matrix = np.vstack(list(self.text_features.values()))
            self.indexes['text'].add(text_matrix.astype('float32'))
        
        # 构建图像索引
        if self.image_features:
            image_matrix = np.vstack(list(self.image_features.values()))
            self.indexes['image'].add(image_matrix.astype('float32'))
    
    def search_text(self, query, k=10):
        """文本检索"""
        query_embedding = self.text_encoder.encode([query])
        query_vector = query_embedding[0].astype('float32').reshape(1, -1)
        
        scores, indices = self.indexes['text'].search(query_vector, k)
        
        results = []
        data_ids = list(self.text_features.keys())
        
        for score, idx in zip(scores[0], indices[0]):
            if idx < len(data_ids):
                data_id = data_ids[idx]
                results.append({
                    'data_id': data_id,
                    'score': float(score),
                    'metadata': self.metadata.get(data_id, {})
                })
        
        return results
    
    def search_image(self, image_path, k=10):
        """图像检索"""
        query_features = self.extract_image_features(image_path)
        query_vector = query_features.astype('float32').reshape(1, -1)
        
        scores, indices = self.indexes['image'].search(query_vector, k)
        
        results = []
        data_ids = list(self.image_features.keys())
        
        for score, idx in zip(scores[0], indices[0]):
            if idx < len(data_ids):
                data_id = data_ids[idx]
                results.append({
                    'data_id': data_id,
                    'score': float(score),
                    'metadata': self.metadata.get(data_id, {})
                })
        
        return results
    
    def hybrid_search(self, text_query=None, image_path=None, text_weight=0.5, k=10):
        """混合检索"""
        combined_scores = {}
        
        # 文本检索结果
        if text_query:
            text_results = self.search_text(text_query, k*2)
            for result in text_results:
                data_id = result['data_id']
                combined_scores[data_id] = result['score'] * text_weight
        
        # 图像检索结果
        if image_path:
            image_results = self.search_image(image_path, k*2)
            for result in image_results:
                data_id = result['data_id']
                if data_id in combined_scores:
                    combined_scores[data_id] += result['score'] * (1 - text_weight)
                else:
                    combined_scores[data_id] = result['score'] * (1 - text_weight)
        
        # 排序并返回结果
        sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
        
        results = []
        for data_id, score in sorted_results[:k]:
            results.append({
                'data_id': data_id,
                'score': score,
                'metadata': self.metadata.get(data_id, {})
            })
        
        return results

场景理解与索引

class SceneUnderstandingModule:
    def __init__(self):
        self.scene_classifier = None
        self.object_detector = None
        self.caption_generator = None
        
    def analyze_scene(self, image_path):
        """全面场景分析"""
        analysis = {
            'scene_category': self.classify_scene(image_path),
            'objects': self.detect_objects(image_path),
            'spatial_layout': self.analyze_spatial_layout(image_path),
            'visual_attributes': self.extract_visual_attributes(image_path),
            'scene_caption': self.generate_caption(image_path)
        }
        return analysis
    
    def classify_scene(self, image_path):
        """场景分类"""
        # 使用预训练的场景分类模型
        categories = {
            'indoor': ['bedroom', 'kitchen', 'living_room', 'office', 'bathroom'],
            'outdoor': ['street', 'park', 'beach', 'mountain', 'urban'],
            'nature': ['forest', 'lake', 'desert', 'field', 'sky']
        }
        
        # 这里应该调用实际的场景分类模型
        # 示例返回
        return {
            'main_category': 'outdoor',
            'sub_category': 'street',
            'confidence': 0.85
        }
    
    def detect_objects(self, image_path):
        """物体检测"""
        # 使用YOLO或其他物体检测模型
        detected_objects = [
            {'class': 'car', 'confidence': 0.95, 'bbox': [100, 150, 300, 400]},
            {'class': 'person', 'confidence': 0.88, 'bbox': [50, 100, 150, 350]},
            {'class': 'tree', 'confidence': 0.75, 'bbox': [400, 50, 600, 300]}
        ]
        return detected_objects
    
    def analyze_spatial_layout(self, image_path):
        """空间布局分析"""
        layout = {
            'dominant_regions': ['sky', 'ground', 'buildings'],
            'perspective': 'street_level',
            'depth_information': {
                'foreground': ['person', 'car'],
                'background': ['buildings', 'sky']
            }
        }
        return layout
    
    def extract_visual_attributes(self, image_path):
        """提取视觉属性"""
        attributes = {
            'color_palette': ['blue', 'gray', 'green'],
            'lighting': 'daylight',
            'weather': 'clear',
            'season': 'summer',
            'time_of_day': 'afternoon'
        }
        return attributes
    
    def generate_caption(self, image_path):
        """生成场景描述"""
        # 使用图像描述生成模型
        caption = "A busy street scene with cars and pedestrians during daytime"
        return caption
    
    def create_scene_index(self, scene_analysis):
        """创建场景检索索引"""
        index_data = {
            'scene_tags': [
                scene_analysis['scene_category']['main_category'],
                scene_analysis['scene_category']['sub_category']
            ],
            'object_tags': [obj['class'] for obj in scene_analysis['objects']],
            'attribute_tags': [
                scene_analysis['visual_attributes']['lighting'],
                scene_analysis['visual_attributes']['weather'],
                scene_analysis['visual_attributes']['time_of_day']
            ],
            'text_description': scene_analysis['scene_caption']
        }
        return index_data

🔍 高级检索技术

语义检索

from transformers import CLIPModel, CLIPProcessor
import torch

class SemanticRetrievalEngine:
    def __init__(self):
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.clip_model.to(self.device)
        
    def encode_text(self, text_list):
        """编码文本为向量"""
        inputs = self.clip_processor(text=text_list, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            text_features = self.clip_model.get_text_features(**inputs)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
        return text_features.cpu().numpy()
    
    def encode_image(self, image_list):
        """编码图像为向量"""
        inputs = self.clip_processor(images=image_list, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            image_features = self.clip_model.get_image_features(**inputs)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        
        return image_features.cpu().numpy()
    
    def compute_similarity(self, text_features, image_features):
        """计算文本-图像相似度"""
        similarity = torch.matmul(
            torch.tensor(text_features), 
            torch.tensor(image_features).T
        )
        return similarity.numpy()
    
    def cross_modal_search(self, query_text, image_database, top_k=10):
        """跨模态检索"""
        # 编码查询文本
        text_features = self.encode_text([query_text])
        
        # 编码图像数据库
        image_features = self.encode_image(image_database)
        
        # 计算相似度
        similarities = self.compute_similarity(text_features, image_features)
        
        # 获取top-k结果
        top_indices = np.argsort(similarities[0])[::-1][:top_k]
        
        results = []
        for idx in top_indices:
            results.append({
                'image_index': idx,
                'similarity_score': similarities[0][idx],
                'image_path': image_database[idx] if isinstance(image_database[idx], str) else None
            })
        
        return results

基于图结构的检索

import networkx as nx
from sklearn.feature_extraction.text import TfidfVectorizer

class GraphBasedRetrieval:
    def __init__(self):
        self.knowledge_graph = nx.Graph()
        self.entity_embeddings = {}
        self.relation_types = ['similar_to', 'contains', 'located_in', 'related_to']
        
    def build_knowledge_graph(self, entities, relations):
        """构建知识图谱"""
        # 添加实体节点
        for entity in entities:
            self.knowledge_graph.add_node(
                entity['id'], 
                type=entity['type'],
                attributes=entity.get('attributes', {})
            )
        
        # 添加关系边
        for relation in relations:
            self.knowledge_graph.add_edge(
                relation['source'],
                relation['target'],
                relation_type=relation['type'],
                weight=relation.get('weight', 1.0)
            )
    
    def entity_expansion(self, query_entities, max_hops=2):
        """实体扩展"""
        expanded_entities = set(query_entities)
        
        for hop in range(max_hops):
            current_entities = expanded_entities.copy()
            for entity in current_entities:
                if entity in self.knowledge_graph:
                    neighbors = list(self.knowledge_graph.neighbors(entity))
                    expanded_entities.update(neighbors)
        
        return list(expanded_entities)
    
    def graph_based_ranking(self, query_entities, candidate_entities):
        """基于图的排序"""
        scores = {}
        
        for candidate in candidate_entities:
            score = 0
            
            for query_entity in query_entities:
                if query_entity in self.knowledge_graph and candidate in self.knowledge_graph:
                    try:
                        # 计算最短路径长度
                        path_length = nx.shortest_path_length(
                            self.knowledge_graph, query_entity, candidate
                        )
                        # 距离越近分数越高
                        score += 1.0 / (1 + path_length)
                    except nx.NetworkXNoPath:
                        # 没有路径连接
                        continue
            
            scores[candidate] = score
        
        # 按分数排序
        ranked_entities = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        return ranked_entities
    
    def personalized_pagerank(self, query_entities, alpha=0.15):
        """个性化PageRank"""
        # 创建个性化向量
        personalization = {}
        for node in self.knowledge_graph.nodes():
            if node in query_entities:
                personalization[node] = 1.0 / len(query_entities)
            else:
                personalization[node] = 0.0
        
        # 计算个性化PageRank
        pagerank_scores = nx.pagerank(
            self.knowledge_graph,
            personalization=personalization,
            alpha=alpha
        )
        
        return pagerank_scores

深度学习检索

import torch
import torch.nn as nn
import torch.nn.functional as F

class DeepRetrievalModel(nn.Module):
    def __init__(self, text_dim=768, image_dim=2048, hidden_dim=512):
        super(DeepRetrievalModel, self).__init__()
        
        # 文本编码器
        self.text_encoder = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 图像编码器
        self.image_encoder = nn.Sequential(
            nn.Linear(image_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 交互层
        self.interaction_layer = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=8,
            batch_first=True
        )
        
        # 相似度预测层
        self.similarity_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, text_features, image_features):
        # 编码
        text_encoded = self.text_encoder(text_features)
        image_encoded = self.image_encoder(image_features)
        
        # 交互注意力
        text_attended, _ = self.interaction_layer(
            text_encoded.unsqueeze(1),
            image_encoded.unsqueeze(1),
            image_encoded.unsqueeze(1)
        )
        text_attended = text_attended.squeeze(1)
        
        image_attended, _ = self.interaction_layer(
            image_encoded.unsqueeze(1),
            text_encoded.unsqueeze(1),
            text_encoded.unsqueeze(1)
        )
        image_attended = image_attended.squeeze(1)
        
        # 融合特征
        fused_features = torch.cat([text_attended, image_attended], dim=-1)
        
        # 预测相似度
        similarity = self.similarity_predictor(fused_features)
        
        return similarity

class DeepRetrievalTrainer:
    def __init__(self, model, learning_rate=0.001):
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        self.criterion = nn.BCELoss()
        
    def train_step(self, text_features, image_features, labels):
        self.optimizer.zero_grad()
        
        predictions = self.model(text_features, image_features)
        loss = self.criterion(predictions.squeeze(), labels.float())
        
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def evaluate(self, test_loader):
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for text_features, image_features, labels in test_loader:
                predictions = self.model(text_features, image_features)
                loss = self.criterion(predictions.squeeze(), labels.float())
                
                total_loss += loss.item()
                predicted = (predictions.squeeze() > 0.5).long()
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        avg_loss = total_loss / len(test_loader)
        
        return avg_loss, accuracy

🎯 特定场景应用

智能监控检索

class SecurityRetrievalSystem:
    def __init__(self):
        self.alert_patterns = {
            'suspicious_behavior': ['loitering', 'running', 'fighting'],
            'object_detection': ['weapon', 'bag', 'vehicle'],
            'crowd_analysis': ['gathering', 'dispersal', 'density']
        }
        
    def analyze_security_footage(self, video_path, time_range=None):
        """分析安防视频"""
        analysis_results = []
        
        # 视频帧提取和分析
        cap = cv2.VideoCapture(video_path)
        frame_count = 0
        
        while cap.read()[0]:
            ret, frame = cap.read()
            if not ret:
                break
                
            frame_count += 1
            
            # 每隔30帧分析一次(约1秒)
            if frame_count % 30 == 0:
                timestamp = frame_count / 30
                
                # 行为分析
                behaviors = self.detect_behaviors(frame)
                
                # 物体检测
                objects = self.detect_suspicious_objects(frame)
                
                # 人群分析
                crowd_info = self.analyze_crowd(frame)
                
                analysis_results.append({
                    'timestamp': timestamp,
                    'behaviors': behaviors,
                    'objects': objects,
                    'crowd_info': crowd_info,
                    'alert_level': self.calculate_alert_level(behaviors, objects, crowd_info)
                })
        
        cap.release()
        return analysis_results
    
    def detect_behaviors(self, frame):
        """检测可疑行为"""
        # 这里应该调用实际的行为识别模型
        behaviors = [
            {'type': 'walking', 'confidence': 0.9, 'person_id': 1},
            {'type': 'loitering', 'confidence': 0.7, 'person_id': 2}
        ]
        return behaviors
    
    def detect_suspicious_objects(self, frame):
        """检测可疑物体"""
        objects = [
            {'type': 'bag', 'confidence': 0.8, 'bbox': [100, 100, 200, 200]},
            {'type': 'person', 'confidence': 0.95, 'bbox': [50, 50, 150, 300]}
        ]
        return objects
    
    def analyze_crowd(self, frame):
        """人群分析"""
        crowd_info = {
            'person_count': 15,
            'density': 'medium',
            'movement_direction': 'north',
            'activity_level': 'normal'
        }
        return crowd_info
    
    def calculate_alert_level(self, behaviors, objects, crowd_info):
        """计算警报级别"""
        alert_score = 0
        
        # 行为评分
        for behavior in behaviors:
            if behavior['type'] in self.alert_patterns['suspicious_behavior']:
                alert_score += behavior['confidence'] * 2
        
        # 物体评分
        for obj in objects:
            if obj['type'] in self.alert_patterns['object_detection']:
                alert_score += obj['confidence'] * 3
        
        # 人群评分
        if crowd_info['density'] == 'high':
            alert_score += 1
        
        # 确定警报级别
        if alert_score > 3:
            return 'high'
        elif alert_score > 1.5:
            return 'medium'
        else:
            return 'low'
    
    def query_incidents(self, query_type, time_range, location=None):
        """查询历史事件"""
        # 这里应该连接到事件数据库
        incidents = [
            {
                'timestamp': '2024-01-15 14:30:00',
                'type': 'suspicious_behavior',
                'location': 'entrance_gate',
                'alert_level': 'medium',
                'description': 'Person loitering near entrance for extended period'
            },
            {
                'timestamp': '2024-01-15 16:45:00',
                'type': 'object_detection',
                'location': 'parking_lot',
                'alert_level': 'high',
                'description': 'Unattended bag detected in parking area'
            }
        ]
        
        return incidents

电商场景检索

class EcommerceSceneRetrieval:
    def __init__(self):
        self.product_attributes = ['color', 'style', 'brand', 'category', 'price_range']
        self.scene_contexts = ['indoor', 'outdoor', 'casual', 'formal', 'seasonal']
        
    def visual_search(self, query_image_path, filters=None):
        """视觉搜索商品"""
        # 提取查询图像特征
        query_features = self.extract_product_features(query_image_path)
        
        # 场景理解
        scene_context = self.understand_scene_context(query_image_path)
        
        # 商品检索
        similar_products = self.find_similar_products(query_features, scene_context, filters)
        
        return similar_products
    
    def extract_product_features(self, image_path):
        """提取商品特征"""
        features = {
            'color_palette': ['red', 'black', 'white'],
            'style_tags': ['casual', 'modern'],
            'category': 'clothing',
            'visual_features': np.random.rand(512)  # 实际应该是深度学习特征
        }
        return features
    
    def understand_scene_context(self, image_path):
        """理解场景上下文"""
        context = {
            'setting': 'outdoor',
            'occasion': 'casual',
            'season': 'summer',
            'weather': 'sunny'
        }
        return context
    
    def find_similar_products(self, query_features, scene_context, filters=None):
        """查找相似商品"""
        # 模拟商品数据库查询
        products = [
            {
                'product_id': 'P001',
                'name': 'Summer Casual T-Shirt',
                'category': 'clothing',
                'price': 29.99,
                'colors': ['red', 'blue', 'white'],
                'style_tags': ['casual', 'summer'],
                'similarity_score': 0.95
            },
            {
                'product_id': 'P002',
                'name': 'Outdoor Adventure Jacket',
                'category': 'clothing',
                'price': 89.99,
                'colors': ['black', 'green'],
                'style_tags': ['outdoor', 'adventure'],
                'similarity_score': 0.82
            }
        ]
        
        # 应用过滤条件
        if filters:
            filtered_products = []
            for product in products:
                if self.apply_filters(product, filters):
                    filtered_products.append(product)
            products = filtered_products
        
        # 按相似度排序
        products.sort(key=lambda x: x['similarity_score'], reverse=True)
        
        return products
    
    def apply_filters(self, product, filters):
        """应用搜索过滤条件"""
        for filter_type, filter_value in filters.items():
            if filter_type == 'price_range':
                min_price, max_price = filter_value
                if not (min_price <= product['price'] <= max_price):
                    return False
            elif filter_type == 'category':
                if product['category'] != filter_value:
                    return False
            elif filter_type == 'color':
                if filter_value not in product['colors']:
                    return False
        
        return True
    
    def recommendation_by_scene(self, user_profile, current_scene):
        """基于场景的商品推荐"""
        recommendations = []
        
        # 根据场景推荐
        if current_scene['setting'] == 'outdoor':
            if current_scene['season'] == 'summer':
                recommendations.extend(['sunglasses', 'hat', 'sunscreen'])
            elif current_scene['season'] == 'winter':
                recommendations.extend(['jacket', 'gloves', 'boots'])
        
        # 根据用户历史偏好调整
        if user_profile.get('preferred_brands'):
            # 过滤推荐品牌
            pass
        
        return recommendations

📊 性能优化

索引优化

class IndexOptimizer:
    def __init__(self):
        self.index_types = ['flat', 'ivf', 'hnsw', 'pq']
        
    def benchmark_indexes(self, data, queries, index_configs):
        """基准测试不同索引"""
        results = {}
        
        for config in index_configs:
            index_type = config['type']
            params = config.get('params', {})
            
            # 构建索引
            start_time = time.time()
            index = self.build_index(data, index_type, params)
            build_time = time.time() - start_time
            
            # 测试查询性能
            start_time = time.time()
            search_results = []
            for query in queries:
                results_per_query = index.search(query, k=10)
                search_results.append(results_per_query)
            search_time = time.time() - start_time
            
            # 计算准确率
            accuracy = self.calculate_accuracy(search_results, ground_truth=None)
            
            results[index_type] = {
                'build_time': build_time,
                'search_time': search_time,
                'accuracy': accuracy,
                'memory_usage': self.get_memory_usage(index)
            }
        
        return results
    
    def build_index(self, data, index_type, params):
        """构建指定类型的索引"""
        dim = data.shape[1]
        
        if index_type == 'flat':
            index = faiss.IndexFlatIP(dim)
        elif index_type == 'ivf':
            nlist = params.get('nlist', 100)
            quantizer = faiss.IndexFlatIP(dim)
            index = faiss.IndexIVFFlat(quantizer, dim, nlist)
        elif index_type == 'hnsw':
            m = params.get('m', 16)
            index = faiss.IndexHNSWFlat(dim, m)
        elif index_type == 'pq':
            m = params.get('m', 8)
            nbits = params.get('nbits', 8)
            index = faiss.IndexPQ(dim, m, nbits)
        
        index.add(data.astype('float32'))
        return index
    
    def optimize_query_processing(self, queries):
        """优化查询处理"""
        # 查询缓存
        query_cache = {}
        
        # 批量处理
        batch_size = 32
        batched_queries = [queries[i:i+batch_size] for i in range(0, len(queries), batch_size)]
        
        # 查询重写和扩展
        optimized_queries = []
        for query in queries:
            if query in query_cache:
                optimized_queries.append(query_cache[query])
            else:
                expanded_query = self.expand_query(query)
                query_cache[query] = expanded_query
                optimized_queries.append(expanded_query)
        
        return optimized_queries
    
    def expand_query(self, query):
        """查询扩展"""
        # 同义词扩展
        synonyms = self.get_synonyms(query)
        
        # 相关词扩展
        related_terms = self.get_related_terms(query)
        
        expanded_query = {
            'original': query,
            'synonyms': synonyms,
            'related_terms': related_terms
        }
        
        return expanded_query

缓存策略

import redis
import pickle
from functools import wraps

class RetrievalCache:
    def __init__(self, redis_host='localhost', redis_port=6379):
        self.redis_client = redis.Redis(host=redis_host, port=redis_port, decode_responses=False)
        self.default_ttl = 3600  # 1小时
        
    def cache_key(self, query, params=None):
        """生成缓存键"""
        import hashlib
        key_data = f"{query}_{params}"
        return hashlib.md5(key_data.encode()).hexdigest()
    
    def get_cached_result(self, query, params=None):
        """获取缓存结果"""
        key = self.cache_key(query, params)
        cached_data = self.redis_client.get(key)
        
        if cached_data:
            return pickle.loads(cached_data)
        return None
    
    def cache_result(self, query, result, params=None, ttl=None):
        """缓存结果"""
        key = self.cache_key(query, params)
        ttl = ttl or self.default_ttl
        
        self.redis_client.setex(key, ttl, pickle.dumps(result))
    
    def cache_decorator(self, ttl=None):
        """缓存装饰器"""
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                # 生成缓存键
                cache_key = f"{func.__name__}_{str(args)}_{str(kwargs)}"
                
                # 尝试获取缓存
                cached_result = self.redis_client.get(cache_key)
                if cached_result:
                    return pickle.loads(cached_result)
                
                # 执行函数并缓存结果
                result = func(*args, **kwargs)
                self.redis_client.setex(
                    cache_key, 
                    ttl or self.default_ttl, 
                    pickle.dumps(result)
                )
                
                return result
            return wrapper
        return decorator

# 使用示例
cache = RetrievalCache()

@cache.cache_decorator(ttl=7200)  # 缓存2小时
def expensive_search_operation(query, filters):
    # 执行耗时的搜索操作
    time.sleep(2)  # 模拟耗时操作
    return f"Result for {query} with filters {filters}"

📈 评估指标

检索性能评估

class RetrievalEvaluator:
    def __init__(self):
        pass
    
    def precision_at_k(self, retrieved_items, relevant_items, k):
        """计算P@K"""
        retrieved_k = retrieved_items[:k]
        relevant_retrieved = len(set(retrieved_k) & set(relevant_items))
        return relevant_retrieved / k if k > 0 else 0
    
    def recall_at_k(self, retrieved_items, relevant_items, k):
        """计算R@K"""
        retrieved_k = retrieved_items[:k]
        relevant_retrieved = len(set(retrieved_k) & set(relevant_items))
        return relevant_retrieved / len(relevant_items) if len(relevant_items) > 0 else 0
    
    def average_precision(self, retrieved_items, relevant_items):
        """计算平均精度"""
        if not relevant_items:
            return 0
        
        precision_sum = 0
        relevant_count = 0
        
        for i, item in enumerate(retrieved_items):
            if item in relevant_items:
                relevant_count += 1
                precision_sum += relevant_count / (i + 1)
        
        return precision_sum / len(relevant_items)
    
    def mean_average_precision(self, all_retrieved, all_relevant):
        """计算MAP"""
        ap_scores = []
        for retrieved, relevant in zip(all_retrieved, all_relevant):
            ap = self.average_precision(retrieved, relevant)
            ap_scores.append(ap)
        
        return np.mean(ap_scores)
    
    def normalized_dcg(self, retrieved_items, relevance_scores, k=None):
        """计算NDCG"""
        if k:
            retrieved_items = retrieved_items[:k]
            relevance_scores = relevance_scores[:k]
        
        # 计算DCG
        dcg = 0
        for i, item in enumerate(retrieved_items):
            if item in relevance_scores:
                relevance = relevance_scores[item]
                dcg += (2**relevance - 1) / np.log2(i + 2)
        
        # 计算IDCG
        ideal_relevance = sorted(relevance_scores.values(), reverse=True)
        idcg = 0
        for i, relevance in enumerate(ideal_relevance):
            idcg += (2**relevance - 1) / np.log2(i + 2)
        
        return dcg / idcg if idcg > 0 else 0
    
    def evaluate_retrieval_system(self, test_queries, ground_truth, system_results):
        """综合评估检索系统"""
        metrics = {
            'precision_at_1': [],
            'precision_at_5': [],
            'precision_at_10': [],
            'recall_at_10': [],
            'map': [],
            'ndcg_at_10': []
        }
        
        for query_id, retrieved in system_results.items():
            if query_id in ground_truth:
                relevant = ground_truth[query_id]
                
                # 计算各种指标
                metrics['precision_at_1'].append(
                    self.precision_at_k(retrieved, relevant, 1)
                )
                metrics['precision_at_5'].append(
                    self.precision_at_k(retrieved, relevant, 5)
                )
                metrics['precision_at_10'].append(
                    self.precision_at_k(retrieved, relevant, 10)
                )
                metrics['recall_at_10'].append(
                    self.recall_at_k(retrieved, relevant, 10)
                )
                metrics['map'].append(
                    self.average_precision(retrieved, relevant)
                )
        
        # 计算平均值
        final_metrics = {}
        for metric_name, values in metrics.items():
            final_metrics[metric_name] = np.mean(values)
        
        return final_metrics

🔗 导航链接