清华开源ChatGPT自动编程ChatDev项目chat_chain.py解读

  • ChatChain类是一个用于实现软件开发的多智能体协作系统,它可以根据用户的自然语言描述来创建定制的软件。
  • __init__()方法是类的构造函数,它接受以下几个参数:
    • config_path: 一个字符串,表示ChatChainConfig.json文件的路径,这个文件包含了ChatChain的基本配置信息,例如智能体链、招聘条件、是否清理结构、是否进行头脑风暴等。
    • config_phase_path: 一个字符串,表示PhaseConfig.json文件的路径,这个文件包含了软件开发的各个阶段的配置信息,例如阶段名称、阶段目标、阶段角色、阶段限制等。
    • config_role_path: 一个字符串,表示RoleConfig.json文件的路径,这个文件包含了软件开发的各个角色的配置信息,例如角色名称、角色描述、角色提示等。
    • task_prompt: 一个字符串,表示用户输入的软件想法,例如“我想要一个五子棋游戏”。
    • project_name: 一个字符串,表示用户输入的软件名称,例如“Gomoku”。
    • org_name: 一个字符串,表示用户所属的组织名称,例如“OpenBMB”。
    • model_type: 一个枚举类型,表示使用的大型语言模型(LLM)的类型,例如ModelType.GPT_3_5_TURBO。
  • __init__()方法首先将这些参数保存在类的属性中,然后使用open()函数和json模块来打开和解析这些配置文件,并将配置信息保存在类的属性中。接着它根据配置信息初始化了ChatChain的智能体链和招聘条件,并设置了默认的最大对话轮数为10。然后它根据配置信息创建了一个ChatEnvConfig对象和一个ChatEnv对象,用于管理软件开发的环境和资源。接下来它将用户输入的软件想法保存在类的属性中,并根据配置信息决定是否对其进行自我改进(这个过程在类的preprocess()方法中实现)。然后它根据配置信息创建了一个字典对象self.role_prompts,用于存储各个角色的提示信息。最后它调用类的get_logfilepath()方法来获取日志文件的路径,并将其保存在类的属性中。
  • check_bool()函数是一个辅助函数,用于将字符串转换为布尔值。它接受一个参数s,表示一个字符串。它将s转换为小写,并判断是否等于”true”。如果是,则返回True,否则返回False。
  • get_logfilepath()方法用于获取日志文件的路径,并返回一个元组(start_time, log_filepath),其中start_time表示开始时间,log_filepath表示日志文件路径。这个方法首先使用datetime模块来获取当前时间,并将其格式化为”%Y-%m-%d %H:%M:%S”形式,并赋值给start_time变量。然后它使用os模块来获取当前工作目录,并将其与”logs”和start_time拼接起来,得到log_filepath变量。最后它返回(start_time, log_filepath)元组。
  • make_recruitment()方法用于招聘合适的智能体,并将他们加入到ChatEnv对象中。它遍历配置信息中定义的招聘条件(self.recruitments属性),对于每个条件(即智能体名称),它调用ChatEnv对象的recruit()方法,将智能体名称作为参数传递,表示招聘该智能体。
  • execute_step()方法用于执行单个软件开发阶段,它接受一个参数phase_item,表示配置信息中定义的单个阶段信息。它首先获取阶段的名称、类型等信息,并根据不同的类型来执行不同的操作。如果阶段类型是”SimplePhase”,则表示这是一个简单的阶段,它会从self.phases属性中获取相应的SimplePhase对象,并调用其execute()方法,将ChatEnv对象、最大对话轮数、是否需要反思等参数传递给该方法,并将返回的新的ChatEnv对象赋值给self.chat_env属性。如果阶段类型是”ComposedPhase”,则表示这是一个复合的阶段,它会从self.compose_phase_module模块中获取相应的ComposedPhase类,并创建一个ComposedPhase对象,将阶段名称、循环次数、组成部分、配置信息、模型类型、日志文件路径等参数传递给其构造函数,并将该对象赋值给compose_phase_instance变量。然后它调用compose_phase_instance对象的execute()方法,将ChatEnv对象作为参数传递,并将返回的新的ChatEnv对象赋值给self.chat_env属性。如果阶段类型是其他类型,则抛出一个异常,表示未实现该类型。
  • execute_chain()方法用于执行整个软件开发过程,它遍历配置信息中定义的智能体链(self.chain属性),对于每个阶段信息,它调用execute_step()方法来执行该阶段。
  • get_logfilepath()方法用于获取日志文件的路径,并返回一个元组(start_time, log_filepath),其中start_time表示开始时间,log_filepath表示日志文件路径。这个方法首先使用datetime模块来获取当前时间,并将其格式化为”%Y-%m-%d %H:%M:%S”形式,并赋值给start_time变量。然后它使用os模块来获取当前工作目录,并将其与”logs”和start_time拼接起来,得到log_filepath变量。最后它返回(start_time, log_filepath)元组。
  • pre_processing()方法用于进行预处理,例如删除无用的文件和记录一些全局的配置信息。它不接受任何参数,也不返回任何值。它首先判断ChatEnv对象的配置信息中是否需要清理结构(self.chat_env.config.clear_structure属性),如果是,则使用os模块来遍历WareHouse目录中的所有文件,并删除除了.py和.log以外的文件,并打印出删除的文件路径。然后它获取软件保存的目录(由项目名称、组织名称和开始时间拼接而成),并调用ChatEnv对象的set_directory()方法,将该目录作为参数传递,表示设置该目录为软件目录。接着它使用shutil模块的copy()函数来将配置文件复制到软件目录中,并使用open()函数和write()方法来将用户输入的软件想法写入到软件目录中的一个.prompt文件中。然后它创建一个字符串preprocess_msg,并赋值为”[Preprocessing]\n\n”,表示开始预处理。接着它创建一个ChatGPTConfig对象,并将其赋值给chat_gpt_config变量,表示LLM的配置信息。然后它在preprocess_msg字符串后面追加一些信息,例如开始时间、配置文件路径、软件想法、项目名称、日志文件路径、ChatDevConfig对象、ChatGPTConfig对象等,并使用log_and_print_online()函数来将preprocess_msg字符串记录到日志文件中,并打印出来。最后它判断配置信息中是否需要进行自我改进(self.config[‘self_improve’]属性),如果是,则调用self.self_task_improve()方法,将用户输入的软件想法作为参数传递,并将返回的更完善的想法赋值给self.chat_env.env_dict[‘task_prompt’]属性。如果不是,则直接将用户输入的软件想法赋值给self.chat_env.env_dict[‘task_prompt’]属性。
  • post_processing()方法用于进行后处理,例如总结产出和移动日志文件到软件目录中。它不接受任何参数,也不返回任何值。它首先调用ChatEnv对象的write_meta()方法,用于写入元数据信息到软件目录中。然后它使用os模块来获取当前工作目录,并将其赋值给filepath变量。接着它使用os模块来获取当前工作目录的父目录,并将其赋值给root变量。然后它创建一个字符串post_info,并赋值为”[Post Info]\n\n”,表示开始后处理。接着它使用datetime模块来获取当前时间,并将其格式化为”%Y%m%d%H%M%S”形式,并赋值给now_time变量。然后它使用datetime模块和strptime()函数来将开始时间和当前时间转换为datetime对象,并分别赋值给datetime1和datetime2变量。接着它使用total_seconds()方法来计算两个datetime对象之间的差异,并将其赋值给duration变量,表示软件开发所花费的时间。然后它在post_info字符串后面追加”Software Info: {}“.format(get_info(self.chat_env.env_dict[‘directory’], self.log_filepath) + “\n\n🕑duration={:.2f}s\n\n”.format(duration)),表示显示软件的信息和开发时间。接着它在post_info字符串后面追加”ChatDev Starts ({})”.format(self.start_time) + “\n\n”,表示显示开始时间。最后它在post_info字符串后面追加”ChatDev Ends ({})”.format(now_time) + “\n\n”,表示显示结束时间。
  • 这段代码定义了一个self_task_improve()方法,用于对用户输入的软件想法进行自我改进,让LLM更好地理解这些想法。它接受一个参数task_prompt,表示用户输入的软件想法。它返回一个字符串revised_task_prompt,表示经过改进的软件想法。
  • 这个方法首先创建一个字符串self_task_improve_prompt,并赋值为一段提示信息,表示要求用户将一个简短的软件设计需求重写为一个详细的提示,让LLM能够根据这个提示来更好地制作这个软件。这个提示信息中包含了用户输入的软件想法(task_prompt参数),以及一些注意事项,例如提示的长度、格式等。然后它创建一个RolePlaying对象role_play_session,并将其赋值给role_play_session变量,表示一个角色扮演的会话。它将以下几个参数传递给RolePlaying类的构造函数:
    • assistant_role_name: 一个字符串,表示助理的角色名称,为”Prompt Engineer”。
    • assistant_role_prompt: 一个字符串,表示助理的角色描述,为”You are an professional prompt engineer that can improve user input prompt to make LLM better understand these prompts.”。
    • user_role_prompt: 一个字符串,表示用户的角色描述,为”You are an user that want to use LLM to build software.”。
    • user_role_name: 一个字符串,表示用户的角色名称,为”User”。
    • task_type: 一个枚举类型,表示任务类型,为TaskType.CHATDEV。
    • task_prompt: 一个字符串,表示任务描述,为”Do prompt engineering on user query”。
    • with_task_specify: 一个布尔值,表示是否需要指定任务类型,为False。
    • model_type: 一个枚举类型,表示使用的LLM的类型,为self.model_type属性。
  • 接着它调用role_play_session对象的init_chat()方法,将None、None和self_task_improve_prompt作为参数传递,并将返回的两个值赋值给_和input_user_msg变量。这个方法用于初始化角色扮演的会话,并返回助理和用户的第一轮对话。其中input_user_msg变量表示用户输入的重写后的软件想法。然后它调用role_play_session对象的step()方法,将input_user_msg和True作为参数传递,并将返回的两个值赋值给assistant_response和user_response变量。这个方法用于进行角色扮演的一步对话,并返回助理和用户的回复。其中assistant_response变量表示助理回复的内容。接着它使用split()方法和strip()方法来从助理回复中提取出改进后的软件想法,并将其赋值给revised_task_prompt变量。然后它调用log_and_print_online()函数来记录并打印助理回复的内容。最后它调用log_and_print_online()函数来记录并打印原始和改进后的软件想法,并返回revised_task_prompt变量。

chat_chain.py的源代码如下:

import importlib
import json
import os
import shutil
from datetime import datetime
import logging
import time

from camel.agents import RolePlaying
from camel.configs import ChatGPTConfig
from camel.typing import TaskType, ModelType
from chatdev.chat_env import ChatEnv, ChatEnvConfig
from chatdev.statistics import get_info
from chatdev.utils import log_and_print_online, now


def check_bool(s):
    return s.lower() == "true"


class ChatChain:

    def __init__(self,
                 config_path: str = None,
                 config_phase_path: str = None,
                 config_role_path: str = None,
                 task_prompt: str = None,
                 project_name: str = None,
                 org_name: str = None,
                 model_type: ModelType = ModelType.GPT_3_5_TURBO) -> None:
        """

        Args:
            config_path: path to the ChatChainConfig.json
            config_phase_path: path to the PhaseConfig.json
            config_role_path: path to the RoleConfig.json
            task_prompt: the user input prompt for software
            project_name: the user input name for software
            org_name: the organization name of the human user
        """

        # load config file
        self.config_path = config_path
        self.config_phase_path = config_phase_path
        self.config_role_path = config_role_path
        self.project_name = project_name
        self.org_name = org_name
        self.model_type = model_type

        with open(self.config_path, 'r', encoding="utf8") as file:
            self.config = json.load(file)
        with open(self.config_phase_path, 'r', encoding="utf8") as file:
            self.config_phase = json.load(file)
        with open(self.config_role_path, 'r', encoding="utf8") as file:
            self.config_role = json.load(file)

        # init chatchain config and recruitments
        self.chain = self.config["chain"]
        self.recruitments = self.config["recruitments"]

        # init default max chat turn
        self.chat_turn_limit_default = 10

        # init ChatEnv
        self.chat_env_config = ChatEnvConfig(clear_structure=check_bool(self.config["clear_structure"]),
                                             brainstorming=check_bool(self.config["brainstorming"]),
                                             gui_design=check_bool(self.config["gui_design"]),
                                             git_management=check_bool(self.config["git_management"]))
        self.chat_env = ChatEnv(self.chat_env_config)

        # the user input prompt will be self-improved (if set "self_improve": "True" in ChatChainConfig.json)
        # the self-improvement is done in self.preprocess
        self.task_prompt_raw = task_prompt
        self.task_prompt = ""

        # init role prompts
        self.role_prompts = dict()
        for role in self.config_role:
            self.role_prompts[role] = "\n".join(self.config_role[role])

        # init log
        self.start_time, self.log_filepath = self.get_logfilepath()

        # init SimplePhase instances
        # import all used phases in PhaseConfig.json from chatdev.phase
        # note that in PhaseConfig.json there only exist SimplePhases
        # ComposedPhases are defined in ChatChainConfig.json and will be imported in self.execute_step
        self.compose_phase_module = importlib.import_module("chatdev.composed_phase")
        self.phase_module = importlib.import_module("chatdev.phase")
        self.phases = dict()
        for phase in self.config_phase:
            assistant_role_name = self.config_phase[phase]['assistant_role_name']
            user_role_name = self.config_phase[phase]['user_role_name']
            phase_prompt = "\n\n".join(self.config_phase[phase]['phase_prompt'])
            phase_class = getattr(self.phase_module, phase)
            phase_instance = phase_class(assistant_role_name=assistant_role_name,
                                         user_role_name=user_role_name,
                                         phase_prompt=phase_prompt,
                                         role_prompts=self.role_prompts,
                                         phase_name=phase,
                                         model_type=self.model_type,
                                         log_filepath=self.log_filepath)
            self.phases[phase] = phase_instance



    def make_recruitment(self):
        """
        recruit all employees
        Returns: None

        """
        for employee in self.recruitments:
            self.chat_env.recruit(agent_name=employee)

    def execute_step(self, phase_item: dict):
        """
        execute single phase in the chain
        Args:
            phase_item: single phase configuration in the ChatChainConfig.json

        Returns:

        """

        phase = phase_item['phase']
        phase_type = phase_item['phaseType']
        # For SimplePhase, just look it up from self.phases and conduct the "Phase.execute" method
        if phase_type == "SimplePhase":
            max_turn_step = phase_item['max_turn_step']
            need_reflect = check_bool(phase_item['need_reflect'])
            if phase in self.phases:
                self.chat_env = self.phases[phase].execute(self.chat_env,
                                                           self.chat_turn_limit_default if max_turn_step <= 0 else max_turn_step,
                                                           need_reflect)
            else:
                raise RuntimeError(f"Phase '{phase}' is not yet implemented in chatdev.phase")
        # For ComposedPhase, we create instance here then conduct the "ComposedPhase.execute" method
        elif phase_type == "ComposedPhase":
            cycle_num = phase_item['cycleNum']
            composition = phase_item['Composition']
            compose_phase_class = getattr(self.compose_phase_module, phase)
            if not compose_phase_class:
                raise RuntimeError(f"Phase '{phase}' is not yet implemented in chatdev.compose_phase")
            compose_phase_instance = compose_phase_class(phase_name=phase,
                                                         cycle_num=cycle_num,
                                                         composition=composition,
                                                         config_phase=self.config_phase,
                                                         config_role=self.config_role,
                                                         model_type=self.model_type,
                                                         log_filepath=self.log_filepath)
            self.chat_env = compose_phase_instance.execute(self.chat_env)
        else:
            raise RuntimeError(f"PhaseType '{phase_type}' is not yet implemented.")

    def execute_chain(self):
        """
        execute the whole chain based on ChatChainConfig.json
        Returns: None

        """
        for phase_item in self.chain:
            self.execute_step(phase_item)

    def get_logfilepath(self):
        """
        get the log path (under the software path)
        Returns:
            start_time: time for starting making the software
            log_filepath: path to the log

        """
        start_time = now()
        filepath = os.path.dirname(__file__)
        # root = "/".join(filepath.split("/")[:-1])
        root = os.path.dirname(filepath)
        # directory = root + "/WareHouse/"
        directory = os.path.join(root, "WareHouse")
        log_filepath = os.path.join(directory, "{}.log".format("_".join([self.project_name, self.org_name,start_time])))
        return start_time, log_filepath

    def pre_processing(self):
        """
        remove useless files and log some global config settings
        Returns: None

        """
        if self.chat_env.config.clear_structure:
            filepath = os.path.dirname(__file__)
            # root = "/".join(filepath.split("/")[:-1])
            root = os.path.dirname(filepath)
            # directory = root + "/WareHouse"
            directory = os.path.join(root, "WareHouse")
            for filename in os.listdir(directory):
                file_path = os.path.join(directory, filename)
                # logs with error trials are left in WareHouse/
                if os.path.isfile(file_path) and not filename.endswith(".py") and not filename.endswith(".log"):
                    os.remove(file_path)
                    print("{} Removed.".format(file_path))

        software_path = os.path.join(directory, "_".join([self.project_name, self.org_name, self.start_time]))
        self.chat_env.set_directory(software_path)

        # copy config files to software path
        shutil.copy(self.config_path, software_path)
        shutil.copy(self.config_phase_path, software_path)
        shutil.copy(self.config_role_path, software_path)

        # write task prompt to software path
        with open(os.path.join(software_path, self.project_name + ".prompt"), "w") as f:
            f.write(self.task_prompt_raw)

        preprocess_msg = "**[Preprocessing]**\n\n"
        chat_gpt_config = ChatGPTConfig()

        preprocess_msg += "**ChatDev Starts** ({})\n\n".format(self.start_time)
        preprocess_msg += "**Timestamp**: {}\n\n".format(self.start_time)
        preprocess_msg += "**config_path**: {}\n\n".format(self.config_path)
        preprocess_msg += "**config_phase_path**: {}\n\n".format(self.config_phase_path)
        preprocess_msg += "**config_role_path**: {}\n\n".format(self.config_role_path)
        preprocess_msg += "**task_prompt**: {}\n\n".format(self.task_prompt_raw)
        preprocess_msg += "**project_name**: {}\n\n".format(self.project_name)
        preprocess_msg += "**Log File**: {}\n\n".format(self.log_filepath)
        preprocess_msg += "**ChatDevConfig**:\n {}\n\n".format(self.chat_env.config.__str__())
        preprocess_msg += "**ChatGPTConfig**:\n {}\n\n".format(chat_gpt_config)
        log_and_print_online(preprocess_msg)

        # init task prompt
        if check_bool(self.config['self_improve']):
            self.chat_env.env_dict['task_prompt'] = self.self_task_improve(self.task_prompt_raw)
        else:
            self.chat_env.env_dict['task_prompt'] = self.task_prompt_raw

    def post_processing(self):
        """
        summarize the production and move log files to the software directory
        Returns: None

        """

        self.chat_env.write_meta()
        filepath = os.path.dirname(__file__)
        # root = "/".join(filepath.split("/")[:-1])
        root = os.path.dirname(filepath)

        post_info = "**[Post Info]**\n\n"
        now_time = now()
        time_format = "%Y%m%d%H%M%S"
        datetime1 = datetime.strptime(self.start_time, time_format)
        datetime2 = datetime.strptime(now_time, time_format)
        duration = (datetime2 - datetime1).total_seconds()

        post_info += "Software Info: {}".format(
            get_info(self.chat_env.env_dict['directory'], self.log_filepath) + "\n\n🕑**duration**={:.2f}s\n\n".format(duration))

        post_info += "ChatDev Starts ({})".format(self.start_time) + "\n\n"
        post_info += "ChatDev Ends ({})".format(now_time) + "\n\n"

        if self.chat_env.config.clear_structure:
            directory = self.chat_env.env_dict['directory']
            for filename in os.listdir(directory):
                file_path = os.path.join(directory, filename)
                if os.path.isdir(file_path) and file_path.endswith("__pycache__"):
                    shutil.rmtree(file_path, ignore_errors=True)
                    post_info += "{} Removed.".format(file_path) + "\n\n"

        log_and_print_online(post_info)

        logging.shutdown()
        time.sleep(1)

        shutil.move(self.log_filepath,
                    os.path.join(root + "/WareHouse", "_".join([self.project_name, self.org_name, self.start_time]),
                                 os.path.basename(self.log_filepath)))

    # @staticmethod
    def self_task_improve(self, task_prompt):
        """
        ask agent to improve the user query prompt
        Args:
            task_prompt: original user query prompt

        Returns:
            revised_task_prompt: revised prompt from the prompt engineer agent

        """
        self_task_improve_prompt = """I will give you a short description of a software design requirement, 
please rewrite it into a detailed prompt that can make large language model know how to make this software better based this prompt,
the prompt should ensure LLMs build a software that can be run correctly, which is the most import part you need to consider.
remember that the revised prompt should not contain more than 200 words, 
here is the short description:\"{}\". 
If the revised prompt is revised_version_of_the_description, 
then you should return a message in a format like \"<INFO> revised_version_of_the_description\", do not return messages in other formats.""".format(
            task_prompt)
        role_play_session = RolePlaying(
            assistant_role_name="Prompt Engineer",
            assistant_role_prompt="You are an professional prompt engineer that can improve user input prompt to make LLM better understand these prompts.",
            user_role_prompt="You are an user that want to use LLM to build software.",
            user_role_name="User",
            task_type=TaskType.CHATDEV,
            task_prompt="Do prompt engineering on user query",
            with_task_specify=False,
            model_type=self.model_type,
        )

        # log_and_print_online("System", role_play_session.assistant_sys_msg)
        # log_and_print_online("System", role_play_session.user_sys_msg)

        _, input_user_msg = role_play_session.init_chat(None, None, self_task_improve_prompt)
        assistant_response, user_response = role_play_session.step(input_user_msg, True)
        revised_task_prompt = assistant_response.msg.content.split("<INFO>")[-1].lower().strip()
        log_and_print_online(role_play_session.assistant_agent.role_name, assistant_response.msg.content)
        log_and_print_online(
            "**[Task Prompt Self Improvement]**\n**Original Task Prompt**: {}\n**Improved Task Prompt**: {}".format(
                task_prompt, revised_task_prompt))
        return revised_task_prompt

关注公众号“大模型全栈程序员”回复“小程序”获取1000个小程序打包源码。更多免费资源在http://www.gitweixin.com/?p=2627

发表评论

邮箱地址不会被公开。 必填项已用*标注