You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
182 lines
7.7 KiB
182 lines
7.7 KiB
# -*- coding: utf-8 -*-
|
|
import asyncio
|
|
from functools import partial, wraps
|
|
from json import JSONDecodeError
|
|
|
|
import ftfy
|
|
from glom import glom
|
|
from playwright.async_api import Browser
|
|
|
|
from abs_spider import AbstractAiSeoSpider
|
|
from domain.ai_seo import AiAnswer, AiSearchResult
|
|
from utils import create_logger, parse_nested_json
|
|
|
|
logger = create_logger(__name__)
|
|
|
|
|
|
class DouBaoSpider(AbstractAiSeoSpider):
|
|
|
|
def __init__(self, browser: Browser, prompt: str, keyword: str, think: bool = False):
|
|
super().__init__(browser, prompt, keyword, think)
|
|
self.__listen_response = self.handle_listen_response_error(self.__listen_response)
|
|
|
|
def get_home_url(self) -> str:
|
|
return 'https://www.doubao.com/chat'
|
|
|
|
async def _do_spider(self) -> AiAnswer:
|
|
# 初始化信息
|
|
self._init_data()
|
|
await self.browser_page.goto(self.get_home_url(), timeout=600000)
|
|
await asyncio.sleep(3)
|
|
if self.think:
|
|
think_btn = self.browser_page.locator("//button[@title='深度思考']")
|
|
if await think_btn.is_visible():
|
|
clazz = (await think_btn.get_attribute('class')).split(' ')
|
|
# 找出点击的class名称
|
|
target_class = [c for c in clazz if c.startswith("active-")]
|
|
if not target_class:
|
|
await think_btn.click()
|
|
await asyncio.sleep(2)
|
|
# 开始操作
|
|
chat_input_element = self.browser_page.locator("//textarea[@data-testid='chat_input_input']")
|
|
# 输入提问词
|
|
await chat_input_element.fill(self.prompt)
|
|
await self.browser_page.keyboard.press('Enter')
|
|
# 监听请求
|
|
self.browser_page.on('response', partial(self.__listen_response))
|
|
await asyncio.sleep(2)
|
|
await self.completed_event.wait()
|
|
|
|
# 报错检查
|
|
if self.fail_status:
|
|
raise self.fail_exception
|
|
|
|
# 关闭侧边栏
|
|
sider_bar_element = self.browser_page.locator("//button[@data-testid='siderbar_close_btn']")
|
|
if await sider_bar_element.is_visible():
|
|
await sider_bar_element.click()
|
|
|
|
# 资料弹出框
|
|
search_result_popup_element = self.browser_page.locator("//div[contains(@class, 'search-item-transition-')]")
|
|
# 资料按钮
|
|
search_result_btn_list = self.browser_page.locator("//div[contains(@class, 'entry-btn-')]")
|
|
if await search_result_btn_list.count() > 0 and not await search_result_popup_element.count() > 0:
|
|
await search_result_btn_list.nth(-1).click()
|
|
await asyncio.sleep(2)
|
|
# 搜索结果元素
|
|
search_result_element_list = self.browser_page.locator("//a[contains(@class, 'search-')]")
|
|
ai_search_result_list = []
|
|
if await search_result_element_list.count() > 0:
|
|
for index,search_result_element in enumerate(await search_result_element_list.all()):
|
|
url = await search_result_element.get_attribute('href')
|
|
title = ''
|
|
desc = ''
|
|
host_name = ''
|
|
title_element = search_result_element.locator("xpath=.//div[contains(@class, 'search-item-title-')]")
|
|
desc_element = search_result_element.locator("xpath=.//div[contains(@class, 'search-item-summary-')]")
|
|
host_name_element = search_result_element.locator("xpath=.//span[contains(@class, 'footer-title-')]")
|
|
# 获取标题
|
|
if await title_element.is_visible():
|
|
title = await title_element.inner_text()
|
|
# 获取描述
|
|
if await desc_element.is_visible():
|
|
desc = await desc_element.inner_text()
|
|
# 获取来源
|
|
if await host_name_element.is_visible():
|
|
host_name = await host_name_element.inner_text()
|
|
if index+1 in self.index_data:
|
|
is_referenced = "1"
|
|
else:
|
|
is_referenced = "0"
|
|
ai_search_result_list.append(AiSearchResult(
|
|
title=title,
|
|
url=url,
|
|
host_name=host_name,
|
|
body=desc,
|
|
is_referenced=is_referenced
|
|
))
|
|
logger.debug(f'搜索结果: [{host_name}]{title}({url})')
|
|
self.ai_answer.search_result = ai_search_result_list
|
|
# 获取回答元素
|
|
answer_element = self.browser_page.locator("//div[@data-testid='receive_message']").nth(-1)
|
|
box = await answer_element.bounding_box()
|
|
logger.debug(f'answer_element: {box}')
|
|
view_port_height = box['height'] + 500
|
|
# 调整视口大小
|
|
await self.browser_page.set_viewport_size({
|
|
'width': 1920,
|
|
'height': int(view_port_height)
|
|
})
|
|
# 截图
|
|
screenshot_path = self._get_screenshot_path()
|
|
await self.browser_page.screenshot(path=screenshot_path, full_page=True)
|
|
self.ai_answer.screenshot_file = screenshot_path
|
|
return self.ai_answer
|
|
|
|
def handle_listen_response_error(self, func):
|
|
"""
|
|
装饰器 用于处理请求回调中的异常
|
|
:param func:
|
|
:return:
|
|
"""
|
|
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except Exception as e:
|
|
logger.error(f"{self.get_platform_name()}响应异常: {e}", exc_info=True)
|
|
# 标记失败状态 记录异常
|
|
self.fail_status = True
|
|
self.fail_exception = e
|
|
self.completed_event.set()
|
|
|
|
return wrapper
|
|
|
|
async def __listen_response(self, response):
|
|
# 读取流式数据
|
|
if '/samantha/chat/completion' in response.url:
|
|
answer = ''
|
|
datas = []
|
|
index_data = list()
|
|
logger.debug(f"await data: {await response.text()}")
|
|
response_text = ftfy.fix_text(await response.text())
|
|
logger.debug(f"response_text: {response_text}")
|
|
lines = response_text.split("\n\n")
|
|
for line in lines:
|
|
if line.startswith('data: '):
|
|
line = line[6:]
|
|
logger.debug(f"line_1: {line}")
|
|
import re
|
|
pattern = r'"[\u4e00-\u9fa5]{1,10}"'
|
|
result = re.findall(pattern, line)
|
|
for i in result:
|
|
line = line.replace(i,i[1:-1])
|
|
try:
|
|
data = parse_nested_json(line)
|
|
logger.debug(f"data: {data}")
|
|
datas.append(data)
|
|
event_data = data.get('event_data', {})
|
|
target_key = 'message.content.text'
|
|
text = glom(event_data, target_key, default=None)
|
|
if not text is None:
|
|
answer = answer + str(text)
|
|
index_key = 'message.content.meta_infos'
|
|
index = glom(event_data, index_key, default=None)
|
|
if index:
|
|
if str(index[0].get("info").get("insert_text")).isdigit():
|
|
# logger.debug(f"index: {index}")
|
|
logger.debug(f"index: {index[0].get("info").get("insert_text")}")
|
|
index_data.append(index[0].get("info").get("insert_text"))
|
|
except JSONDecodeError:
|
|
continue
|
|
logger.debug(f"ai回复: {answer}")
|
|
self.index_data = list(set(index_data))
|
|
self.ai_answer.answer = answer
|
|
self.completed_event.set()
|
|
|
|
def get_platform_id(self) -> int:
|
|
return 5
|
|
|
|
def get_platform_name(self) -> str:
|
|
return 'DouBao'
|