import datetime import logging import time import torch from contextlib import contextmanager def inference_on_dataset(model, data_loader): """ Run model on the data_loader and extract the features with extractor. The model will be used in eval mode. Args: model (nn.Module): a module which accepts an object from `data_loader` and returns some outputs. It will be temporarily set to `eval` mode. If you wish to extract a model in `training` mode instead, you can wrap the given model and override its behavior of `.eval()` and `.train()`. data_loader: an iterable object with a length. The elements it generates will be the inputs to the model. evaluator (DatasetEvaluator): the evaluator to run. Use :class:`DatasetEvaluators([])` if you only want to benchmark, but don't want to do any evaluation. Returns: The return value of `evaluator.evaluate()` """ num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 logger = logging.getLogger(__name__) logger.info("Start inference on {} images".format(len(data_loader))) total = len(data_loader) # inference data loader must have a fixed length logging_interval = 50 num_warmup = min(5, logging_interval - 1, total - 1) start_time = time.time() total_compute_time = 0 with inference_context(model), torch.no_grad(): for idx, inputs in enumerate(data_loader): if idx == num_warmup: start_time = time.time() total_compute_time = 0 start_compute_time = time.time() outputs = model(inputs) torch.cuda.synchronize() total_compute_time += time.time() - start_compute_time if (idx + 1) % logging_interval == 0: duration = time.time() - start_time seconds_per_img = duration / (idx + 1 - num_warmup) eta = datetime.timedelta( seconds=int(seconds_per_img * (total - num_warmup) - duration) ) logger.info( "Inference done {}/{}. {:.4f} s / img. ETA={}".format( idx + 1, total, seconds_per_img, str(eta) ) ) # Measure the time only for this worker (before the synchronization barrier) total_time = int(time.time() - start_time) total_time_str = str(datetime.timedelta(seconds=total_time)) # NOTE this format is parsed by grep logger.info( "Total inference time: {} ({:.6f} s / img per device, on {} devices)".format( total_time_str, total_time / (total - num_warmup), num_devices ) ) total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time))) logger.info( "Total inference pure compute time: {} ({:.6f} s / img per device, on {} devices)".format( total_compute_time_str, total_compute_time / (total - num_warmup), num_devices ) ) @contextmanager def inference_context(model): """ A context where the model is temporarily changed to eval mode, and restored to previous mode afterwards. Args: model: a torch Module """ training_mode = model.training model.eval() yield model.train(training_mode)