= 8
batch_size = 8000 n_epochs
Experiment 3: Learning to regenerate
Conducting the third experiment which is training the automaton to regenerate its form following a corruption, such as deleting a part or creating a hole
= '../images/emoji_u1f98e.png' path
= load_image(path) img_tensor
Pool Training with holes
We want to add the ability to the pool of samples to generate corrupted samples as well, meaning samples containing random holes.
We start by creating a function that receives a batch of images and creates a random hole in each one of them.
create_hole
create_hole (batch)
= create_hole(img_tensor) corrupted_image
/home/vvr/anaconda3/envs/fastai/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
0].detach().cpu().permute(1, 2, 0))
plt.imshow(corrupted_image[ plt.show()
CorruptedPool
CorruptedPool (pool_size=1024, loss_fn=None, device='cpu')
Initialize self. See help(type(self)) for accurate signature.
= partial(mse, target=img_tensor.repeat(batch_size, 1, 1, 1).to(def_device)) loss_fn
= CorruptedPool(1024, loss_fn=loss_fn)
pool = pool.sample_with_damage()
batch vis_batch(batch)
Training Loop
# Instantiate the model
= CAModel(CHANNEL_N).to(def_device)
ca
# Optimization
= 2e-3
lr = 0.9999
lr_gamma = (0.5, 0.5)
betas = torch.optim.Adam(ca.parameters(), lr=lr, betas=betas)
optimizer = optim.lr_scheduler.ExponentialLR(optimizer, lr_gamma)
scheduler
= img_tensor.repeat(batch_size, 1, 1, 1) target
for i in tqdm(range(n_epochs)):
# zero the optimizer
optimizer.zero_grad()
# set the number of steps to take
= torch.randint(64, 96, (1,)).item()
steps
# sample the pool to get the input
= pool.sample_with_damage()
model_in
# activate the model
= ca(model_in, steps=steps)
res
# calculate the loss
= F.mse_loss(res[:, :4], target) # we only care about the RGBA channels
loss
# update the pool
pool.update(res)
# log the loss
if i%500 == 0:
print(f"Epoch: {i} Loss: {loss.item()}")
# backpropagate the loss and update the weights
loss.backward()
optimizer.step() scheduler.step()
0%| | 1/8000 [00:00<1:29:22, 1.49it/s] 6%|██▌ | 503/8000 [00:37<08:42, 14.35it/s] 13%|████▉ | 1003/8000 [01:13<08:13, 14.19it/s] 19%|███████▎ | 1503/8000 [01:49<07:29, 14.45it/s] 25%|█████████▊ | 2003/8000 [02:25<07:06, 14.06it/s] 31%|████████████▏ | 2503/8000 [03:00<06:32, 14.02it/s] 38%|██████████████▋ | 3003/8000 [03:36<05:46, 14.41it/s] 44%|█████████████████ | 3503/8000 [04:12<05:45, 13.01it/s] 50%|███████████████████▌ | 4003/8000 [04:48<04:43, 14.12it/s] 56%|█████████████████████▉ | 4503/8000 [05:24<04:09, 13.99it/s] 63%|████████████████████████▍ | 5003/8000 [06:00<03:41, 13.51it/s] 69%|██████████████████████████▊ | 5503/8000 [06:36<02:45, 15.10it/s] 75%|█████████████████████████████▎ | 6003/8000 [07:11<02:14, 14.81it/s] 81%|███████████████████████████████▋ | 6503/8000 [07:47<01:45, 14.19it/s] 88%|██████████████████████████████████▏ | 7003/8000 [08:23<01:15, 13.26it/s] 94%|████████████████████████████████████▌ | 7503/8000 [08:59<00:36, 13.63it/s]100%|███████████████████████████████████████| 8000/8000 [09:35<00:00, 13.90it/s]
Epoch: 0 Loss: 0.11324034631252289
Epoch: 500 Loss: 0.029659921303391457
Epoch: 1000 Loss: 0.01154416985809803
Epoch: 1500 Loss: 0.009176425635814667
Epoch: 2000 Loss: 0.004854139406234026
Epoch: 2500 Loss: 0.0035146409645676613
Epoch: 3000 Loss: 0.0019474619766697288
Epoch: 3500 Loss: 0.0017854450270533562
Epoch: 4000 Loss: 0.0013569763395935297
Epoch: 4500 Loss: 0.0009505663765594363
Epoch: 5000 Loss: 0.0013342727907001972
Epoch: 5500 Loss: 0.001305947545915842
Epoch: 6000 Loss: 0.00037199087091721594
Epoch: 6500 Loss: 0.0005692397826351225
Epoch: 7000 Loss: 0.00025373895186930895
Epoch: 7500 Loss: 0.0003689015575218946
= ca.grow_animation(seed, 200)
images display_animation(images)
The automato manages to maintain it’s shape as time passes.
Visualize a batch from the updated pool.
During the training process we update the samples inside the pool. Let’s visualize how a sample would look like after the training process is completed.
= pool.sample_with_damage() batch
vis_batch(batch)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Test the reconstruction ability of a model
# Generate an automato which we want to corrupt
= ca(seed, steps=96).detach() generated
Corruption 1:
= generated.clone()
corrupted_input1 20:] = 0
corrupted_input1[...,
0].permute(1, 2, 0)[:, :, :4].cpu()) plt.imshow(corrupted_input1[
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage>
= ca.grow_animation(corrupted_input1, 300)
images display_animation(images)
Corruption 2:
= generated.clone()
corrupted_input2 20,:] = 0
corrupted_input2[..., :
0].permute(1, 2, 0)[:, :, :4].cpu()) plt.imshow(corrupted_input2[
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage>
= ca.grow_animation(corrupted_input2, 300)
images display_animation(images)