You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

176 lines
7.1 KiB

# coding=utf-8
import asyncio
import re
from functools import partial, wraps
from json import JSONDecodeError
from glom import glom
from playwright.async_api import Browser
from abs_spider import AbstractAiSeoSpider
from domain.ai_seo import AiAnswer, AiSearchResult
from utils import create_logger, parse_nested_json
from utils.image_utils import crop_image_left
logger = create_logger(__name__)
class TongyiSpider(AbstractAiSeoSpider):
def __init__(self, browser: Browser, prompt: str, keyword: str, think: bool = False):
super().__init__(browser, prompt, keyword, think)
self.__listen_response = self.handle_listen_response_error(self.__listen_response)
def get_home_url(self) -> str:
return 'https://tongyi.aliyun.com'
async def _do_spider(self) -> AiAnswer:
# 初始化信息
self._init_data()
await self.browser_page.goto(self.get_home_url(), timeout=600000)
if self.think:
search_btn = self.browser_page.locator("div:text('深度思考')")
if await search_btn.is_visible():
await search_btn.click()
await asyncio.sleep(1)
else:
search_btn = self.browser_page.locator("div:text('联网搜索')")
if await search_btn.is_visible():
await search_btn.click()
await asyncio.sleep(1)
# 开始操作
# chat_input_element = self.browser_page.locator("//textarea[@placeholder='千事不决问通义']")
chat_input_element = self.browser_page.locator("//textarea[contains(@class, 'ant-input')]")
await chat_input_element.click()
# 输入提问词
await self.browser_page.keyboard.type(self.prompt)
await asyncio.sleep(2)
await self.browser_page.keyboard.press('Enter')
# 监听请求
self.browser_page.on('response', partial(self.__listen_response))
await asyncio.sleep(2)
await self.completed_event.wait()
# 报错检查
if self.fail_status:
raise self.fail_exception
# # 获取回答元素
answer_element = self.browser_page.locator("//div[contains(@class, 'answerItem')]").nth(-1)
box = await answer_element.bounding_box()
logger.debug(f'answer_element: {box}')
view_port_height = box['height'] + 500
# 调整视口大小
await self.browser_page.set_viewport_size({
'width': 1920,
'height': int(view_port_height)
})
# 打开搜索结果
search_list_element = self.browser_page.locator("//div[contains(@class, 'linkTitle')]").nth(-1)
if await search_list_element.is_visible():
await search_list_element.click()
await asyncio.sleep(2)
# 关闭侧边栏
side_console_element = self.browser_page.locator("//span[contains(@class, 'sc-frniUE')]")
if await side_console_element.is_visible():
await side_console_element.click()
# 截图
screenshot_path = self._get_screenshot_path()
await self.browser_page.screenshot(path=screenshot_path)
# 切割图片
crop_image_left(screenshot_path, 340)
self.ai_answer.screenshot_file = screenshot_path
return self.ai_answer
async def __listen_response(self, response):
if '/dialog/conversation' not in response.url:
return
# 读取流式数据
data = {}
stream = await response.body()
response_text = stream.decode('utf-8')
datas = response_text.split("\n")
# 合规数据转成字典
for data_str in datas:
if not data_str or data_str == 'data: [DONE]':
continue
data_str = data_str.replace('data: ', '')
try:
data = parse_nested_json(data_str)
except JSONDecodeError as e:
continue
logger.debug(f"结果: {data}")
# 获取结果
contents = data.get('contents', [])
# 保存搜索内容
ai_search_result_list = []
search_result_list = list()
for content in contents:
content_type = content.get('contentType', '')
if content_type == 'plugin':
logger.debug(f"获取到联网搜索结果")
if self.think:
search_result_list = glom(content, 'content.pluginResult', default=[])
else:
search_result_list = glom(content, 'content.pluginResult.-1.search_results', default=[])
# for search_result in search_result_list:
# url = search_result.get('url', '')
# title = search_result.get('title', '')
# body = search_result.get('body', '')
# host_name = search_result.get('host_name', '未知')
# publish_time = search_result.get('time', 0)
# logger.debug(f"ai参考资料: [{host_name}]{title}({url})")
# ai_search_result_list.append(
# AiSearchResult(title=title, url=url, body=body, host_name=host_name, publish_time=publish_time)
# )
if content_type == 'text':
logger.debug(f'获取到ai回复结果')
answer = content.get('content', '')
logger.debug(f"ai回复: {answer}")
self.ai_answer.answer = answer
pattern = r'ty-reference]\((\d+)\)'
index_data = list(set(re.findall(pattern, self.ai_answer.answer)))
for index, search_result in enumerate(search_result_list):
url = search_result.get('url', '')
title = search_result.get('title', '')
body = search_result.get('body', '')
host_name = search_result.get('host_name', '未知')
publish_time = search_result.get('time', 0)
if str(index+1) in index_data:
is_referenced = "1"
else:
is_referenced = "0"
logger.debug(f"ai参考资料: [{host_name}]{title}({url})")
ai_search_result_list.append(
AiSearchResult(title=title, url=url, body=body, host_name=host_name, publish_time=publish_time,is_referenced=is_referenced)
)
if ai_search_result_list:
self.ai_answer.search_result = ai_search_result_list
self.completed_event.set()
def handle_listen_response_error(self, func):
"""
装饰器 用于处理请求回调中的异常
:param func:
:return:
"""
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
logger.error(f"{self.get_platform_name()}响应异常: {e}", exc_info=True)
# 标记失败状态 记录异常
self.fail_status = True
self.fail_exception = e
self.completed_event.set()
return wrapper
def get_platform_id(self) -> int:
return 2
def get_platform_name(self) -> str:
return 'TongYi'