Skip to content

Commit

Permalink
Support & instructions for MPS (Silicon Mac M1/M2) and CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Jun 26, 2023
1 parent d7f7319 commit f3777f6
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 28 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,19 @@

## Requirements

Please follow the requirements of [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
If you have CUDA graphic card, please follow the requirements of [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).

Otherwise (for GPU acceleration on MacOS with Silicon Mac M1/M2, or just CPU) try the following:

```sh
cat environment.yml | \
grep -v -E 'nvidia|cuda' > environment-no-nvidia.yml && \
conda env create -f environment-no-nvidia.yml
conda activate stylegan3

# On MacOS
export PYTORCH_ENABLE_MPS_FALLBACK=1
```

## Download pre-trained StyleGAN2 weights

Expand Down
17 changes: 10 additions & 7 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,23 @@ channels:
dependencies:
- python >= 3.8
- pip
- numpy>=1.20
- numpy>=1.25
- click>=8.0
- pillow=8.3.1
- scipy=1.7.1
- pytorch=1.9.1
- pillow=9.4.0
- scipy=1.11.0
- pytorch>=2.0.1
- torchvision>=0.15.2
- cudatoolkit=11.1
- requests=2.26.0
- tqdm=4.62.2
- ninja=1.10.2
- matplotlib=3.4.2
- imageio=2.9.0
- pip:
- imgui==1.3.0
- glfw==2.2.0
- imgui==2.0.0
- glfw==2.6.1
- gradio==3.35.2
- pyopengl==3.1.5
- imageio-ffmpeg==0.4.3
- pyspng
# pyspng is currently broken on MacOS (see https://github.com/nurpax/pyspng/pull/6 for instance)
- pyspng-seunglab
7 changes: 4 additions & 3 deletions gen_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def generate_images(
"""

print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
# import pickle
# G = legacy.load_network_pkl(f)
# output = open('checkpoints/stylegan2-car-config-f-pt.pkl', 'wb')
Expand All @@ -126,7 +127,7 @@ def generate_images(
# Generate images.
for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)

# Construct an inverse rotation/translation matrix and pass to the generator. The
# generator expects this matrix as an inverse to avoid potentially failing numerical
Expand Down
7 changes: 4 additions & 3 deletions stylegan_human/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ def generate_images(

else:
import torch
device = torch.device('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
os.makedirs(outdir, exist_ok=True)


Expand All @@ -92,7 +93,7 @@ def generate_images(

else: ## stylegan v2/v3
label = torch.zeros([1, G.c_dim], device=device)
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)
if target_z.size==0:
target_z= z.cpu()
else:
Expand Down
9 changes: 5 additions & 4 deletions stylegan_human/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ def main(
):


device = torch.device('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore

outdir = os.path.join(outdir)
if not os.path.exists(outdir):
Expand All @@ -132,8 +133,8 @@ def main(
print('Require two seeds, randomly generate two now.')
seeds = [seeds[0],random.randint(0,10000)]

z1 = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device)
z2 = torch.from_numpy(np.random.RandomState(seeds[1]).randn(1, G.z_dim)).to(device)
z1 = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device, dtype=dtype)
z2 = torch.from_numpy(np.random.RandomState(seeds[1]).randn(1, G.z_dim)).to(device, dtype=dtype)
img1 = generate_image_from_z(G, z1, noise_mode, truncation_psi, device)
img2 = generate_image_from_z(G, z2, noise_mode, truncation_psi, device)
img1.save(f'{outdir}/seed{seeds[0]:04d}.png')
Expand Down
7 changes: 4 additions & 3 deletions stylegan_human/style_mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,17 @@ def generate_style_mix(
):

print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device)
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)

os.makedirs(outdir, exist_ok=True)

print('Generating W vectors...')
all_seeds = list(set(row_seeds + col_seeds))
all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
all_w = G.mapping(torch.from_numpy(all_z).to(device), None)
all_w = G.mapping(torch.from_numpy(all_z).to(device, dtype=dtype), None)
w_avg = G.mapping.w_avg
all_w = w_avg + (all_w - w_avg) * truncation_psi
w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}
Expand Down
11 changes: 6 additions & 5 deletions stylegan_human/stylemixing_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ def style_mixing_video(network_pkl: str,
print('col_seeds: ', dst_seeds)
num_frames = int(np.rint(duration_sec * mp4_fps))
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
Gs = legacy.load_network_pkl(f)['G_ema'].to(device)
Gs = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)

print(Gs.num_ws, Gs.w_dim, Gs.img_resolution)
max_style = int(2 * np.log2(Gs.img_resolution)) - 3
Expand All @@ -80,14 +81,14 @@ def style_mixing_video(network_pkl: str,
src_z = scipy.ndimage.gaussian_filter(src_z, [smoothing_sec * mp4_fps] + [0] * (2- 1), mode="wrap")
src_z /= np.sqrt(np.mean(np.square(src_z)))
# Map into the detangled latent space W and do truncation trick
src_w = Gs.mapping(torch.from_numpy(src_z).to(device), None)
src_w = Gs.mapping(torch.from_numpy(src_z).to(device, dtype=dtype), None)
w_avg = Gs.mapping.w_avg
src_w = w_avg + (src_w - w_avg) * truncation_psi

# Top row latents (fixed reference)
print('Generating Destination W vectors...')
dst_z = np.stack([np.random.RandomState(seed).randn(Gs.z_dim) for seed in dst_seeds])
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device), None)
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device, dtype=dtype), None)
dst_w = w_avg + (dst_w - w_avg) * truncation_psi
# Get the width and height of each image:
H = Gs.img_resolution # 1024
Expand Down Expand Up @@ -120,7 +121,7 @@ def make_frame(t):
for col, dst_image in enumerate(list(dst_images)):
# Select the pertinent latent w column:
w_col = np.stack([dst_w[col].cpu()]) # [18, 512] -> [1, 18, 512]
w_col = torch.from_numpy(w_col).to(device)
w_col = torch.from_numpy(w_col).to(device, dtype=dtype)
# Replace the values defined by col_styles:
w_col[:, col_styles] = src_w[frame_idx, col_styles]#.cpu()
# Generate these synthesized images:
Expand Down
5 changes: 3 additions & 2 deletions viz/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):

class Renderer:
def __init__(self, disable_timing=False):
self._device = torch.device('cuda')
self._device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64
self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
self._networks = dict() # {cache_key: torch.nn.Module, ...}
self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
Expand Down Expand Up @@ -241,7 +242,7 @@ def init_network(self, res,

if self.w_load is None:
# Generate random latents.
z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device).float()
z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device, dtype=self._dtype)

# Run mapping network.
label = torch.zeros([1, G.c_dim], device=self._device)
Expand Down

0 comments on commit f3777f6

Please sign in to comment.