天天看點

Stable Diffusion如何實作API切換模型

作者:愛吃餅幹de大叔

研究過Stable Diffusion接口文檔的小夥伴們肯定知道,文檔中并沒有提供模型參數,那麼如何實作api切換模型呢?

我們先來看原先的sd-webui的代碼,找到模型接收請求參數的中心代碼,然後自己修改源碼,将這些請求參數傳遞到這段中心函數中去。

StableDiffusionProcessingTxt2Img

首要咱們來看最重要的txt2img的代碼,中心的類便是modules.processing中的StableDiffusionProcessingTxt2Img類,它的init函數接納以下的參數:

def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs)           

代碼中的縮寫hr代表的便是webui中的Hires.fix,相關的參數對應的是webui中的這些選項:

Stable Diffusion如何實作API切換模型

接下來,能夠看到還有很多其他的參數沒有看到,其實這些參數都是在StableDiffusionProcessingTxt2Img的父類:StableDiffusionProcessing類的init中指定的:

def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
    self.outpath_samples: str = outpath_samples # 生成的圖檔的儲存路徑,和下面的do_not_save_samples配合運用
    self.outpath_grids: str = outpath_grids
    self.prompt: str = prompt # 正向提示詞
    self.prompt_for_display: str = None
    self.negative_prompt: str = (negative_prompt or "") # 反向提示詞
    self.styles: list = styles or []
    self.seed: int = seed # 種子,-1表明運用随機種子
    self.sampler_name: str = sampler_name # 采樣方法,比方"DPM++ SDE Karras"
    self.batch_size: int = batch_size # 每批生成的數量?
    self.n_iter: int = n_iter
    self.steps: int = steps # UI中的sampling steps
    self.cfg_scale: float = cfg_scale # UI中的CFG Scale,提示詞相關性
    self.width: int = width # 生成圖像的寬度
    self.height: int = height # 生成圖像的高度
    self.restore_faces: bool = restore_faces # 是否運用面部修正
    self.tiling: bool = tiling # 是否運用可平鋪(tilling)
    self.do_not_save_samples: bool = do_not_save_samples           

api接口中模型是如何加載的

我們來看modules/api/api.py中text2imgapi代碼:

def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
        ......
        with self.queue_lock:
            p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
            ......
            return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())           

從代碼中可以看出加載的模型是從shared.sd_model擷取的,但是這樣加載的模型不是使用者次元而是全局的,當我們api傳過來的模型與目前模型不一樣的時候,我們就需要重新加載模型,那麼就需要直接調用modules/sd_models.py中的reload_model_weights(sd_model=None, info=None)函數,咱們隻需傳入info參數就行,用info參數來指定咱們想要加載的模型,而在這個函數中,會自動判斷咱們想要加載的模型和目前模型是否相同,相同的話就不加載。

從函數簽名很難看出來info字段是一個什麼樣的參數,經過我對代碼的研究,我發現info其實便是下面這個類:

class CheckpointInfo:
    def __init__(self, filename):
        self.filename = filename
        abspath = os.path.abspath(filename)
        if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
            name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
        elif abspath.startswith(model_path):
            name = abspath.replace(model_path, '')
        else:
            name = os.path.basename(filename)
        if name.startswith("\\") or name.startswith("/"):
            name = name[1:]
        self.name = name
        self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
        self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
        self.hash = model_hash(filename)
        self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
        self.shorthash = self.sha256[0:10] if self.sha256 else None
        self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
        self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])           

init裡的一大串其實都不用管,咱們隻需要指定filename就行了。是以用如下的示例代碼就能夠手動加載一個指定的模型:

from modules import sd_models
checkpoint_info = sd_models.CheckpointInfo("模型的全路徑名稱")
sd_models.reload_model_weights(info=checkpoint_info)           

看完這裡,我們就可以直接修改源碼了:

1.修改 modules/api/models.py中的StableDiffusionTxt2ImgProcessingAPI增加模型名稱

StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
    "StableDiffusionProcessingTxt2Img",
    StableDiffusionProcessingTxt2Img,
    [
        {"key": "sampler_index", "type": str, "default": "Euler"},
        {"key": "script_name", "type": str, "default": None},
        {"key": "script_args", "type": list, "default": []},
        {"key": "send_images", "type": bool, "default": True},
        {"key": "save_images", "type": bool, "default": False},
        {"key": "alwayson_scripts", "type": dict, "default": {}},
        {"key": "model_name", "type": str, "default": None},
    ]
).generate_model()           

2.修改modules/processing.py中的StableDiffusionProcessingTxt2Img,增加模型名稱接收:

def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '',model_name: str=None, **kwargs):           

3.修改modules/api/api.py中text2imgapi代碼:

def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
        ......
        model_name=txt2imgreq.model_name       
        if model_name is None:
            raise HTTPException(status_code=404, detail="model_name not found")
        ......
        with self.queue_lock:
            checkpoint_info = sd_models.CheckpointInfo(os.path.join(models_path,'Stable-diffusion',model_name))
            sd_models.reload_model_weights(info=checkpoint_info)
            p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
            ......           

到此,我們就完成了文生圖api接口切換模型了,同理,我們也可對圖生圖api增加模型切換。下篇我們将會介紹如何增加任務id及通過任務id查詢任務進度。另外,我們也做了一個繪畫聊天的小程式,可以掃碼體驗:

Stable Diffusion如何實作API切換模型

AI智能繪畫