308 lines
12 KiB
Python
308 lines
12 KiB
Python
import time
|
||
import os
|
||
import json
|
||
import requests
|
||
import re
|
||
import threading
|
||
from queue import Queue
|
||
|
||
task_queue = Queue()
|
||
|
||
llm_config = {
|
||
"api_url": "",
|
||
"api_key": "",
|
||
"model": ""
|
||
}
|
||
|
||
def load_skills_md():
|
||
try:
|
||
with open('agent/skills.md', 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
except Exception as e:
|
||
print(f"加载技能文档出错: {e}")
|
||
return ""
|
||
|
||
def llm_call(prompt):
|
||
if not llm_config["api_url"] or not llm_config["api_key"] or not llm_config["model"]:
|
||
return {"action": "stop", "args": {}}
|
||
|
||
try:
|
||
# 确保API URL以/chat/completions结尾(除了讯飞MaaS API)
|
||
api_url = llm_config["api_url"].strip()
|
||
# 对于非讯飞MaaS API,自动添加/chat/completions路径
|
||
if "maas-api.cn" not in api_url and not api_url.endswith("/chat/completions"):
|
||
if api_url.endswith("/"):
|
||
api_url += "chat/completions"
|
||
else:
|
||
api_url += "/chat/completions"
|
||
|
||
# 准备请求头
|
||
headers = {
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
# 检查API Key格式,如果是client_id:client_secret格式,使用Basic认证
|
||
api_key = llm_config["api_key"]
|
||
if ":" in api_key:
|
||
# 讯飞MaaS API格式:client_id:client_secret
|
||
import base64
|
||
auth_str = base64.b64encode(api_key.encode()).decode()
|
||
headers["Authorization"] = f"Basic {auth_str}"
|
||
else:
|
||
# OpenAI API格式:Bearer token
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
|
||
# 构建请求数据
|
||
# 优先使用讯飞MaaS API格式
|
||
if "maas-api.cn" in api_url:
|
||
# 讯飞MaaS API格式
|
||
data = {
|
||
"model": llm_config["model"],
|
||
"messages": [
|
||
{"role": "system", "content": "你是一个智能垃圾桶控制助手,请根据用户输入返回对应的动作指令。"},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
"temperature": 0.7,
|
||
"max_tokens": 500
|
||
}
|
||
else:
|
||
# OpenAI API格式
|
||
data = {
|
||
"model": llm_config["model"],
|
||
"messages": [
|
||
{"role": "system", "content": "你是一个智能垃圾桶控制助手,请根据用户输入返回对应的动作指令。"},
|
||
{"role": "user", "content": prompt}
|
||
]
|
||
}
|
||
|
||
print(f"调用LLM API: {api_url}")
|
||
print(f"请求数据: {json.dumps(data, ensure_ascii=False)}")
|
||
|
||
response = requests.post(api_url, headers=headers, json=data, timeout=15)
|
||
print(f"API响应状态码: {response.status_code}")
|
||
print(f"API响应内容: {response.text}")
|
||
|
||
# 检查响应状态码
|
||
if response.status_code != 200:
|
||
print(f"API调用失败,状态码: {response.status_code}")
|
||
return {"action": "stop", "args": {}}
|
||
|
||
try:
|
||
result = response.json()
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON解析错误: {e}")
|
||
return {"action": "stop", "args": {}}
|
||
|
||
# 处理不同API的响应格式
|
||
if "maas-api.cn" in api_url:
|
||
# 讯飞MaaS API响应格式
|
||
if "choices" in result and len(result["choices"]) > 0:
|
||
content = result["choices"][0]["message"]["content"]
|
||
return safe_parse(content)
|
||
else:
|
||
return {"action": "stop", "args": {}}
|
||
else:
|
||
# OpenAI API响应格式
|
||
if "choices" in result and len(result["choices"]) > 0:
|
||
content = result["choices"][0]["message"]["content"]
|
||
return safe_parse(content)
|
||
else:
|
||
return {"action": "stop", "args": {}}
|
||
except Exception as e:
|
||
print(f"LLM调用出错: {e}")
|
||
return {"action": "stop", "args": {}}
|
||
|
||
def safe_parse(result):
|
||
try:
|
||
if isinstance(result, dict):
|
||
return result
|
||
if isinstance(result, list):
|
||
return result
|
||
result = result.strip()
|
||
if result.startswith("```"):
|
||
lines = result.split("\n")
|
||
for i, line in enumerate(lines):
|
||
if not line.startswith("```") and line.strip():
|
||
result = "\n".join(lines[i:])
|
||
break
|
||
if result.startswith("```"):
|
||
return {"action": "stop", "args": {}}
|
||
if result.endswith("```"):
|
||
result = result[:-3].strip()
|
||
|
||
# 尝试解析为JSON数组(多个动作)
|
||
try:
|
||
parsed = json.loads(result)
|
||
if isinstance(parsed, list):
|
||
return parsed
|
||
except:
|
||
pass
|
||
|
||
# 尝试解析为JSON对象(单个动作)
|
||
for line in result.split("\n"):
|
||
line = line.strip()
|
||
if line.startswith("{") and line.endswith("}"):
|
||
return json.loads(line)
|
||
|
||
match = re.search(r'\{[^}]+\}', result)
|
||
if match:
|
||
return json.loads(match.group())
|
||
|
||
return {"action": "stop", "args": {}}
|
||
except:
|
||
return {"action": "stop", "args": {}}
|
||
|
||
def execute_skill(action, args, motor_module):
|
||
ALLOWED = [
|
||
"move_forward", "move_backward", "turn_left", "turn_right",
|
||
"stop", "play_path", "list_paths", "delete_path", "save_path"
|
||
]
|
||
|
||
if action not in ALLOWED:
|
||
return {"status": "error", "message": "不允许的动作"}
|
||
|
||
try:
|
||
if action == "move_forward":
|
||
print("控制垃圾桶前进")
|
||
motor_module.backward(speed=0.6)
|
||
def stop_after_duration():
|
||
time.sleep(args.get("duration", 1))
|
||
motor_module.stop()
|
||
threading.Thread(target=stop_after_duration).start()
|
||
return {"status": "success", "message": f"前进{args.get('duration', 1)}秒"}
|
||
|
||
elif action == "move_backward":
|
||
print("控制垃圾桶后退")
|
||
motor_module.forward(speed=0.6)
|
||
def stop_after_duration():
|
||
time.sleep(args.get("duration", 1))
|
||
motor_module.stop()
|
||
threading.Thread(target=stop_after_duration).start()
|
||
return {"status": "success", "message": f"后退{args.get('duration', 1)}秒"}
|
||
|
||
elif action == "turn_left":
|
||
print("控制垃圾桶左旋转")
|
||
motor_module.rotate_left(speed=0.6)
|
||
def stop_after_duration():
|
||
time.sleep(args.get("duration", 1))
|
||
motor_module.stop()
|
||
threading.Thread(target=stop_after_duration).start()
|
||
return {"status": "success", "message": f"左转{args.get('duration', 1)}秒"}
|
||
|
||
elif action == "turn_right":
|
||
print("控制垃圾桶右旋转")
|
||
motor_module.rotate_right(speed=0.6)
|
||
def stop_after_duration():
|
||
time.sleep(args.get("duration", 1))
|
||
motor_module.stop()
|
||
threading.Thread(target=stop_after_duration).start()
|
||
return {"status": "success", "message": f"右转{args.get('duration', 1)}秒"}
|
||
|
||
elif action == "stop":
|
||
print("停止垃圾桶")
|
||
motor_module.stop()
|
||
return {"status": "success", "message": "已停止"}
|
||
|
||
elif action == "play_path":
|
||
path_name = args.get("name")
|
||
if not path_name:
|
||
return {"status": "error", "message": "路径名称不能为空"}
|
||
return {"status": "success", "message": f"播放轨迹{path_name}"}
|
||
|
||
elif action == "list_paths":
|
||
return {"status": "success", "message": "获取轨迹列表"}
|
||
|
||
elif action == "delete_path":
|
||
path_name = args.get("name")
|
||
if not path_name:
|
||
return {"status": "error", "message": "路径名称不能为空"}
|
||
return {"status": "success", "message": f"删除轨迹{path_name}"}
|
||
|
||
elif action == "save_path":
|
||
path_name = args.get("name")
|
||
if not path_name:
|
||
return {"status": "error", "message": "路径名称不能为空"}
|
||
return {"status": "success", "message": f"保存轨迹{path_name}"}
|
||
|
||
else:
|
||
return {"status": "error", "message": "无效的动作"}
|
||
except Exception as e:
|
||
print(f"执行技能出错: {e}")
|
||
return {"status": "error", "message": f"执行技能出错: {e}"}
|
||
|
||
def create_agent_routes(app, motor_module):
|
||
@app.route('/agent_config', methods=['POST'])
|
||
def agent_config():
|
||
from flask import request, jsonify
|
||
data = request.get_json()
|
||
api_url = data.get('api_url', '')
|
||
api_key = data.get('api_key', '')
|
||
model = data.get('model', '')
|
||
|
||
llm_config['api_url'] = api_url
|
||
llm_config['api_key'] = api_key
|
||
llm_config['model'] = model
|
||
|
||
print(f"LLM配置已更新: API={api_url}, Model={model}")
|
||
return jsonify({'status': 'success', 'message': '配置已保存'})
|
||
|
||
@app.route('/agent_config', methods=['GET'])
|
||
def get_agent_config():
|
||
from flask import jsonify
|
||
return jsonify({
|
||
'status': 'success',
|
||
'api_url': llm_config['api_url'],
|
||
'api_key': llm_config['api_key'],
|
||
'model': llm_config['model']
|
||
})
|
||
|
||
@app.route('/agent_chat', methods=['POST'])
|
||
def agent_chat():
|
||
from flask import request, jsonify
|
||
data = request.get_json()
|
||
text = data.get('text')
|
||
|
||
if not text:
|
||
return jsonify({'status': 'error', 'message': '输入文本不能为空'})
|
||
|
||
try:
|
||
skills_prompt = load_skills_md()
|
||
prompt = skills_prompt + "\n用户输入:" + text
|
||
result = llm_call(prompt)
|
||
parsed_result = safe_parse(result)
|
||
|
||
# 检查是否是多个动作(列表)
|
||
if isinstance(parsed_result, list) and len(parsed_result) > 0:
|
||
# 执行第一个动作
|
||
first_action = parsed_result[0]
|
||
action = first_action.get("action", "stop")
|
||
args = first_action.get("args", {})
|
||
task_queue.put((action, args))
|
||
response = execute_skill(action, args, motor_module)
|
||
|
||
# 如果有后续动作,在第一个动作完成后执行
|
||
if len(parsed_result) > 1:
|
||
def execute_next_actions():
|
||
for i, next_action_data in enumerate(parsed_result[1:]):
|
||
# 等待前一个动作完成
|
||
time.sleep(next_action_data.get("args", {}).get("duration", 1))
|
||
# 添加0.5秒缓冲时间
|
||
time.sleep(0.5)
|
||
next_action = next_action_data.get("action", "stop")
|
||
next_args = next_action_data.get("args", {})
|
||
task_queue.put((next_action, next_args))
|
||
execute_skill(next_action, next_args, motor_module)
|
||
|
||
threading.Thread(target=execute_next_actions).start()
|
||
|
||
return jsonify({"status": "success", "message": f"开始执行复合指令,共{len(parsed_result)}个动作"})
|
||
else:
|
||
# 单个动作的处理
|
||
action = parsed_result.get("action", "stop")
|
||
args = parsed_result.get("args", {})
|
||
task_queue.put((action, args))
|
||
response = execute_skill(action, args, motor_module)
|
||
return jsonify(response)
|
||
except Exception as e:
|
||
print(f"Agent聊天出错: {e}")
|
||
return jsonify({'status': 'error', 'message': f'Agent聊天出错: {e}'}) |