= 8
batch_size = 8000 n_epochs
Experiment 2: What persists, exists
Conducting the second experiment that is training the cellular automaton to maintain its form after it has reached the desired shape.
= '../images/emoji_u1f98e.png' path
= load_image(path) img_tensor
Pool Training
We want to create a pool of samples, that will contain different stages of the generation process as initial states.
SamplePool
SamplePool (pool_size=1024, loss_fn=None, device='cpu')
Initialize self. See help(type(self)) for accurate signature.
In the SamplePool, we use a loss function to select the sample inside the batch with the highest loss and replace it with the original seed. To do so, we will use the MSE loss.
mse
mse (pred, target, dim=1)
= partial(mse, target=img_tensor.repeat(batch_size, 1, 1, 1).to(def_device)) loss_fn
= SamplePool(1024, loss_fn=loss_fn)
pool = pool.sample()
batch batch.shape
torch.Size([8, 16, 40, 40])
We also need a function that will visualize a batch of inputs to validate that the sampling of the pool works as expected.
vis_batch
vis_batch (batch)
vis_batch(batch)
Training Loop
# Instantiate the model
= CAModel(CHANNEL_N).to(def_device)
ca
# Optimization
= 2e-3
lr = 0.9999
lr_gamma = (0.5, 0.5)
betas = torch.optim.Adam(ca.parameters(), lr=lr, betas=betas)
optimizer = optim.lr_scheduler.ExponentialLR(optimizer, lr_gamma)
scheduler
= img_tensor.repeat(batch_size, 1, 1, 1) target
for i in tqdm(range(n_epochs)):
# zero the optimizer
optimizer.zero_grad()
# set the number of steps to take
= torch.randint(64, 96, (1,)).item()
steps
# sample the pool to get the input
= pool.sample()
model_in
# activate the model
= ca(model_in, steps=steps)
res
# calculate the loss
= F.mse_loss(res[:, :4], target) # we only care about the RGBA channels
loss
# update the pool
pool.update(res)
# log the loss
if i%500 == 0:
print(f"Epoch: {i} Loss: {loss.item()}")
# backpropagate the loss and update the weights
loss.backward()
optimizer.step() scheduler.step()
0%| | 1/8000 [00:00<1:26:39, 1.54it/s] 6%|██▌ | 503/8000 [00:35<09:04, 13.77it/s] 13%|████▉ | 1003/8000 [01:09<07:48, 14.92it/s] 19%|███████▎ | 1503/8000 [01:43<07:31, 14.38it/s] 25%|█████████▊ | 2003/8000 [02:17<06:55, 14.43it/s] 31%|████████████▏ | 2503/8000 [02:51<06:19, 14.48it/s] 38%|██████████████▋ | 3003/8000 [03:26<05:46, 14.40it/s] 44%|█████████████████ | 3503/8000 [04:00<05:28, 13.70it/s] 50%|███████████████████▌ | 4003/8000 [04:35<04:51, 13.71it/s] 56%|█████████████████████▉ | 4503/8000 [05:10<04:02, 14.45it/s] 63%|████████████████████████▍ | 5001/8000 [05:44<03:34, 14.00it/s] 69%|██████████████████████████▊ | 5503/8000 [06:23<03:19, 12.54it/s] 75%|█████████████████████████████▎ | 6001/8000 [07:03<02:30, 13.32it/s] 81%|███████████████████████████████▋ | 6503/8000 [07:44<02:02, 12.19it/s] 88%|██████████████████████████████████▏ | 7001/8000 [08:24<01:20, 12.39it/s] 94%|████████████████████████████████████▌ | 7503/8000 [09:04<00:35, 13.94it/s]100%|███████████████████████████████████████| 8000/8000 [09:43<00:00, 13.70it/s]
Epoch: 0 Loss: 0.11318174749612808
Epoch: 500 Loss: 0.025092191994190216
Epoch: 1000 Loss: 0.015459218993782997
Epoch: 1500 Loss: 0.008353400975465775
Epoch: 2000 Loss: 0.0046185762621462345
Epoch: 2500 Loss: 0.0031142488587647676
Epoch: 3000 Loss: 0.0015645724488422275
Epoch: 3500 Loss: 0.0010012438287958503
Epoch: 4000 Loss: 0.0005030333995819092
Epoch: 4500 Loss: 0.0005954373045824468
Epoch: 5000 Loss: 0.0009438891429454088
Epoch: 5500 Loss: 0.000864389818161726
Epoch: 6000 Loss: 0.00047235662350431085
Epoch: 6500 Loss: 0.00031021377071738243
Epoch: 7000 Loss: 0.00025668973103165627
Epoch: 7500 Loss: 0.00032552078482694924
= ca.grow_animation(seed, 200)
images display_animation(images)
The automato manages to maintain it’s shape as time passes.
Visualize a batch from the updated pool.
During the training process we update the samples inside the pool. Let’s visualize how a sample would look like after the training process is completed.
= pool.sample() batch
vis_batch(batch)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).