Experiment 3: Learning to regenerate

Conducting the third experiment which is training the automaton to regenerate its form following a corruption, such as deleting a part or creating a hole
batch_size = 8
n_epochs = 8000
path = '../images/emoji_u1f98e.png'
img_tensor = load_image(path)

Pool Training with holes

We want to add the ability to the pool of samples to generate corrupted samples as well, meaning samples containing random holes.

We start by creating a function that receives a batch of images and creates a random hole in each one of them.


source

create_hole

 create_hole (batch)
corrupted_image = create_hole(img_tensor)
/home/vvr/anaconda3/envs/fastai/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
plt.imshow(corrupted_image[0].detach().cpu().permute(1, 2, 0))
plt.show()


source

CorruptedPool

 CorruptedPool (pool_size=1024, loss_fn=None, device='cpu')

Initialize self. See help(type(self)) for accurate signature.

loss_fn = partial(mse, target=img_tensor.repeat(batch_size, 1, 1, 1).to(def_device))
pool = CorruptedPool(1024, loss_fn=loss_fn)
batch = pool.sample_with_damage()
vis_batch(batch)

Training Loop

# Instantiate the model
ca = CAModel(CHANNEL_N).to(def_device)

# Optimization
lr = 2e-3
lr_gamma = 0.9999
betas = (0.5, 0.5)
optimizer = torch.optim.Adam(ca.parameters(), lr=lr, betas=betas)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, lr_gamma)

target = img_tensor.repeat(batch_size, 1, 1, 1)
for i in tqdm(range(n_epochs)):
    # zero the optimizer
    optimizer.zero_grad()

    # set the number of steps to take
    steps = torch.randint(64, 96, (1,)).item()

    # sample the pool to get the input
    model_in = pool.sample_with_damage()

    # activate the model
    res = ca(model_in, steps=steps)

    # calculate the loss
    loss = F.mse_loss(res[:, :4], target) # we only care about the RGBA channels

    # update the pool
    pool.update(res)

    # log the loss
    if i%500 == 0:
        print(f"Epoch: {i} Loss: {loss.item()}")

    # backpropagate the loss and update the weights
    loss.backward()
    optimizer.step()
    scheduler.step()
  0%|                                        | 1/8000 [00:00<1:29:22,  1.49it/s]  6%|██▌                                     | 503/8000 [00:37<08:42, 14.35it/s] 13%|████▉                                  | 1003/8000 [01:13<08:13, 14.19it/s] 19%|███████▎                               | 1503/8000 [01:49<07:29, 14.45it/s] 25%|█████████▊                             | 2003/8000 [02:25<07:06, 14.06it/s] 31%|████████████▏                          | 2503/8000 [03:00<06:32, 14.02it/s] 38%|██████████████▋                        | 3003/8000 [03:36<05:46, 14.41it/s] 44%|█████████████████                      | 3503/8000 [04:12<05:45, 13.01it/s] 50%|███████████████████▌                   | 4003/8000 [04:48<04:43, 14.12it/s] 56%|█████████████████████▉                 | 4503/8000 [05:24<04:09, 13.99it/s] 63%|████████████████████████▍              | 5003/8000 [06:00<03:41, 13.51it/s] 69%|██████████████████████████▊            | 5503/8000 [06:36<02:45, 15.10it/s] 75%|█████████████████████████████▎         | 6003/8000 [07:11<02:14, 14.81it/s] 81%|███████████████████████████████▋       | 6503/8000 [07:47<01:45, 14.19it/s] 88%|██████████████████████████████████▏    | 7003/8000 [08:23<01:15, 13.26it/s] 94%|████████████████████████████████████▌  | 7503/8000 [08:59<00:36, 13.63it/s]100%|███████████████████████████████████████| 8000/8000 [09:35<00:00, 13.90it/s]
Epoch: 0 Loss: 0.11324034631252289
Epoch: 500 Loss: 0.029659921303391457
Epoch: 1000 Loss: 0.01154416985809803
Epoch: 1500 Loss: 0.009176425635814667
Epoch: 2000 Loss: 0.004854139406234026
Epoch: 2500 Loss: 0.0035146409645676613
Epoch: 3000 Loss: 0.0019474619766697288
Epoch: 3500 Loss: 0.0017854450270533562
Epoch: 4000 Loss: 0.0013569763395935297
Epoch: 4500 Loss: 0.0009505663765594363
Epoch: 5000 Loss: 0.0013342727907001972
Epoch: 5500 Loss: 0.001305947545915842
Epoch: 6000 Loss: 0.00037199087091721594
Epoch: 6500 Loss: 0.0005692397826351225
Epoch: 7000 Loss: 0.00025373895186930895
Epoch: 7500 Loss: 0.0003689015575218946
images = ca.grow_animation(seed, 200)
display_animation(images)

The automato manages to maintain it’s shape as time passes.

Visualize a batch from the updated pool.

During the training process we update the samples inside the pool. Let’s visualize how a sample would look like after the training process is completed.

batch = pool.sample_with_damage()
vis_batch(batch)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Test the reconstruction ability of a model

# Generate an automato which we want to corrupt
generated = ca(seed, steps=96).detach()

Corruption 1:

corrupted_input1 = generated.clone()
corrupted_input1[..., 20:] = 0

plt.imshow(corrupted_input1[0].permute(1, 2, 0)[:, :, :4].cpu())
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage>

images = ca.grow_animation(corrupted_input1, 300)
display_animation(images)

Corruption 2:

corrupted_input2 = generated.clone()
corrupted_input2[..., :20,:] = 0

plt.imshow(corrupted_input2[0].permute(1, 2, 0)[:, :, :4].cpu())
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage>

images = ca.grow_animation(corrupted_input2, 300)
display_animation(images)