Maison >développement back-end >Tutoriel Python >Maîtrisez rapidement la fonction Hook en Python
De nombreuses recommandations d'apprentissage gratuites, alors s'il vous plaît, visitez le tutoriel Python(vidéo)
1. 🎜>
input_filter_fn
insert_queue
class ContentStash(object): """ content stash for online operation pipeline is 1. input_filter: filter some contents, no use to user 2. insert_queue(redis or other broker): insert useful content to queue """ def __init__(self): self.input_filter_fn = None self.broker = [] def register_input_filter_hook(self, input_filter_fn): """ register input filter function, parameter is content dict Args: input_filter_fn: input filter function Returns: """ self.input_filter_fn = input_filter_fn def insert_queue(self, content): """ insert content to queue Args: content: dict Returns: """ self.broker.append(content) def input_pipeline(self, content, use=False): """ pipeline of input for content stash Args: use: is use, defaul False content: dict Returns: """ if not use: return # input filter if self.input_filter_fn: _filter = self.input_filter_fn(content) # insert to queue if not _filter: self.insert_queue(content) # test ## 实现一个你所需要的钩子实现:比如如果content 包含time就过滤掉,否则插入队列 def input_filter_hook(content): """ test input filter hook Args: content: dict Returns: None or content """ if content.get('time') is None: return else: return content # 原有程序 content = {'filename': 'test.jpg', 'b64_file': "#test", 'data': {"result": "cat", "probility": 0.9}} content_stash = ContentStash('audit', work_dir='') # 挂上钩子函数, 可以有各种不同钩子函数的实现,但是要主要函数输入输出必须保持原有程序中一致,比如这里是content content_stash.register_input_filter_hook(input_filter_hook) # 执行流程 content_stash.input_pipeline(content)3. Application des hooks dans les frameworks open source3.1 kerasDans le processus de formation en apprentissage profond, la fonction hook se reflète le plus clairement . Un processus de formation (à l'exclusion de la préparation des données) interrogera l'ensemble de formation plusieurs fois, chaque fois est appelée une époque, et chaque époque est divisée en plusieurs lots pour la formation. Le processus se décompose en :
, l'utilisation du meilleur modèle pour effectuer l'effet d'ensemble de test dans 训练一个epoch后
, etc. 结束训练
@keras_export('keras.callbacks.Callback') class Callback(object): """Abstract base class used to build new callbacks. Attributes: params: Dict. Training parameters (eg. verbosity, batch size, number of epochs...). model: Instance of `keras.models.Model`. Reference of the model being trained. The `logs` dictionary that callback methods take as argument will contain keys for quantities relevant to the current batch or epoch (see method-specific docstrings). """ def __init__(self): self.validation_data = None # pylint: disable=g-missing-from-attributes self.model = None # Whether this Callback should only run on the chief worker in a # Multi-Worker setting. # TODO(omalleyt): Make this attr public once solution is stable. self._chief_worker_only = None self._supports_tf_logs = False def set_params(self, params): self.params = params def set_model(self, model): self.model = model @doc_controls.for_subclass_implementers @generic_utils.default def on_batch_begin(self, batch, logs=None): """A backwards compatibility alias for `on_train_batch_begin`.""" @doc_controls.for_subclass_implementers @generic_utils.default def on_batch_end(self, batch, logs=None): """A backwards compatibility alias for `on_train_batch_end`.""" @doc_controls.for_subclass_implementers def on_epoch_begin(self, epoch, logs=None): """Called at the start of an epoch. Subclasses should override for any actions to run. This function should only be called during TRAIN mode. Arguments: epoch: Integer, index of epoch. logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ @doc_controls.for_subclass_implementers def on_epoch_end(self, epoch, logs=None): """Called at the end of an epoch. Subclasses should override for any actions to run. This function should only be called during TRAIN mode. Arguments: epoch: Integer, index of epoch. logs: Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with `val_`. """ @doc_controls.for_subclass_implementers @generic_utils.default def on_train_batch_begin(self, batch, logs=None): """Called at the beginning of a training batch in `fit` methods. Subclasses should override for any actions to run. Arguments: batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of `model.train_step`. Typically, the values of the `Model`'s metrics are returned. Example: `{'loss': 0.2, 'accuracy': 0.7}`. """ # For backwards compatibility. self.on_batch_begin(batch, logs=logs) @doc_controls.for_subclass_implementers @generic_utils.default def on_train_batch_end(self, batch, logs=None): """Called at the end of a training batch in `fit` methods. Subclasses should override for any actions to run. Arguments: batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch. """ # For backwards compatibility. self.on_batch_end(batch, logs=logs) @doc_controls.for_subclass_implementers @generic_utils.default def on_test_batch_begin(self, batch, logs=None): """Called at the beginning of a batch in `evaluate` methods. Also called at the beginning of a validation batch in the `fit` methods, if validation data is provided. Subclasses should override for any actions to run. Arguments: batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of `model.test_step`. Typically, the values of the `Model`'s metrics are returned. Example: `{'loss': 0.2, 'accuracy': 0.7}`. """ @doc_controls.for_subclass_implementers @generic_utils.default def on_test_batch_end(self, batch, logs=None): """Called at the end of a batch in `evaluate` methods. Also called at the end of a validation batch in the `fit` methods, if validation data is provided. Subclasses should override for any actions to run. Arguments: batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch. """ @doc_controls.for_subclass_implementers @generic_utils.default def on_predict_batch_begin(self, batch, logs=None): """Called at the beginning of a batch in `predict` methods. Subclasses should override for any actions to run. Arguments: batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of `model.predict_step`, it typically returns a dict with a key 'outputs' containing the model's outputs. """ @doc_controls.for_subclass_implementers @generic_utils.default def on_predict_batch_end(self, batch, logs=None): """Called at the end of a batch in `predict` methods. Subclasses should override for any actions to run. Arguments: batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch. """ @doc_controls.for_subclass_implementers def on_train_begin(self, logs=None): """Called at the beginning of training. Subclasses should override for any actions to run. Arguments: logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ @doc_controls.for_subclass_implementers def on_train_end(self, logs=None): """Called at the end of training. Subclasses should override for any actions to run. Arguments: logs: Dict. Currently the output of the last call to `on_epoch_end()` is passed to this argument for this method but that may change in the future. """ @doc_controls.for_subclass_implementers def on_test_begin(self, logs=None): """Called at the beginning of evaluation or validation. Subclasses should override for any actions to run. Arguments: logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ @doc_controls.for_subclass_implementers def on_test_end(self, logs=None): """Called at the end of evaluation or validation. Subclasses should override for any actions to run. Arguments: logs: Dict. Currently the output of the last call to `on_test_batch_end()` is passed to this argument for this method but that may change in the future. """ @doc_controls.for_subclass_implementers def on_predict_begin(self, logs=None): """Called at the beginning of prediction. Subclasses should override for any actions to run. Arguments: logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ @doc_controls.for_subclass_implementers def on_predict_end(self, logs=None): """Called at the end of prediction. Subclasses should override for any actions to run. Arguments: logs: Dict. Currently no data is passed to this argument for this method but that may change in the future. """ def _implements_train_batch_hooks(self): """Determines if this Callback should be called for each train batch.""" return (not generic_utils.is_default(self.on_batch_begin) or not generic_utils.is_default(self.on_batch_end) or not generic_utils.is_default(self.on_train_batch_begin) or not generic_utils.is_default(self.on_train_batch_end))Les programmes originaux de ces hooks sont en cours de formation du modèle
emplacement du code source de Keras : tensorflowpythonkerasenginetraining.pyUne partie de l'extrait est comme suit (## I am hook) :
# Container that configures and calls `tf.keras.Callback`s. if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList( callbacks, add_history=True, add_progbar=verbose != 0, model=self, verbose=verbose, epochs=epochs, steps=data_handler.inferred_steps) ## I am hook callbacks.on_train_begin() training_logs = None # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to # happen after `callbacks.on_train_begin`. data_handler._initial_epoch = ( # pylint: disable=protected-access self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) for epoch, iterator in data_handler.enumerate_epochs(): self.reset_metrics() callbacks.on_epoch_begin(epoch) with data_handler.catch_stop_iteration(): for step in data_handler.steps(): with trace.Trace( 'TraceContext', graph_type='train', epoch_num=epoch, step_num=step, batch_size=batch_size): ## I am hook callbacks.on_train_batch_begin(step) tmp_logs = train_function(iterator) if data_handler.should_sync: context.async_wait() logs = tmp_logs # No error, now safe to assign to logs. end_step = step + data_handler.step_increment callbacks.on_train_batch_end(end_step, logs) epoch_logs = copy.copy(logs) # Run validation. ## I am hook callbacks.on_epoch_end(epoch, epoch_logs)3,2 mmdetectionmmdetection est un framework open source pour la détection de cibles qui intègre de nombreux algorithmes d'apprentissage en profondeur de détection de cibles différents (version pytorch), tels que plus rapide-rcnn, fpn, retian et al. Les hooks sont également largement utilisés pour exposer des parties spécifiques du processus de mise en œuvre de l’application. Voir
https://github.com/open-mmlab/mmdetection
pour plus de détails
这里看一个训练的调用例子(摘录)(https://github.com/open-mmlab/mmdetection/blob/5d592154cca589c5113e8aadc8798bbc73630d98/mmdet/apis/train.py
)
def train_detector(model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None): logger = get_root_logger(cfg.log_level) # prepare data loaders # put model on gpus # build runner optimizer = build_optimizer(model, cfg.optimizer) runner = EpochBasedRunner( model, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta) # an ugly workaround to make .log and .log.json filenames the same runner.timestamp = timestamp # fp16 setting # register hooks runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None)) if distributed: runner.register_hook(DistSamplerSeedHook()) # register eval hooks if validate: # Support batch_size > 1 in validation eval_cfg = cfg.get('evaluation', {}) eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) # user-defined hooks if cfg.get('custom_hooks', None): custom_hooks = cfg.custom_hooks assert isinstance(custom_hooks, list), \ f'custom_hooks expect list type, but got {type(custom_hooks)}' for hook_cfg in cfg.custom_hooks: assert isinstance(hook_cfg, dict), \ 'Each item in custom_hooks expects dict type, but got ' \ f'{type(hook_cfg)}' hook_cfg = hook_cfg.copy() priority = hook_cfg.pop('priority', 'NORMAL') hook = build_from_cfg(hook_cfg, HOOKS) runner.register_hook(hook, priority=priority)
本文介绍了hook的概念和应用,并给出了python的实现细则。希望对比有帮助。总结如下:
hook函数是流程中预定义好的一个步骤,没有实现
挂载或者注册时, 流程执行就会执行这个钩子函数
回调函数和hook函数功能上是一致的
hook设计方式带来灵活性,如果流程中有一个步骤,你想让调用方来实现,你可以用hook函数
相关免费学习推荐:php编程(视频)
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!