autoaug.autoaugment_learners.GruLearner

class autoaug.autoaugment_learners.GruLearner(num_sub_policies=5, p_bins=11, m_bins=10, exclude_method=[], batch_size=8, toy_size=1, learning_rate=0.1, max_epochs=inf, early_stop_num=20, alpha=0.2, cont_mb_size=4, cont_lr=0.03)[source]

An AutoAugment learner with a GRU controller

The original AutoAugment paper(http://arxiv.org/abs/1805.09501) uses a LSTM controller updated via Proximal Policy Optimization. (See Section 3 of AutoAugment paper)

The GRU has been shown to be as powerful of a sequential neural network as the LSTM whilst training and testing much faster (https://arxiv.org/abs/1412.3555), which is why we substituted the LSTM for the GRU.

Parameters
  • num_sub_policies (int, optional) – number of subpolicies per policy. Defaults to 5.

  • p_bins (int, optional) – number of bins we divide the interval [0,1] for probabilities. e.g. (0.0, 0.1, … 1.0) Defaults to 11.

  • m_bins (int, optional) – number of bins we divide the magnitude space. Defaults to 10.

  • exclude_method (list, optional) – list of names(:type:str) of image operations the user wants to exclude from the search space. Defaults to [].

  • batch_size (int, optional) – child_network training parameter. Defaults to 32.

  • toy_size (int, optional) – child_network training parameter. ratio of original dataset used in toy dataset. Defaults to 0.1.

  • learning_rate (float, optional) – child_network training parameter. Defaults to 1e-1.

  • max_epochs (Union[int, float], optional) – child_network training parameter. Defaults to float(‘inf’).

  • early_stop_num (int, optional) – child_network training parameter. Defaults to 20.

  • alpha (float, optional) – Exploration parameter. It is multiplied to operation tensors before they’re softmaxed. The lower this value, the more smoothed the output of the softmaxed will be, hence more exploration. Defaults to 0.2.

  • cont_mb_size (int, optional) – Controller Minibatch Size. How many policies do we test in order to calculate the PPO(proximal policy update) gradient to update the controller. Defaults to

  • cont_lr (float, optional) – The learning rate when updating the GRU controller via proximal policy optimization update

history

list of policies that has been input into self._test_autoaugment_policy as well as their respective obtained accuracies

Type

list

augmentation_space

list of image functions that the user has chosen to include in the search space.

Type

list

References

Ekin D. Cubuk, et al.

“AutoAugment: Learning Augmentation Policies from Data” arXiv:1805.09501

Junyoung Chung, et al.

“Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling” arXiv:1412.3555

learn(train_dataset, test_dataset, child_network_architecture, iterations=15)[source]

Runs the main loop (of finding a good policy for the given child network, training dataset, and test(validation) dataset)

Does the loop which is seen in Figure 1 in the AutoAugment paper which is:

  1. <generate a random policy>

  2. <see how good that policy is>

  3. <save how good the policy is in a list/dictionary and (if applicable,) update the controller (e.g. RL agent)>

If child_network_architecture is a <function>, then we make an instance of it. If this is a <nn.Module>, we make a copy.deepcopy of it. We make a copy of it because we we want to keep an untrained (initialized but not trained) version of the child network architecture, because we need to train it multiple times for each policy. Keeping child_network_architecture as a <function> is potentially better than keeping it as a <nn.Module> because every time we make a new instance, the weights are differently initialized which means that our results will be less biased (https://en.wikipedia.org/wiki/Bias_(statistics)).

Parameters
  • train_dataset (torchvision.dataset.vision.VisionDataset) –

  • test_dataset (torchvision.dataset.vision.VisionDataset) –

  • child_network_architecture (Union[function, nn.Module]) – This can be both, for example, LeNet or LeNet()

  • iterations (int) – how many different policies do you want to test

Returns

none

Example code:

Listing 10 This is how a child class might implement this method:
for _ in range(15):
    policy = self._generate_new_policy()

    print(policy)
    reward = self._test_autoaugment_policy(policy,
                            child_network_architecture,
                            train_dataset,
                            test_dataset)