From 3affc9415da3e9e0cf62d4a90e770d410cc3ff2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 18 Apr 2026 17:30:48 +0800 Subject: [PATCH] fix(train): unsupported gloo device on win --- infer/modules/train/train.py | 141 +++++++++++++++++++++++++---------- 1 file changed, 102 insertions(+), 39 deletions(-) diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index e8f7156..8c32671 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -106,23 +106,28 @@ def main(): # patch to unblock people without gpus. there is probably a better way. print("NO GPU DETECTED: falling back to CPU - this may take a while") n_gpus = 1 - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(randint(20000, 55555)) - children = [] logger = utils.get_logger(hps.model_dir) - for i in range(n_gpus): - subproc = mp.Process( - target=run, - args=(i, n_gpus, hps, logger), - ) - children.append(subproc) - subproc.start() + if n_gpus == 1: + # Single GPU: run directly without distributed to avoid gloo issues on Windows + run(0, 1, hps, logger) + else: + master_port = str(randint(20000, 55555)) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = master_port + children = [] + for i in range(n_gpus): + subproc = mp.Process( + target=run, + args=(i, n_gpus, hps, logger, master_port), + ) + children.append(subproc) + subproc.start() - for i in range(n_gpus): - children[i].join() + for i in range(n_gpus): + children[i].join() -def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger): +def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger, master_port: str = "29500"): global global_step if rank == 0: # logger = utils.get_logger(hps.model_dir) @@ -131,24 +136,81 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger): writer = SummaryWriter(log_dir=hps.model_dir) writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) - try: - dist.init_process_group( - backend=( - "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl" - ), - init_method="env://", - world_size=n_gpus, - rank=rank, - ) - except: - dist.init_process_group( - backend=( - "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl" - ), - init_method="env://?use_libuv=False", - world_size=n_gpus, - rank=rank, - ) + use_distributed = n_gpus > 1 + if use_distributed: + if os.name == "nt" or not torch.cuda.is_available(): + # On Windows, gloo's create_device(hostname=...) is gated to Linux only + # in the C++ layer (makeDeviceForHostname). We must use the interface- + # based path instead: create_device(interface=...) calls + # makeDeviceForInterface which is not platform-gated. + import socket as _socket + + try: + store = dist.TCPStore( + host_name="127.0.0.1", + port=int(master_port), + world_size=n_gpus, + is_master=(rank == 0), + ) + except Exception: + store = dist.TCPStore( + host_name="127.0.0.1", + port=int(master_port), + world_size=n_gpus, + is_master=(rank == 0), + use_libuv=False, + ) + + # Discover a working network interface for gloo device creation + gloo_device = None + try: + for idx, ifname in _socket.if_nameindex(): + try: + gloo_device = dist.ProcessGroupGloo.create_device( + interface=ifname + ) + print("Try device", idx, "name", ifname) + break + except RuntimeError as e: + print("Try device", idx, "name", ifname, "err:", e) + continue + except (OSError, AttributeError) as e: + print(e.with_traceback(None)) + + if gloo_device is None: + raise RuntimeError( + "Cannot create gloo device on Windows. " + "No usable network interface found. " + "Try adding your hostname to " + "C:\\Windows\\System32\\drivers\\etc\\hosts " + "with: 127.0.0.1 " + _socket.gethostname() + ) + + pg_options = dist.ProcessGroupGloo._Options() + pg_options._devices = [gloo_device] + dist.init_process_group( + backend="gloo", + store=store, + world_size=n_gpus, + rank=rank, + pg_options=pg_options, + ) + else: + init_url = f"tcp://127.0.0.1:{master_port}" + try: + dist.init_process_group( + backend="nccl", + init_method=init_url, + world_size=n_gpus, + rank=rank, + ) + except: + dist.init_process_group( + backend="nccl", + init_method=init_url + "?use_libuv=False", + world_size=n_gpus, + rank=rank, + ) torch.manual_seed(hps.train.seed) if torch.cuda.is_available(): torch.cuda.set_device(rank) @@ -221,14 +283,15 @@ def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger): ) # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) - if hasattr(torch, "xpu") and torch.xpu.is_available(): - pass - elif torch.cuda.is_available(): - net_g = DDP(net_g, device_ids=[rank]) - net_d = DDP(net_d, device_ids=[rank]) - else: - net_g = DDP(net_g) - net_d = DDP(net_d) + if use_distributed: + if hasattr(torch, "xpu") and torch.xpu.is_available(): + pass + elif torch.cuda.is_available(): + net_g = DDP(net_g, device_ids=[rank]) + net_d = DDP(net_d, device_ids=[rank]) + else: + net_g = DDP(net_g) + net_d = DDP(net_d) try: # 如果能加载自动resume _, _, _, epoch_str = utils.load_checkpoint(