fujie_code/patch_config.py

136 lines
2.8 KiB
Python

from torch import optim
class BaseConfig(object):
"""
Default parameters for all config files.
"""
def __init__(self):
"""
Set the defaults.
"""
# self.img_dir = "inria/Train/pos"
# self.lab_dir = "inria/Train/pos/yolo-labels"
self.img_dir = "cctsdb/Train/pos"
self.lab_dir = "cctsdb/Train/labels"
self.cfgfile = "cfg/yolo.cfg"
self.weightfile = "weights/yolo.weights"
self.printfile = "non_printability/30values.txt"
self.patch_size = 300
self.start_learning_rate = 0.03
self.patch_name = 'base'
self.scheduler_factory = lambda x: optim.lr_scheduler.ReduceLROnPlateau(x, 'min', patience=50)
self.max_tv = 0
self.batch_size = 20
self.loss_target = lambda obj, cls: obj * cls
class Experiment1(BaseConfig):
"""
Model that uses a maximum total variation, tv cannot go below this point.
"""
def __init__(self):
"""
Change stuff...
"""
super().__init__()
self.patch_name = 'Experiment1'
self.max_tv = 0.165
class Experiment2HighRes(Experiment1):
"""
Higher res
"""
def __init__(self):
"""
Change stuff...
"""
super().__init__()
self.max_tv = 0.165
self.patch_size = 400
self.patch_name = 'Exp2HighRes'
class Experiment3LowRes(Experiment1):
"""
Lower res
"""
def __init__(self):
"""
Change stuff...
"""
super().__init__()
self.max_tv = 0.165
self.patch_size = 100
self.patch_name = "Exp3LowRes"
class Experiment4ClassOnly(Experiment1):
"""
Only minimise class score.
"""
def __init__(self):
"""
Change stuff...
"""
super().__init__()
self.patch_name = 'Experiment4ClassOnly'
self.loss_target = lambda obj, cls: cls
class Experiment1Desktop(Experiment1):
"""
"""
def __init__(self):
"""
Change batch size.
"""
super().__init__()
self.batch_size = 8
self.patch_size = 400
class ReproducePaperObj(BaseConfig):
"""
Reproduce the results from the paper: Generate a patch that minimises object score.
"""
def __init__(self):
super().__init__()
self.batch_size = 8
self.patch_size = 300
self.patch_name = 'ObjectOnlyPaper'
self.max_tv = 0.165
self.loss_target = lambda obj, cls: obj
patch_configs = {
"base": BaseConfig,
"exp1": Experiment1,
"exp1_des": Experiment1Desktop,
"exp2_high_res": Experiment2HighRes,
"exp3_low_res": Experiment3LowRes,
"exp4_class_only": Experiment4ClassOnly,
"paper_obj": ReproducePaperObj
}