From c3cda0f2035c458e3a365cb7848ba1d20e2183f6 Mon Sep 17 00:00:00 2001 From: Roland Li Date: Mon, 25 Mar 2024 11:09:32 -0700 Subject: [PATCH 1/2] handle new startup --- plugin.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/plugin.py b/plugin.py index 44d4038..920a620 100644 --- a/plugin.py +++ b/plugin.py @@ -19,7 +19,7 @@ from basicsr.utils.img_util import tensor2img from basicsr.utils import img2tensor from basicsr.archs.swinir_arch import SwinIR -from .BasicSR.inference.inference_swinir import define_model +from .inference.inference_swinir import define_model from torch.nn import functional as F @@ -44,22 +44,22 @@ def set_config(update: dict): sr_plugin.set_config(update) # TODO: Validate config dict are all valid keys return sr_plugin.get_config() -@app.on_event("startup") -async def startup_event(): +@app.get("/startup/{plugin_name}") +async def startup_event(plugin_name: str): print("Starting up") # A slight delay to ensure the app has started up. try: - set_model() + set_model(plugin_name) print("Successfully started up") + print(sr_plugin.plugin_name) sr_plugin.notify_main_system_of_startup("True") except Exception as e: - # raise e sr_plugin.notify_main_system_of_startup("False") @app.get("/set_model/") -def set_model(): +def set_model(plugin_name): global sr_plugin - args = {"plugin": plugin, "config": config, "endpoints": endpoints} + args = {"plugin": plugin, "config": config, "endpoints": endpoints, "name": plugin_name} sr_plugin = SR(Namespace(**args)) # try: # sd_plugin.set_model(args["model_name"], dtype=args["model_dtype"]) @@ -117,8 +117,7 @@ class SR(Plugin): """ def __init__(self, arguments: "Namespace") -> None: super().__init__(arguments) - self.plugin_name = "BasicSR" - model_folder = "plugin/BasicSR/experiments/pretrained_models/" + model_folder = f"plugin/{self.plugin_name}/experiments/pretrained_models/" self.esrgan_model_path = os.path.join(model_folder, arguments.config["esrgan_model"]) self.swinir_model_path = os.path.join(model_folder, arguments.config["swinir_model"]) if sys.platform == "darwin": @@ -144,8 +143,9 @@ def set_model(self) -> None: # Load SwinIR if self.swinir_model_path is not None: - split_name = self.swinir_model_path.split("_") - task, scale, patch_size = split_name[2], int(split_name[-1].split("x")[1].split(".")[0]), int(split_name[4][1:3]) + target_model = self.swinir_model_path.split("/")[-1] + split_name = target_model.split("_") + task, scale, patch_size = split_name[1], int(split_name[-1].split("x")[1].split(".")[0]), int(split_name[3][1:3]) if task == "classicalSR": task = "classical_sr" swin_args = {"task": task, "scale": scale, "patch_size": patch_size, "model_path": self.swinir_model_path} @@ -154,10 +154,10 @@ def set_model(self) -> None: self.swin_scale = scale self.load_model(self.swinir_model_path, self.swin_model) - # elif self.method == "BasicVSR": - # self.model = BasicVSR(num_feat=64, num_block=30) - # self.interval = 15 - # self.save_path = "plugin/BasicSuperRes/results/BasicVSR" + # self.vsr_model = BasicVSR(num_feat=64, num_block=30) + # self.vsr_model.to(self.device) + # self.interval = 12 + # self.save_path = "plugin/BasicSuperRes/results/BasicVSR" def super_res(self, inputs, model="esrgan"): From f40214c344309ecc840e08e2aea0de34373fc981 Mon Sep 17 00:00:00 2001 From: Roland Li Date: Mon, 25 Mar 2024 11:10:34 -0700 Subject: [PATCH 2/2] inference script --- .../inference_swinir.cpython-38.pyc | Bin 0 -> 3851 bytes inference/inference_swinir.py | 199 ++++++++++++++++++ 2 files changed, 199 insertions(+) create mode 100644 inference/__pycache__/inference_swinir.cpython-38.pyc create mode 100644 inference/inference_swinir.py diff --git a/inference/__pycache__/inference_swinir.cpython-38.pyc b/inference/__pycache__/inference_swinir.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45d1fedbf4c7a94320fc7e27516679585e30781e GIT binary patch literal 3851 zcmb7GOOG4J5uO)^!^d)0yDRNOFY4j7D3Qz6D-i<2mSx$nVLg(S6$^xe!HC@?XLmS5 z_cV7`#4rh<#5W;2<&p#-kGcDj{DL4So%;|32#|otC5IRYj0DKn&E+oFi2`H>UHz!) zs_Lrls(L&#lhyG3^*4IsM<+GyPjWK-8JOHgi?;xXM)W|djeEVOtGQ7#RBzTS)u(DH z)!Q{2y%D6_PA#L$J!X(?=W02OEt2|)R?8Ecq|wjN0&%tt)}Z<48nF5!t&|x-p7x-m zi)^DC*0~>rUVxeXXx9%v+N?|LuCC4q2Co@wk1+2`Hxv7<7SeYjaQ!9&DIZjSYtt^pcvO-p0nESF% zLeeR6dORPbpry|3t#qPqIphpE`@+yQo*86$W-B+y4>k6R6hniY8|r<7&m{VDy>BMw zpfH4;(9b3YFeY}-k@MTeOPv|yLf;B4aLtDac~?oC&(K^&_Bn2S9_`d z;CG)H{Z!vp_$39G6MLu+b+Wdi_0xQXyqTnu^mFLWR|h8(o3C!28XAMs{LI!_P+dBy z^&Mzi`QpP z3~e&~EGXU>YM)vC9R7TNCdnt(35~36n>gK{6KgzQ*7})b<^W?b4Hz{CCd=0bZ$h87 zB#V}5SfnyEA=^wG*sqWjz9FG-MDsOB$iP~c7PV!0k~Q$n9?lyUY*YADyQzIo+ti+D zNnxurxIEP4`36_e%4CgPhLtKfFJx`4lXZB;V82yXDoiXhEmK*kIy59bf!skAe592( z;N1djfAzqJPLlqGMmBzB42`5P#BF)0e}0i@crjNW^wO1zA9lLDcbs_Kiz$z*k0`%c zy|?+)ukYRHEh%h~)0h|iw%4S^MidarN~W;*vknz$LK|K;;KFLrpwn|Im8i=l!v#iT ztb?C=eu`~xW{lit(z7cS2;CO>df>&eU-tqxX1zPpvvM);n=QUeB?PuyWK@xri%sS| zb4ge()}tU|YP8j%O}Fl`UcOT4c)Z?nW4}j7xADw-3!Hg==!eas!W1P&oXn>xp+&GhLIoBkqjQaB{ePvbcY6@ya`gYn;Y^>Qj5_g zNpV@hB= zri`%fkeHPWVMApEJZ4L&IK0~A;-IbGL%P?Y%x}{W&gd|z%s|Ol=z)7SHmX($8P+s8@spMu>&jj-Mt&@9o{M#Od5-&06sx{b?BZSb2r)+R=rC+Vb(ic;Y>U+ zN`q|V00{ju)6v(`pWrsZ4}ZZ#)BYIU9?6t3lDLw|PM}kCY)LK{g={nY;-i@5ipU;XK95481lLsScBh zn)RJ)!uH#E3Pj{Ty1)7UyLa999^QMniHi`ccWHzjk!}Q$$FE;QLS#|fh3L*Rk%9li zxD&xHDTNoAZWw>wrL>3F97Zml6||DCL+Bp}vx8YmU1MR%7$MU8#B5HU-o~YAcU+Mg zbGJg`w?&FY-4NTlQ1+Hc`|VwZlaywKtt`xhC2h}{%(n;3`;XmP2D+d3TX3f&cYOX$1oJa+tAV~^2&{b98o@@l<`cf{fg)QPnFVG}FX(yIPN>=2<{8W{>We!2 z6mshF%1=J|;EtiGdSwyM1$7Fl9#lXda$VId2B0Zx7Lq#rn>MPJqw_RzWZm*2*VkX_ z;m;MH9Qb@O;e#4xic8Z%MYGQV<1{Hu?S7(m&#Ee75o0TU#dXv>#!((}|9{ow*o0Sj z=XLwWvYQ!HP4aX{eypa+F;r9PG~w@+Hj7G|C*XHn)m2kxd95Bx2YP&=cq}U(zf?R7 z)Mo!5k2{LTisJEyeLYsmNz`yhJ@KEOInaZr1~#~)WerD8C&spoO6F^Fx^JRRJEN*M zCovLPDV-%nSpg=p0zUVE$!N<~Bv=4wqT&17J0qDFMqdJ`cq#H7e~$)ntJ^?86{)NB zDBLOKM7rAQS=$*ex=0iU86?03#IHAVy4UKSer`?>>Vd;I%h>O=W0$r) z6{Rqi$<0P`@aw#kVI@f_6AycLpMFwa)N)f|qubvvS3aIYK{71&$W` zSTef@P)kc>NO8bz`b>s96~o5h`;ux6;2~Q43P3Zm2t6_Yje}4DBiZnOm7Sg7#^FiM zWK+0;3`FDgF`VA%6s9qY*dZJKuN>@~#vR@DKX9|J$JsM^=!LMT&l~D=!f{<0XkE87 z%T{qnc3pz91n*1mZ3&7JC<{9(6R(FMlf?sj2SBVKPx|$kt$Rq)@%lJ~Bc0by` rdLT@i*>$H3Ez<^*))vlNa9lT_Z%fv(EJ$P=pZs|vZ#s9JT}S^nZC?p@ literal 0 HcmV?d00001 diff --git a/inference/inference_swinir.py b/inference/inference_swinir.py new file mode 100644 index 0000000..c82224b --- /dev/null +++ b/inference/inference_swinir.py @@ -0,0 +1,199 @@ +# Modified from https://github.com/JingyunLiang/SwinIR +import argparse +import cv2 +import glob +import numpy as np +import os +import torch +from torch.nn import functional as F + +from basicsr.archs.swinir_arch import SwinIR + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=str, default='datasets/Set5/LRbicx4', help='input test image folder') + parser.add_argument('--output', type=str, default='results/SwinIR/Set5', help='output folder') + parser.add_argument( + '--task', + type=str, + default='classical_sr', + help='classical_sr, lightweight_sr, real_sr, gray_dn, color_dn, jpeg_car') + # dn: denoising; car: compression artifact removal + # TODO: it now only supports sr, need to adapt to dn and jpeg_car + parser.add_argument('--patch_size', type=int, default=64, help='training patch size') + parser.add_argument('--scale', type=int, default=4, help='scale factor: 1, 2, 3, 4, 8') # 1 for dn and jpeg car + parser.add_argument('--noise', type=int, default=15, help='noise level: 15, 25, 50') + parser.add_argument('--jpeg', type=int, default=40, help='scale factor: 10, 20, 30, 40') + parser.add_argument('--large_model', action='store_true', help='Use large model, only used for real image sr') + parser.add_argument( + '--model_path', + type=str, + default='experiments/pretrained_models/SwinIR/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth') + args = parser.parse_args() + + os.makedirs(args.output, exist_ok=True) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # set up model + model = define_model(args) + model.eval() + model = model.to(device) + + if args.task == 'jpeg_car': + window_size = 7 + else: + window_size = 8 + + for idx, path in enumerate(sorted(glob.glob(os.path.join(args.input, '*')))): + # read image + imgname = os.path.splitext(os.path.basename(path))[0] + print('Testing', idx, imgname) + # read image + img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255. + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + img = img.unsqueeze(0).to(device) + + # inference + with torch.no_grad(): + # pad input image to be a multiple of window_size + mod_pad_h, mod_pad_w = 0, 0 + _, _, h, w = img.size() + if h % window_size != 0: + mod_pad_h = window_size - h % window_size + if w % window_size != 0: + mod_pad_w = window_size - w % window_size + img = F.pad(img, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + + output = model(img) + _, _, h, w = output.size() + output = output[:, :, 0:h - mod_pad_h * args.scale, 0:w - mod_pad_w * args.scale] + + # save image + output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + if output.ndim == 3: + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) + output = (output * 255.0).round().astype(np.uint8) + cv2.imwrite(os.path.join(args.output, f'{imgname}_SwinIR.png'), output) + + +def define_model(args): + # 001 classical image sr + if args.task == 'classical_sr': + model = SwinIR( + upscale=args.scale, + in_chans=3, + img_size=args.patch_size, + window_size=8, + img_range=1., + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler='pixelshuffle', + resi_connection='1conv') + + # 002 lightweight image sr + # use 'pixelshuffledirect' to save parameters + elif args.task == 'lightweight_sr': + model = SwinIR( + upscale=args.scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1., + depths=[6, 6, 6, 6], + embed_dim=60, + num_heads=[6, 6, 6, 6], + mlp_ratio=2, + upsampler='pixelshuffledirect', + resi_connection='1conv') + + # 003 real-world image sr + elif args.task == 'real_sr': + if not args.large_model: + # use 'nearest+conv' to avoid block artifacts + model = SwinIR( + upscale=4, + in_chans=3, + img_size=64, + window_size=8, + img_range=1., + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler='nearest+conv', + resi_connection='1conv') + else: + # larger model size; use '3conv' to save parameters and memory; use ema for GAN training + model = SwinIR( + upscale=4, + in_chans=3, + img_size=64, + window_size=8, + img_range=1., + depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], + embed_dim=248, + num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, + upsampler='nearest+conv', + resi_connection='3conv') + + # 004 grayscale image denoising + elif args.task == 'gray_dn': + model = SwinIR( + upscale=1, + in_chans=1, + img_size=128, + window_size=8, + img_range=1., + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler='', + resi_connection='1conv') + + # 005 color image denoising + elif args.task == 'color_dn': + model = SwinIR( + upscale=1, + in_chans=3, + img_size=128, + window_size=8, + img_range=1., + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler='', + resi_connection='1conv') + + # 006 JPEG compression artifact reduction + # use window_size=7 because JPEG encoding uses 8x8; use img_range=255 because it's slightly better than 1 + elif args.task == 'jpeg_car': + model = SwinIR( + upscale=1, + in_chans=1, + img_size=126, + window_size=7, + img_range=255., + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler='', + resi_connection='1conv') + + loadnet = torch.load(args.model_path) + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + + return model + + +if __name__ == '__main__': + main() \ No newline at end of file