You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							174 lines
						
					
					
						
							6.5 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							174 lines
						
					
					
						
							6.5 KiB
						
					
					
				| # coding=utf-8 | |
| import asyncio | |
| import json | |
| import re | |
| from datetime import datetime | |
| from functools import partial, wraps | |
| from json import JSONDecodeError | |
| 
 | |
| from glom import glom | |
| from playwright.async_api import Playwright, Browser, async_playwright | |
| 
 | |
| from abs_spider import AbstractAiSeoSpider | |
| from domain.ai_seo import AiAnswer, AiSearchResult | |
| from utils import create_logger, parse_nested_json | |
| 
 | |
| 
 | |
| logger = create_logger(__name__) | |
| 
 | |
| class NanometerSpider(AbstractAiSeoSpider): | |
| 
 | |
|     def __init__(self, browser: Browser, prompt: str, keyword: str): | |
|         super().__init__(browser, prompt, keyword) | |
|         self.load_session = False | |
|         self.__listen_response = self.handle_listen_response_error(self.__listen_response) | |
| 
 | |
|     def get_home_url(self) -> str: | |
|         return 'https://www.n.cn/' | |
| 
 | |
|     async def _do_spider(self) -> AiAnswer: | |
|         # 初始化数据 | |
|         self._init_data() | |
|         # 开始操作 | |
|         await self.browser_page.goto(self.get_home_url(), timeout=600000) | |
|         chat_input_element = self.browser_page.locator("//textarea[@id='composition-input']") | |
|         # 输入提问词 | |
|         await chat_input_element.fill(self.prompt) | |
|         await self.browser_page.keyboard.press('Enter') | |
|         # 监听请求 | |
|         self.browser_page.on('response', partial(self.__listen_response)) | |
|         await asyncio.sleep(2) | |
|         await self.completed_event.wait() | |
| 
 | |
|         # 报错检查 | |
|         if self.fail_status: | |
|             raise self.fail_exception | |
| 
 | |
|         # 获取回答元素 | |
|         answer_element = self.browser_page.locator("//div[@class='js-article-content']").nth(-1) | |
|         box = await answer_element.bounding_box() | |
|         logger.debug(f'answer_element: {box}') | |
|         view_port_height = box['height'] + 500 | |
|         # 调整视口大小 | |
|         await self.browser_page.set_viewport_size({ | |
|             'width': 1920, | |
|             'height': int(view_port_height) | |
|         }) | |
|         # 截图 | |
|         screenshot_path = self._get_screenshot_path() | |
|         await self.browser_page.screenshot(path=screenshot_path, full_page=True) | |
|         self.ai_answer.screenshot_file = screenshot_path | |
|         return self.ai_answer | |
| 
 | |
|     def __parse_event_data(self, data_str): | |
|         # 按照 'id:' 分割文本,去掉第一个空的部分 | |
|         parts = data_str.strip().split('id:')[1:] | |
| 
 | |
|         # 初始化结果列表 | |
|         result = [] | |
| 
 | |
|         # 遍历每个部分,提取数据并存储到字典中 | |
|         for part in parts: | |
|             lines = part.strip().split('\n') | |
|             item = {} | |
|             for line in lines: | |
|                 if ':' not in line: | |
|                     key = 'id' | |
|                     value = line | |
|                 else: | |
|                     key, value = line.split(':', 1) | |
|                     key = key.strip() | |
|                     value = value.strip() | |
|                 if key == 'data': | |
|                     try: | |
|                         # 尝试将 data 转换为 JSON 对象 | |
|                         import json | |
|                         value = json.loads(value) | |
|                     except JSONDecodeError: | |
|                         pass | |
|                 item[key] = value | |
|             result.append(item) | |
|         return result | |
| 
 | |
|     async def __listen_response(self, response): | |
|         if '/api/common/chat/v2' not in response.url: | |
|             return | |
|             # 读取流式数据 | |
|         stream = await response.body() | |
|         response_text = stream.decode('utf-8') | |
|         datas = self.__parse_event_data(response_text) | |
|         answer = '' | |
|         search_result_list = list() | |
|         # 遍历每行数据 | |
|         for data in datas: | |
|             event = data.get('event', '') | |
|             if event == '200': | |
|                 answer = answer + str(data.get('data', '')) | |
|             elif event == '102': | |
|                 # json格式的返回 要解析数据 | |
|                 data = data.get('data', {}) | |
|                 if isinstance(data, str): | |
|                     data = parse_nested_json(data) | |
|                 data_type = data.get('type', '') | |
|                 if data_type == 'search_result': | |
|                     search_result_list = glom(data, 'message.list', default=[]) | |
|                     # # 保存搜索数据 | |
|                     # ai_search_result_list = [] | |
|                     # for search_result in search_result_list: | |
|                     #     title = search_result.get('title', '') | |
|                     #     url = search_result.get('url', '') | |
|                     #     body = search_result.get('summary', '') | |
|                     #     host_name = search_result.get('site', '未知') | |
|                     #     publish_time = search_result.get('date', 0) | |
|                     #     ai_search_result_list.append( | |
|                     #         AiSearchResult(title, url, host_name, body, publish_time) | |
|                     #     ) | |
|                     #     logger.debug(f"ai参考资料: [{host_name}]{title}({url})") | |
|                     # self.ai_answer.search_result = ai_search_result_list | |
|         pattern = r'\[(\d+)\]' | |
|         index_data = list(set(re.findall(pattern, answer))) | |
|         ai_search_result_list = [] | |
|         for index,search_result in enumerate(search_result_list): | |
|             title = search_result.get('title', '') | |
|             url = search_result.get('url', '') | |
|             body = search_result.get('summary', '') | |
|             host_name = search_result.get('site', '未知') | |
|             publish_time = search_result.get('date', 0) | |
|             if str(index+1) in index_data: | |
|                 is_referenced = "1" | |
|             else: | |
|                 is_referenced = "0" | |
|             ai_search_result_list.append( | |
|                 AiSearchResult(title, url, host_name, body, publish_time,is_referenced) | |
|             ) | |
|             logger.debug(f"ai参考资料: [{host_name}]{title}({url})") | |
|         self.ai_answer.search_result = ai_search_result_list | |
|         self.ai_answer.answer = answer | |
|         logger.debug(f'ai回复: {answer}') | |
|         self.completed_event.set() | |
| 
 | |
|     def get_platform_id(self) -> int: | |
|         return 7 | |
| 
 | |
|     def get_platform_name(self) -> str: | |
|         return 'Nano' | |
| 
 | |
|     def handle_listen_response_error(self, func): | |
|         """ | |
|         装饰器 用于处理请求回调中的异常 | |
|         :param func: | |
|         :return: | |
|         """ | |
| 
 | |
|         @wraps(func) | |
|         async def wrapper(*args, **kwargs): | |
|             try: | |
|                 return await func(*args, **kwargs) | |
|             except Exception as e: | |
|                 logger.error(f"{self.get_platform_name()}响应异常: {e}", exc_info=True) | |
|                 # 标记失败状态 记录异常 | |
|                 self.fail_status = True | |
|                 self.fail_exception = e | |
|                 self.completed_event.set() | |
| 
 | |
|         return wrapper |