You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

175 lines
6.7 KiB

# coding=utf-8
import asyncio
import re
from datetime import datetime
from functools import partial, wraps
from playwright.async_api import Playwright, Browser, async_playwright
from abs_spider import AbstractAiSeoSpider
from domain.ai_seo import AiAnswer, AiSearchResult
from utils import create_logger
from glom import glom, Coalesce
from utils.image_utils import crop_image_left
logger = create_logger(__name__)
class YuanBaoSpider(AbstractAiSeoSpider):
def __init__(self, browser: Browser, prompt: str, keyword: str, think: bool = False):
super().__init__(browser, prompt, keyword, think)
self.__listen_response = self.handle_listen_response_error(self.__listen_response)
def get_home_url(self) -> str:
return 'https://yuanbao.tencent.com/'
async def _do_spider(self) -> AiAnswer:
# 初始化数据
self._init_data()
self.is_get_detail = False
await self.browser_page.goto(self.get_home_url(), timeout=600000)
await asyncio.sleep(2)
# 开启深度思考
if self.think:
think_button = self.browser_page.locator("//button[@dt-button-id='deep_think']")
if await think_button.is_visible():
model_id = await think_button.get_attribute('dt-model-id')
if not model_id == 'deep_seek':
await think_button.click()
await asyncio.sleep(2)
# 开启联网搜索
search_button = self.browser_page.locator("//div[@dt-button-id='online_search']")
if await search_button.is_visible():
class_str = await search_button.get_attribute('class')
clazz = class_str.split(' ')
if 'checked' not in clazz:
logger.debug('未开启联网搜索')
await search_button.click()
await asyncio.sleep(1)
# 开始操作
chat_input_element = self.browser_page.locator("//div[contains(@class, 'chat-input-editor')]")
await chat_input_element.click()
# 输入提问词
await self.browser_page.keyboard.type(self.prompt)
await asyncio.sleep(2)
await self.browser_page.keyboard.press('Enter')
# 监听请求
self.browser_page.on('response', partial(self.__listen_response))
await self.completed_event.wait()
# 报错检查
if self.fail_status:
raise self.fail_exception
# # 获取回答元素
answer_element = self.browser_page.locator("//div[@class='agent-chat__list__item__content']").nth(-1)
box = await answer_element.bounding_box()
logger.debug(f'answer_element: {box}')
view_port_height = box['height'] + 500
# # 调整视口大小
await self.browser_page.set_viewport_size({
'width': 1920,
'height': int(view_port_height)
})
# 收起侧栏
# await self.browser_page.locator("//div[@data-desc='fold']").click()
# 打开联网搜索结果
search_list_element = self.browser_page.locator("(//div[contains(@data-title, '资料作为参考')])[1]/span")
if await search_list_element.is_visible():
await search_list_element.click()
# 截图
screenshot_path = self._get_screenshot_path()
await self.browser_page.screenshot(path=screenshot_path)
crop_image_left(screenshot_path, 260)
self.ai_answer.screenshot_file = screenshot_path
return self.ai_answer
async def __listen_response(self, response):
if '/agent/conversation/v1/detail' not in response.url or self.is_get_detail:
return
json_data = await response.json()
# 取值key
if not json_data['convs']:
return
convs = json_data['convs']
content = {}
for conv in convs:
key = 'speechesV2.0.content'
content = glom(conv, key, default=[])
if len(content) > 1:
break
# 循环获取content中的内容
search_list = None
think = None
text = None
for item in content:
if item['type'] == 'text':
text = item.get('msg', '')
elif item['type'] == 'searchGuid':
search_list = item.get('docs', [])
elif item['type'] == 'think':
think = item.get('content', '')
logger.debug(f'ai回复内容: {text}')
ai_search_result_list = []
self.ai_answer.answer = text
if search_list:
# pattern = r'\[\^(\d+)\]'
pattern = r'citation:(\d+)'
index_data = list(set(re.findall(pattern, self.ai_answer.answer)))
for index,search_result in enumerate(search_list):
if str(index+1) in index_data:
ai_search_result_list.append(
AiSearchResult(
title=search_result.get('title', ''),
url=search_result.get('url', ''),
host_name=search_result.get('web_site_name', ''),
body=search_result.get('quote', ''),
publish_time=search_result.get('publish_time', 0),
is_referenced = "1"
)
)
else:
ai_search_result_list.append(
AiSearchResult(
title=search_result.get('title', ''),
url=search_result.get('url', ''),
host_name=search_result.get('web_site_name', ''),
body=search_result.get('quote', ''),
publish_time=search_result.get('publish_time', 0),
is_referenced = "0"
)
)
logger.debug(f'ai参考资料: {search_list}')
self.ai_answer.search_result = ai_search_result_list
self.is_get_detail = True
self.completed_event.set()
def handle_listen_response_error(self, func):
"""
装饰器 用于处理请求回调中的异常
:param func:
:return:
"""
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
logger.error(f"{self.get_platform_name()}响应异常: {e}", exc_info=True)
# 标记失败状态 记录异常
self.fail_status = True
self.fail_exception = e
self.completed_event.set()
return wrapper
def get_platform_id(self) -> int:
return 3
def get_platform_name(self) -> str:
return 'YuanBao'