vak.config.predict.PredictConfig¶
- class vak.config.predict.PredictConfig(checkpoint_path, labelmap_path, model, batch_size, dataset: DatasetConfig, trainer: TrainerConfig, frames_standardizer_path=None, num_workers=2, annot_csv_filename=None, output_dir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/vak/checkouts/latest/doc'), min_segment_dur=None, majority_vote=True, save_net_outputs=False)[source]¶
Bases:
object
Class that represents
[vak.predict]
table of configuration file.- checkpoint_pathstr
path to directory with checkpoint files saved by Torch, to reload model
- labelmap_pathstr
path to ‘labelmap.json’ file.
- modelvak.config.ModelConfig
The model to use: its name, and the parameters to configure it. Must be an instance of
vak.config.ModelConfig
- batch_sizeint
number of samples per batch presented to models during training.
- datasetvak.config.DatasetConfig
The dataset to use: the path to it, and optionally a path to a file representing splits, and the name, if it is a built-in dataset. Must be an instance of
vak.config.DatasetConfig
.- trainervak.config.TrainerConfig
Configuration for
lightning.pytorch.Trainer
. Must be an instance ofvak.config.TrainerConfig
.- num_workersint
Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2.
- frames_standardizer_pathstr
path to a saved
vak.transforms.FramesStandardizer
object used to standardize (normalize) frames. If spectrograms were normalized and this is not provided, will give incorrect results.- annot_csv_filenamestr
name of .csv file containing predicted annotations. Default is None, in which case the name of the dataset .csv is used, with ‘.annot.csv’ appended to it.
- output_dirstr
path to location where .csv containing predicted annotation should be saved. Defaults to current working directory.
- min_segment_durfloat
minimum duration of segment, in seconds. If specified, then any segment with a duration less than min_segment_dur is removed from lbl_tb. Default is None, in which case no segments are removed.
- majority_votebool
if True, transform segments containing multiple labels into segments with a single label by taking a “majority vote”, i.e. assign all time bins in the segment the most frequently occurring label in the segment. This transform can only be applied if the labelmap contains an ‘unlabeled’ label, because unlabeled segments makes it possible to identify the labeled segments. Default is False.
- save_net_outputsbool
If True, save ‘raw’ outputs of neural networks before they are converted to annotations. Default is False. Typically the output will be “logits” to which a softmax transform might be applied. For each item in the dataset–each row in the dataset_path .csv– the output will be saved in a separate file in output_dir, with the extension {MODEL_NAME}.output.npz. E.g., if the input is a spectrogram with spect_path filename gy6or6_032312_081416.npz, and the network is TweetyNet, then the net output file will be gy6or6_032312_081416.tweetynet.output.npz.
- __init__(checkpoint_path, labelmap_path, model, batch_size, dataset: DatasetConfig, trainer: TrainerConfig, frames_standardizer_path=None, num_workers=2, annot_csv_filename=None, output_dir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/vak/checkouts/latest/doc'), min_segment_dur=None, majority_vote=True, save_net_outputs=False) None ¶
Method generated by attrs for class PredictConfig.
Methods
__init__
(checkpoint_path, labelmap_path, ...)Method generated by attrs for class PredictConfig.
from_config_dict
(config_dict)Return
PredictConfig
instance from adict
.Attributes
checkpoint_path
labelmap_path
model
batch_size
dataset
trainer
frames_standardizer_path
num_workers
annot_csv_filename
output_dir
min_segment_dur
majority_vote
save_net_outputs
- classmethod from_config_dict(config_dict: dict) PredictConfig [source]¶
Return
PredictConfig
instance from adict
.The
dict
passed in should be the one found by loading a valid configuration toml file withvak.config.parse.from_toml_path()
, and then using keypredict
, i.e.,PredictConfig.from_config_dict(config_dict['predict'])
.