-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdeep_dream.py
More file actions
96 lines (78 loc) · 2.79 KB
/
deep_dream.py
File metadata and controls
96 lines (78 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Deep Dream Algorithm.
"""
from time import time
import logging
import torch
from algorithm.googlenet import DeepGoogLeNet
import algorithm.im_utils as utils
from utils.exception import (
UnknownStyle,
PreProcessingError,
PostProcessingError
)
logger = logging.getLogger('DEEP_API')
class DeepDream:
"""
Deep Dream Algorithm.
"""
def __init__(self, cfg):
self.cfg = cfg
self.device = torch.device('cuda'
if torch.cuda.is_available()
else 'cpu')
logger.info('Deep Dream Class Initialized succesfully')
def deep_dream_loss(self, model, target):
"""
Deep Dream Loss:
Uses the model features (declared when setting the model hooks)
and uses the sum of the mean values of these features as loss.
"""
# run the image through the net
_ = model(target)
# get the loss
losses = [torch.mean(feat) for feat in model.features]
loss = torch.stack(losses, axis=0).sum()
return loss
def __call__(self, image, style_name):
"""
Deep Dream Main Algorithm.
Passes the given image through the model and uses the
Gradient Ascent Method to update the image.
"""
# prepare style
style = self.cfg.STYLES_CFG[style_name]
logger.info(f'Running deep dream algorithm using style: {style_name}')
# prepare model
model = DeepGoogLeNet(loi=style.loi)
model.eval()
# prepare input image
try:
original_size = image.size
image = image.resize((style.size, style.size))
target = utils.preprocess(image).to(self.device)
except Exception:
raise PreProcessingError()
start = time()
for e in range(style.epochs):
# reset gradient
if target.grad is not None:
target.grad.zero_()
# loss backward
loss = self.deep_dream_loss(model, target)
loss.backward(retain_graph=True)
# gradient ascent step (standarizing the gradient)
grad = target.grad.data / (torch.std(target.grad.data) + 1e-8)
target.data = target.data + grad * style.learning_rate
# clip pixel values
target.data = utils.clip(target.data)
logger.debug(f'Epoch {e}/{style.epochs} '
f'took: {time() - start:.2f}')
logger.info('Deep Dream with style: '
f'{style_name} took: {time() - start:.2f}')
try:
dream = target.cpu().clone().detach().squeeze(0)
dream = utils.postprocess(dream).resize(original_size)
except Exception:
raise PostProcessingError()
return dream