You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

176 lines
6.7 KiB

# coding=utf-8
import asyncio
import json
import re
from datetime import datetime
from functools import partial, wraps
from json import JSONDecodeError
from glom import glom
from playwright.async_api import Playwright, Browser, async_playwright
from abs_spider import AbstractAiSeoSpider
from domain.ai_seo import AiAnswer, AiSearchResult
from utils import create_logger, parse_nested_json
logger = create_logger(__name__)
class NanometerSpider(AbstractAiSeoSpider):
def __init__(self, browser: Browser, prompt: str, keyword: str):
super().__init__(browser, prompt, keyword)
self.load_session = False
self.__listen_response = self.handle_listen_response_error(self.__listen_response)
def get_home_url(self) -> str:
return 'https://www.n.cn/'
async def _do_spider(self) -> AiAnswer:
# 初始化数据
self._init_data()
# 开始操作
await self.browser_page.goto(self.get_home_url(), timeout=600000)
#开启深度思考
await self.browser_page.locator('//*[@id="nworld-app-container"]/div/div[1]/div[1]/div/div/div/div/div[2]/div[1]/div[1]/div[2]/div[1]/section/div').click()
chat_input_element = self.browser_page.locator("//textarea[@id='composition-input']")
# 输入提问词
await chat_input_element.fill(self.prompt)
await self.browser_page.keyboard.press('Enter')
# 监听请求
self.browser_page.on('response', partial(self.__listen_response))
await asyncio.sleep(2)
await self.completed_event.wait()
# 报错检查
if self.fail_status:
raise self.fail_exception
# 获取回答元素
answer_element = self.browser_page.locator("//div[@class='js-article-content']").nth(-1)
box = await answer_element.bounding_box()
logger.debug(f'answer_element: {box}')
view_port_height = box['height'] + 500
# 调整视口大小
await self.browser_page.set_viewport_size({
'width': 1920,
'height': int(view_port_height)
})
# 截图
screenshot_path = self._get_screenshot_path()
await self.browser_page.screenshot(path=screenshot_path, full_page=True)
self.ai_answer.screenshot_file = screenshot_path
return self.ai_answer
def __parse_event_data(self, data_str):
# 按照 'id:' 分割文本,去掉第一个空的部分
parts = data_str.strip().split('id:')[1:]
# 初始化结果列表
result = []
# 遍历每个部分,提取数据并存储到字典中
for part in parts:
lines = part.strip().split('\n')
item = {}
for line in lines:
if ':' not in line:
key = 'id'
value = line
else:
key, value = line.split(':', 1)
key = key.strip()
value = value.strip()
if key == 'data':
try:
# 尝试将 data 转换为 JSON 对象
import json
value = json.loads(value)
except JSONDecodeError:
pass
item[key] = value
result.append(item)
return result
async def __listen_response(self, response):
if '/api/common/chat/v2' not in response.url:
return
# 读取流式数据
stream = await response.body()
response_text = stream.decode('utf-8')
datas = self.__parse_event_data(response_text)
answer = ''
search_result_list = list()
# 遍历每行数据
for data in datas:
event = data.get('event', '')
if event == '200':
answer = answer + str(data.get('data', ''))
elif event == '102':
# json格式的返回 要解析数据
data = data.get('data', {})
if isinstance(data, str):
data = parse_nested_json(data)
data_type = data.get('type', '')
if data_type == 'search_result':
search_result_list = glom(data, 'message.list', default=[])
# # 保存搜索数据
# ai_search_result_list = []
# for search_result in search_result_list:
# title = search_result.get('title', '')
# url = search_result.get('url', '')
# body = search_result.get('summary', '')
# host_name = search_result.get('site', '未知')
# publish_time = search_result.get('date', 0)
# ai_search_result_list.append(
# AiSearchResult(title, url, host_name, body, publish_time)
# )
# logger.debug(f"ai参考资料: [{host_name}]{title}({url})")
# self.ai_answer.search_result = ai_search_result_list
pattern = r'\[(\d+)\]'
index_data = list(set(re.findall(pattern, answer)))
ai_search_result_list = []
for index,search_result in enumerate(search_result_list):
title = search_result.get('title', '')
url = search_result.get('url', '')
body = search_result.get('summary', '')
host_name = search_result.get('site', '未知')
publish_time = search_result.get('date', 0)
if str(index+1) in index_data:
is_referenced = "1"
else:
is_referenced = "0"
ai_search_result_list.append(
AiSearchResult(title, url, host_name, body, publish_time,is_referenced)
)
logger.debug(f"ai参考资料: [{host_name}]{title}({url})")
self.ai_answer.search_result = ai_search_result_list
self.ai_answer.answer = answer
logger.debug(f'ai回复: {answer}')
self.completed_event.set()
def get_platform_id(self) -> int:
return 7
def get_platform_name(self) -> str:
return 'Nano'
def handle_listen_response_error(self, func):
"""
装饰器 用于处理请求回调中的异常
:param func:
:return:
"""
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
logger.error(f"{self.get_platform_name()}响应异常: {e}", exc_info=True)
# 标记失败状态 记录异常
self.fail_status = True
self.fail_exception = e
self.completed_event.set()
return wrapper