-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathScript
More file actions
293 lines (239 loc) · 11.6 KB
/
Script
File metadata and controls
293 lines (239 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import cv2
import numpy as np
from skimage import exposure
import matplotlib.pyplot as plt
from skimage.exposure import match_histograms
from PIL import ImageEnhance, Image
import os
class ImageStyleTransfer:
"""
A class for extracting image features, applying style from a reference image to a target image,
and enhancing the processed image.
"""
def __init__(self):
self.reference_features = None
def extract_features(self, image_path=None, image_rgb=None):
"""
Extract features from an image.
Parameters:
- image_path (str): Path to the image file.
- image_rgb (np.array): Optional. RGB image array.
Returns:
- dict: A dictionary containing image features.
"""
if image_rgb is None:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found: {image_path}")
# Read the image
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Unable to read image: {image_path}")
# Convert BGR to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_hsv = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2HSV)
# Extract color channels
r, g, b = cv2.split(image_rgb)
# Calculate histograms for each channel
hist_r, _ = np.histogram(r, bins=256, range=(0, 256))
hist_g, _ = np.histogram(g, bins=256, range=(0, 256))
hist_b, _ = np.histogram(b, bins=256, range=(0, 256))
# Calculate brightness and contrast
brightness = np.mean(image_rgb)
contrast = np.std(image_rgb)
# Calculate Hue, Saturation, and Vibrance
hue = np.mean(image_hsv[:, :, 0])
saturation = np.mean(image_hsv[:, :, 1])
vibrance = np.mean(np.abs(image_rgb - np.mean(image_rgb, axis=(0, 1))))
# Calculate Color Moments (mean, variance, skewness) for each channel
moments_r = cv2.moments(r)
moments_g = cv2.moments(g)
moments_b = cv2.moments(b)
hu_moments_r = cv2.HuMoments(moments_r).flatten()
hu_moments_g = cv2.HuMoments(moments_g).flatten()
hu_moments_b = cv2.HuMoments(moments_b).flatten()
# Calculate Color Coherence Vector (CCV)
def calculate_ccv(channel, threshold=10):
blurred = cv2.GaussianBlur(channel, (5, 5), 0)
mask = np.abs(channel - blurred) > threshold
coherent = np.sum(mask)
incoherent = channel.size - coherent
return coherent, incoherent
ccv_r = calculate_ccv(r)
ccv_g = calculate_ccv(g)
ccv_b = calculate_ccv(b)
return {
'image': image_rgb,
'histograms': {
'red': hist_r,
'green': hist_g,
'blue': hist_b
},
'brightness': brightness,
'contrast': contrast,
'hue': hue,
'saturation': saturation,
'vibrance': vibrance,
'color_moments': {
'red': hu_moments_r,
'green': hu_moments_g,
'blue': hu_moments_b
},
'color_coherence_vector': {
'red': ccv_r,
'green': ccv_g,
'blue': ccv_b
}
}
def save_reference_features(self, features):
"""
Save the extracted features of the reference image.
Parameters:
- features (dict): Extracted features of the reference image.
"""
self.reference_features = features
def apply_style(self, target_image_path):
"""
Apply the style of the reference image to the target image.
Parameters:
- target_image_path (str): Path to the target image file.
Returns:
- np.array: The target image with applied style.
"""
if self.reference_features is None:
raise ValueError("Reference features not set.")
if not os.path.exists(target_image_path):
raise FileNotFoundError(f"Image not found: {target_image_path}")
# Read the target image
target_image = cv2.imread(target_image_path)
if target_image is None:
raise ValueError(f"Unable to read image: {target_image_path}")
# Convert BGR to RGB
target_image_rgb = cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB)
# Histogram matching
matched_image_r = match_histograms(target_image_rgb[:, :, 0], self.reference_features['image'][:, :, 0])
matched_image_g = match_histograms(target_image_rgb[:, :, 1], self.reference_features['image'][:, :, 1])
matched_image_b = match_histograms(target_image_rgb[:, :, 2], self.reference_features['image'][:, :, 2])
matched_image = cv2.merge([matched_image_r, matched_image_g, matched_image_b])
# Normalize the matched image
matched_image = cv2.normalize(matched_image, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
matched_image = np.clip(matched_image, 0, 255).astype(np.uint8)
return matched_image
def enhance_image(self, image, brightness=1.0, contrast=1.0, color=1.0):
"""
Enhance the image by adjusting brightness, contrast, and color.
Parameters:
- image (np.array): The image to enhance.
- brightness (float): Factor to adjust brightness.
- contrast (float): Factor to adjust contrast.
- color (float): Factor to adjust color.
Returns:
- np.array: Enhanced image.
"""
pil_image = Image.fromarray(image)
enhancer = ImageEnhance.Brightness(pil_image)
image_enhanced = enhancer.enhance(brightness)
enhancer = ImageEnhance.Contrast(image_enhanced)
image_enhanced = enhancer.enhance(contrast)
enhancer = ImageEnhance.Color(image_enhanced)
image_enhanced = enhancer.enhance(color)
return np.array(image_enhanced)
def plot_results(self, reference_image, reference_features, original_image, original_features, processed_image,
processed_features, title):
"""
Plot the results including the reference image, original image, and processed image with their features.
Parameters:
- reference_image (np.array): Reference image.
- reference_features (dict): Extracted features of the reference image.
- original_image (np.array): Original target image.
- original_features (dict): Extracted features of the original target image.
- processed_image (np.array): Processed target image.
- processed_features (dict): Extracted features of the processed target image.
- title (str): Title for the plot.
"""
fig, axs = plt.subplots(3, 4, figsize=(25, 20))
fig.suptitle(title)
# Reference image
axs[0, 0].imshow(reference_image)
axs[0, 0].set_title('Reference Image')
axs[0, 0].axis('off')
# Reference image features
axs[0, 1].bar(['Brightness', 'Contrast', 'Hue', 'Saturation', 'Vibrance'],
[reference_features['brightness'], reference_features['contrast'], reference_features['hue'],
reference_features['saturation'], reference_features['vibrance']])
axs[0, 1].set_title('Reference Image Features')
axs[0, 2].plot(reference_features['histograms']['red'], color='r', label='Red')
axs[0, 2].plot(reference_features['histograms']['green'], color='g', label='Green')
axs[0, 2].plot(reference_features['histograms']['blue'], color='b', label='Blue')
axs[0, 2].set_title('Reference Image Histograms')
axs[0, 2].legend()
# Original image
axs[1, 0].imshow(original_image)
axs[1, 0].set_title('Original Image')
axs[1, 0].axis('off')
# Original image features
axs[1, 1].bar(['Brightness', 'Contrast', 'Hue', 'Saturation', 'Vibrance'],
[original_features['brightness'], original_features['contrast'], original_features['hue'],
original_features['saturation'], original_features['vibrance']])
axs[1, 1].set_title('Original Image Features')
axs[1, 2].plot(original_features['histograms']['red'], color='r', label='Red')
axs[1, 2].plot(original_features['histograms']['green'], color='g', label='Green')
axs[1, 2].plot(original_features['histograms']['blue'], color='b', label='Blue')
axs[1, 2].set_title('Original Image Histograms')
axs[1, 2].legend()
# Processed image
axs[2, 0].imshow(processed_image)
axs[2, 0].set_title('Processed Image')
axs[2, 0].axis('off')
# Processed image features
axs[2, 1].bar(['Brightness', 'Contrast', 'Hue', 'Saturation', 'Vibrance'],
[processed_features['brightness'], processed_features['contrast'], processed_features['hue'],
processed_features['saturation'], processed_features['vibrance']])
axs[2, 1].set_title('Processed Image Features')
axs[2, 2].plot(processed_features['histograms']['red'], color='r', label='Red')
axs[2, 2].plot(processed_features['histograms']['green'], color='g', label='Green')
axs[2, 2].plot(processed_features['histograms']['blue'], color='b', label='Blue')
axs[2, 2].set_title('Processed Image Histograms')
axs[2, 2].legend()
# Color Moments for processed image
axs[2, 3].bar(['Red', 'Green', 'Blue'], [np.mean(processed_features['color_moments']['red']),
np.mean(processed_features['color_moments']['green']),
np.mean(processed_features['color_moments']['blue'])])
axs[2, 3].set_title('Processed Image Color Moments')
plt.tight_layout()
plt.show()
def main():
"""
Main function to perform image style transfer and enhancement.
"""
reference_image_path = '/Users/godsonjohnson/Desktop/_DSC0819.jpg' # Path to your reference image
target_image_path = '/Users/godsonjohnson/Desktop/_DSC0720.JPG' # Path to your target image
image_style_transfer = ImageStyleTransfer()
# Extract features from the reference image and save them
try:
reference_features = image_style_transfer.extract_features(reference_image_path)
image_style_transfer.save_reference_features(reference_features)
reference_image = reference_features['image']
print("Reference image features extracted and style saved.")
except (FileNotFoundError, ValueError) as e:
print(e)
return
# Extract features from the original target image
try:
original_features = image_style_transfer.extract_features(target_image_path)
original_image = original_features['image']
except (FileNotFoundError, ValueError) as e:
print(e)
return
# Apply the style to the target image and enhance it
try:
styled_image = image_style_transfer.apply_style(target_image_path)
enhanced_image = image_style_transfer.enhance_image(styled_image, brightness=1.1, contrast=1.1, color=1.0)
# Extract features from the processed image
processed_features = image_style_transfer.extract_features(image_rgb=enhanced_image)
processed_features['image'] = enhanced_image
image_style_transfer.plot_results(reference_image, reference_features, original_image, original_features,
enhanced_image, processed_features, 'Image Style Transfer Results')
except (FileNotFoundError, ValueError) as e:
print(e)
if __name__ == '__main__':
main()