You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
222 lines
9.4 KiB
222 lines
9.4 KiB
# coding=utf-8
|
|
import asyncio
|
|
import json
|
|
from functools import partial, wraps
|
|
from json import JSONDecodeError
|
|
from glom import glom
|
|
from playwright.async_api import Browser
|
|
|
|
from abs_spider import AbstractAiSeoSpider
|
|
from domain.ai_seo import AiAnswer, AiSearchResult
|
|
from utils import create_logger, css_to_dict
|
|
from utils.image_utils import crop_image_left
|
|
import re
|
|
logger = create_logger(__name__)
|
|
|
|
class DeepseekSpider(AbstractAiSeoSpider):
|
|
|
|
def __init__(self, browser: Browser, prompt: str, keyword: str):
|
|
super().__init__(browser, prompt, keyword)
|
|
self.__listen_response = self.handle_listen_response_error(self.__listen_response)
|
|
|
|
def get_home_url(self) -> str:
|
|
return 'https://chat.deepseek.com/'
|
|
|
|
def get_platform_id(self) -> int:
|
|
return 1
|
|
|
|
def get_platform_name(self) -> str:
|
|
return 'DeepSeek'
|
|
|
|
async def _do_spider(self) -> AiAnswer:
|
|
self._init_data()
|
|
self.search_result_count = 0
|
|
await self.browser_page.goto(self.get_home_url(), timeout=600000)
|
|
await asyncio.sleep(3)
|
|
# 开启联网搜索
|
|
search_btn = self.browser_page.locator("span:text('联网搜索')").locator('..')
|
|
if await search_btn.is_visible():
|
|
await search_btn.click()
|
|
self.think = True
|
|
if self.think:
|
|
# 开启深度思考
|
|
think_btn = self.browser_page.locator("span:text('深度思考')").locator('..')
|
|
if await think_btn.is_visible():
|
|
# styles = css_to_dict(await think_btn.get_attribute('style'))
|
|
# if styles.get('--ds-button-color') == '#fff':
|
|
await think_btn.click()
|
|
await asyncio.sleep(1)
|
|
chat_input_element = self.browser_page.locator("//textarea[@placeholder='给 DeepSeek 发送消息 ']")
|
|
await chat_input_element.click()
|
|
# 输入提问词
|
|
await self.browser_page.keyboard.type(self.prompt)
|
|
await asyncio.sleep(1)
|
|
await self.browser_page.keyboard.press('Enter')
|
|
# 监听请求
|
|
self.browser_page.on('response', partial(self.__listen_response))
|
|
await self.completed_event.wait()
|
|
# 报错检查
|
|
if self.fail_status:
|
|
raise self.fail_exception
|
|
# 打开搜索栏
|
|
search_btn_text = f'已搜索到 {self.search_result_count} 个网页'
|
|
search_btn = self.browser_page.locator(f"div:text('{search_btn_text}')")
|
|
# search_btn = self.browser_page.locator('div:has-text("搜索到")')
|
|
if await search_btn.count() > 0:
|
|
await search_btn.click()
|
|
await asyncio.sleep(2)
|
|
if self.think:
|
|
# 思考元素
|
|
think_element = self.browser_page.locator("text=已深度思考(")
|
|
think_element_count = await think_element.count()
|
|
if think_element_count > 0:
|
|
await think_element.nth(-1).click()
|
|
await asyncio.sleep(2)
|
|
# 获取回答元素
|
|
# answer = self.browser_page.locator("//div[@class='ds-markdown ds-markdown--block']").nth(-1)
|
|
answer = self.browser_page.locator("//div[contains(@class, 'ds-message')]").nth(-1)
|
|
box = await answer.bounding_box()
|
|
# 设置视口大小
|
|
await self.browser_page.set_viewport_size({
|
|
'width': 1920,
|
|
'height': int(box['height']) + 500
|
|
})
|
|
# 截图
|
|
screenshot_path = self._get_screenshot_path()
|
|
await self.browser_page.screenshot(path=screenshot_path)
|
|
# 切割图片
|
|
crop_image_left(screenshot_path, 250)
|
|
self.ai_answer.screenshot_file = screenshot_path
|
|
return self.ai_answer
|
|
|
|
async def do_check_session(self) -> bool:
|
|
try:
|
|
await self.browser_page.goto(self.get_home_url(), timeout=600000)
|
|
await asyncio.sleep(3)
|
|
chat_input_element = self.browser_page.locator("//textarea[@id='chat-input']")
|
|
await chat_input_element.click()
|
|
# 输入提问词
|
|
await self.browser_page.keyboard.type(self.prompt)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def handle_listen_response_error(self, func):
|
|
"""
|
|
装饰器 用于处理请求回调中的异常
|
|
:param func:
|
|
:return:
|
|
"""
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except Exception as e:
|
|
logger.error(f"DeepSeek响应异常: {e}", exc_info=True)
|
|
# 标记失败状态 记录异常
|
|
self.fail_status = True
|
|
self.fail_exception = e
|
|
self.completed_event.set()
|
|
return wrapper
|
|
|
|
async def __listen_response(self, response):
|
|
if '/api/v0/chat/completion' not in response.url:
|
|
return
|
|
# 读取流式数据
|
|
response_text = ''
|
|
thinking_text = ''
|
|
search_result_lists = list()
|
|
start_content = False
|
|
start_thinking = False
|
|
stream = await response.body()
|
|
body = stream.decode('utf-8')
|
|
datas = body.split("\n\n")
|
|
for data_str in datas:
|
|
# 返回数据为空 跳过
|
|
if not data_str:
|
|
continue
|
|
data_str = data_str.replace('data: ', '')
|
|
# 服务器繁忙 跳过
|
|
try:
|
|
data = json.loads(data_str)
|
|
if glom(data, 'v.0.v', default='') == 'TIMEOUT':
|
|
self.fail_status = True
|
|
logger.error("DeepSeek服务器繁忙")
|
|
except JSONDecodeError as e:
|
|
continue
|
|
# 获取ai搜索结果
|
|
if data.get('p', '') == 'response/search_results' or isinstance(data.get('v', ''), list):
|
|
logger.debug(f"获取到联网搜索结果")
|
|
search_result_list = data.get('v', [])
|
|
search_result_lists.extend(search_result_list)
|
|
# 保存搜索结果
|
|
ai_search_result_list = []
|
|
for search_result in search_result_list:
|
|
url = search_result.get('url', '')
|
|
title = search_result.get('title', '')
|
|
body = search_result.get('snippet', '')
|
|
publish_time = search_result.get('published_at', '')
|
|
host_name = search_result.get('site_name', '未知')
|
|
ai_result = AiSearchResult(url=url, title=title, body=body, publish_time=publish_time, host_name=host_name)
|
|
if ai_result.title and ai_result.url:
|
|
ai_search_result_list.append(ai_result)
|
|
logger.debug(f"ai参考资料: [{host_name}]{title}({url})")
|
|
if ai_search_result_list:
|
|
self.ai_answer.search_result = ai_search_result_list
|
|
self.search_result_count = len(self.ai_answer.search_result)
|
|
continue
|
|
|
|
# 是否开始返回深度思考数据
|
|
if data.get('p', '') == 'response/fragments/1/content':
|
|
start_thinking = True
|
|
if data.get('p', '') == 'response/fragments/1/elapsed_secs':
|
|
start_thinking = False
|
|
if start_thinking:
|
|
# 获取深度思考回复
|
|
value = data.get('v', None)
|
|
if isinstance(value, dict):
|
|
continue
|
|
if value is None:
|
|
target = 'choices.0.delta.content'
|
|
value = glom(data, target, default="")
|
|
thinking_text = thinking_text + str(value)
|
|
# 是否开始返回回复数据
|
|
if data.get('p', '') == 'response/fragments/2/content':
|
|
start_content = True
|
|
if start_content:
|
|
# 获取ai回复
|
|
value = data.get('v', None)
|
|
if isinstance(value, dict):
|
|
continue
|
|
if value is None:
|
|
target = 'choices.0.delta.content'
|
|
value = glom(data, target, default="")
|
|
response_text = response_text + str(value)
|
|
#匹配citation:中的数字
|
|
citation = list()
|
|
citations = re.findall(r'citation:(\d+)', response_text)
|
|
if citations:
|
|
citation = list(set(citations))
|
|
# 保存搜索结果
|
|
ai_search_result_list = []
|
|
for index,search_result in enumerate(search_result_lists):
|
|
url = search_result.get('url', '')
|
|
title = search_result.get('title', '')
|
|
body = search_result.get('snippet', '')
|
|
publish_time = search_result.get('published_at', '')
|
|
host_name = search_result.get('site_name', '未知')
|
|
if str(index+1) in citation:
|
|
is_referenced = "1"
|
|
else:
|
|
is_referenced = "0"
|
|
ai_result = AiSearchResult(url=url, title=title, body=body, publish_time=publish_time, host_name=host_name, is_referenced=is_referenced)
|
|
if ai_result.title and ai_result.url:
|
|
ai_search_result_list.append(ai_result)
|
|
logger.debug(f"ai参考资料: [{host_name}]{title}({url})")
|
|
if ai_search_result_list:
|
|
self.ai_answer.search_result = ai_search_result_list
|
|
self.search_result_count = len(self.ai_answer.search_result)
|
|
logger.debug(response_text)
|
|
self.ai_answer.answer = response_text
|
|
self.ai_answer.thinking = thinking_text
|
|
self.completed_event.set()
|