robustness.model_utils module¶
-
class
robustness.model_utils.
FeatureExtractor
(submod, layers)¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Tool for extracting layers from models.
Parameters: - submod (torch.nn.Module) – model to extract activations from
- layers (list of functions) – list of functions where each function, when applied to submod, returns a desired layer. For example, one function could be lambda model: model.layer1.
Returns: - A model whose forward function returns the activations from the layers
corresponding to the functions in layers (in the order that the functions were passed in the list).
-
forward
(*args, **kwargs)¶
-
class
robustness.model_utils.
DummyModel
(model)¶ Bases:
sphinx.ext.autodoc.importer._MockObject
-
forward
(x, *args, **kwargs)¶
-
-
robustness.model_utils.
make_and_restore_model
(*_, arch, dataset, resume_path=None, parallel=False, pytorch_pretrained=False, add_custom_forward=False)¶ Makes a model and (optionally) restores it from a checkpoint.
Parameters: - arch (str|nn.Module) – Model architecture identifier or otherwise a torch.nn.Module instance with the classifier
- dataset (Dataset class [see datasets.py]) –
- resume_path (str) – optional path to checkpoint saved with the
robustness library (ignored if
arch
is not a string) - a string (not) –
- parallel (bool) – if True, wrap the model in a DataParallel (defaults to False)
- pytorch_pretrained (bool) – if True, try to load a standard-trained checkpoint from the torchvision library (throw error if failed)
- add_custom_forward (bool) – ignored unless arch is an instance of
nn.Module (and not a string). Normally, architectures should have a
forward() function which accepts arguments
with_latent
,fake_relu
, andno_relu
to allow for adversarial manipulation (see `here`<https://robustness.readthedocs.io/en/latest/example_usage/training_lib_part_2.html#training-with-custom-architectures> for more info). If this argument is True, then these options will not be passed to forward(). (Useful if you just want to train a model and don’t care about these arguments, and are passing in an arch that you don’t want to edit forward() for, e.g. a pretrained model)
Returns: A tuple consisting of the model (possibly loaded with checkpoint), and the checkpoint itself
-
robustness.model_utils.
model_dataset_from_store
(s, overwrite_params={}, which='last')¶ Given a store directory corresponding to a trained model, return the original model, dataset object, and args corresponding to the arguments.