88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
import datetime
|
|
import logging
|
|
import time
|
|
import torch
|
|
|
|
from contextlib import contextmanager
|
|
|
|
def inference_on_dataset(model, data_loader):
|
|
"""
|
|
Run model on the data_loader and extract the features with extractor.
|
|
The model will be used in eval mode.
|
|
|
|
Args:
|
|
model (nn.Module): a module which accepts an object from
|
|
`data_loader` and returns some outputs. It will be temporarily set to `eval` mode.
|
|
|
|
If you wish to extract a model in `training` mode instead, you can
|
|
wrap the given model and override its behavior of `.eval()` and `.train()`.
|
|
data_loader: an iterable object with a length.
|
|
The elements it generates will be the inputs to the model.
|
|
evaluator (DatasetEvaluator): the evaluator to run. Use
|
|
:class:`DatasetEvaluators([])` if you only want to benchmark, but
|
|
don't want to do any evaluation.
|
|
|
|
Returns:
|
|
The return value of `evaluator.evaluate()`
|
|
"""
|
|
num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
|
logger = logging.getLogger(__name__)
|
|
logger.info("Start inference on {} images".format(len(data_loader)))
|
|
|
|
total = len(data_loader) # inference data loader must have a fixed length
|
|
|
|
logging_interval = 50
|
|
num_warmup = min(5, logging_interval - 1, total - 1)
|
|
start_time = time.time()
|
|
total_compute_time = 0
|
|
with inference_context(model), torch.no_grad():
|
|
for idx, inputs in enumerate(data_loader):
|
|
if idx == num_warmup:
|
|
start_time = time.time()
|
|
total_compute_time = 0
|
|
|
|
start_compute_time = time.time()
|
|
outputs = model(inputs)
|
|
torch.cuda.synchronize()
|
|
total_compute_time += time.time() - start_compute_time
|
|
if (idx + 1) % logging_interval == 0:
|
|
duration = time.time() - start_time
|
|
seconds_per_img = duration / (idx + 1 - num_warmup)
|
|
eta = datetime.timedelta(
|
|
seconds=int(seconds_per_img * (total - num_warmup) - duration)
|
|
)
|
|
logger.info(
|
|
"Inference done {}/{}. {:.4f} s / img. ETA={}".format(
|
|
idx + 1, total, seconds_per_img, str(eta)
|
|
)
|
|
)
|
|
|
|
# Measure the time only for this worker (before the synchronization barrier)
|
|
total_time = int(time.time() - start_time)
|
|
total_time_str = str(datetime.timedelta(seconds=total_time))
|
|
# NOTE this format is parsed by grep
|
|
logger.info(
|
|
"Total inference time: {} ({:.6f} s / img per device, on {} devices)".format(
|
|
total_time_str, total_time / (total - num_warmup), num_devices
|
|
)
|
|
)
|
|
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
|
|
logger.info(
|
|
"Total inference pure compute time: {} ({:.6f} s / img per device, on {} devices)".format(
|
|
total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
|
|
)
|
|
)
|
|
|
|
@contextmanager
|
|
def inference_context(model):
|
|
"""
|
|
A context where the model is temporarily changed to eval mode,
|
|
and restored to previous mode afterwards.
|
|
|
|
Args:
|
|
model: a torch Module
|
|
"""
|
|
training_mode = model.training
|
|
model.eval()
|
|
yield
|
|
model.train(training_mode) |