defroute_tools(state: State): if isinstance(state, list): ai_message = state[-1] elif messages := state.get("messages", []): ai_message = messages[-1] else: raise ValueError(f"No messages found in input state to tool_edge: {state}") if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: return"tools" return END
# FunctionCall 方式判断是否继续执行 defshould_continue(messages): last_message = messages[-1] # If there is no function call, then we finish if'tool_calls'notin last_message.additional_kwargs: if'|<instruct>|'in system_message: # cohere model pattern = r'Answer:(.+)\nGrounded answer' match = re.search(pattern, last_message.content) if match: last_message.content = match.group(1) return'end' # Otherwise if there is, we continue else: return'continue' # ======================================== # 不支持 FunctionCall 的模型,使用 ReAct 来推理,判断是否继续执行 defshould_continue(data): if isinstance(data['agent_outcome'], AgentFinish): return'end' else: return'continue' # ======================================== # 组装工具和 Agent 节点 if tools: llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(t) for t in tools]) else: llm_with_tools = llm
# node_id: NodeInstance self.nodes_map = {} # record how many nodes fan in this node self.nodes_fan_in = {} # node_id: [node_ids] # record how many nodes next to this node self.nodes_next_nodes = {} # node_id: {node_ids}
defget_all_variables(self) -> Dict[str, Any]: """ 获取所有的变量,key为node_id.key的格式 """ ret = {} for node_id, node_variables in self.variables_pool.items(): for key, value in node_variables.items(): ret[f'{node_id}.{key}'] = self.get_variable(node_id, key) # 特殊处理下 preset_question key if key == 'preset_question': for k, v in value.items(): ret[f'{node_id}.{key}#{k}'] = v return ret
asyncdefhandle_user_input(self, data: dict): status_info = self.workflow.get_workflow_status(user_cache=False) if status_info['status'] != WorkflowStatus.INPUT.value: logger.warning(f'workflow is not input status: {status_info}') else: user_input = {} message_id = None new_message = None # 目前支持一个输入节点 for node_id, node_info in data.items(): user_input[node_id] = node_info['data'] message_id = node_info.get('message_id') new_message = node_info.get('message') break self.workflow.set_user_input(user_input, message_id=message_id, message_content=new_message) self.workflow.set_workflow_status(WorkflowStatus.INPUT_OVER.value) defset_user_input(self, data: dict, message_id: int = None, message_content: str = None): if self.chat_id and message_id: message_db = ChatMessageDao.get_message_by_id(message_id) if message_db: self.update_old_message(data, message_db, message_content) # 通知异步任务用户输入 self.redis_client.set(self.workflow_input_key, data, expiration=self.workflow_expire_time) return
使用输入:
1 2 3 4 5 6 7 8 9 10 11
defget_user_input(self) -> dict | None: ret = self.redis_client.get(self.workflow_input_key) if ret: self.redis_client.delete(self.workflow_input_key) return ret
def_execute_workflow(unique_id: str, workflow_id: str, chat_id: str, user_id: str): redis_callback = RedisCallback(unique_id, workflow_id, chat_id, user_id) try: # update workflow status redis_callback.set_workflow_status(WorkflowStatus.RUNNING.value) # get workflow data workflow_data = redis_callback.get_workflow_data() ifnot workflow_data: raise Exception('workflow data not found maybe data is expired')