fujie_code/ad_train.py

47 lines
1.1 KiB
Python
Raw Normal View History

2024-07-04 17:03:29 +08:00
from torch import optim
class BaseConfig(object):
"""
Default parameters for all config files.
"""
def __init__(self):
"""
Set the defaults.
"""
self.img_dir = "inria/Train/pos"
self.lab_dir = "inria/Train/pos/yolo-labels"
self.cfgfile = "cfg/yolo.cfg"
self.weightfile = "weights/yolo.weights"
self.printfile = "non_printability/30values.txt"
self.patch_size = 300
self.start_learning_rate = 0.03
self.patch_name = 'base'
self.scheduler_factory = lambda x: optim.lr_scheduler.ReduceLROnPlateau(x, 'min', patience=50)
self.max_tv = 0
self.batch_size = 20
self.loss_target = lambda obj, cls: obj * cls
class ReproducePaperObj(BaseConfig):
"""
Reproduce the results from the paper: Generate a patch that minimises object score.
"""
def __init__(self):
super().__init__()
self.batch_size = 8
self.patch_size = 300
self.patch_name = 'ObjectOnlyPaper'
self.max_tv = 0.165
self.loss_target = lambda obj, cls: obj