autoaug.autoaugment_learners.GruLearner
- class autoaug.autoaugment_learners.GruLearner(num_sub_policies=5, p_bins=11, m_bins=10, exclude_method=[], batch_size=8, toy_size=1, learning_rate=0.1, max_epochs=inf, early_stop_num=20, alpha=0.2, cont_mb_size=4, cont_lr=0.03)[source]
An AutoAugment learner with a GRU controller
The original AutoAugment paper(http://arxiv.org/abs/1805.09501) uses a LSTM controller updated via Proximal Policy Optimization. (See Section 3 of AutoAugment paper)
The GRU has been shown to be as powerful of a sequential neural network as the LSTM whilst training and testing much faster (https://arxiv.org/abs/1412.3555), which is why we substituted the LSTM for the GRU.
- Parameters
num_sub_policies (int, optional) – number of subpolicies per policy. Defaults to 5.
p_bins (int, optional) – number of bins we divide the interval [0,1] for probabilities. e.g. (0.0, 0.1, … 1.0) Defaults to 11.
m_bins (int, optional) – number of bins we divide the magnitude space. Defaults to 10.
exclude_method (list, optional) – list of names(:type:str) of image operations the user wants to exclude from the search space. Defaults to [].
batch_size (int, optional) – child_network training parameter. Defaults to 32.
toy_size (int, optional) – child_network training parameter. ratio of original dataset used in toy dataset. Defaults to 0.1.
learning_rate (float, optional) – child_network training parameter. Defaults to 1e-1.
max_epochs (Union[int, float], optional) – child_network training parameter. Defaults to float(‘inf’).
early_stop_num (int, optional) – child_network training parameter. Defaults to 20.
alpha (float, optional) – Exploration parameter. It is multiplied to operation tensors before they’re softmaxed. The lower this value, the more smoothed the output of the softmaxed will be, hence more exploration. Defaults to 0.2.
cont_mb_size (int, optional) – Controller Minibatch Size. How many policies do we test in order to calculate the PPO(proximal policy update) gradient to update the controller. Defaults to
cont_lr (float, optional) – The learning rate when updating the GRU controller via proximal policy optimization update
- history
list of policies that has been input into self._test_autoaugment_policy as well as their respective obtained accuracies
- Type
- augmentation_space
list of image functions that the user has chosen to include in the search space.
- Type
References
- Ekin D. Cubuk, et al.
“AutoAugment: Learning Augmentation Policies from Data” arXiv:1805.09501
- Junyoung Chung, et al.
“Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling” arXiv:1412.3555
- learn(train_dataset, test_dataset, child_network_architecture, iterations=15)[source]
Runs the main loop (of finding a good policy for the given child network, training dataset, and test(validation) dataset)
Does the loop which is seen in Figure 1 in the AutoAugment paper which is:
<generate a random policy>
<see how good that policy is>
<save how good the policy is in a list/dictionary and (if applicable,) update the controller (e.g. RL agent)>
If
child_network_architectureis a<function>, then we make an instance of it. If this is a<nn.Module>, we make acopy.deepcopyof it. We make a copy of it because we we want to keep an untrained (initialized but not trained) version of the child network architecture, because we need to train it multiple times for each policy. Keepingchild_network_architectureas a<function>is potentially better than keeping it as a<nn.Module>because every time we make a new instance, the weights are differently initialized which means that our results will be less biased (https://en.wikipedia.org/wiki/Bias_(statistics)).- Parameters
train_dataset (torchvision.dataset.vision.VisionDataset) –
test_dataset (torchvision.dataset.vision.VisionDataset) –
child_network_architecture (Union[function, nn.Module]) – This can be both, for example,
LeNetorLeNet()iterations (int) – how many different policies do you want to test
- Returns
none
Example code:
for _ in range(15): policy = self._generate_new_policy() print(policy) reward = self._test_autoaugment_policy(policy, child_network_architecture, train_dataset, test_dataset)