-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathshow.py
More file actions
25 lines (21 loc) · 1.06 KB
/
show.py
File metadata and controls
25 lines (21 loc) · 1.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import matplotlib.pyplot as plt
import numpy as np
import torchvision
def imshow(img, filename):
# img是一个PyTorch Tensor,我们需要将其转换为numpy数组
# 并且,我们需要将通道从PyTorch的[channels, height, width]转换为matplotlib期望的[height, width, channels]
img = img.numpy().transpose((1, 2, 0))
# 对图像进行反标准化处理 (如果你在预处理时进行了标准化)
# mean = np.array([0.485, 0.456, 0.406])
# std = np.array([0.229, 0.224, 0.225])
# img = std * img + mean
img = np.clip(img, 0, 1) # 修正可能出现的任何超出[0,1]范围的值
plt.imshow(img)
plt.axis('off') # 不显示坐标轴
plt.savefig(filename, bbox_inches='tight', pad_inches=0.0) # 保存图像文件
plt.close() # 关闭图形,防止再次显示
def visualize_image(x, filename):
x = x.cpu()
# 使用torchvision的make_grid函数来创建一个网格布局的图像
grid_img = torchvision.utils.make_grid(x, nrow=4) # nrow是每行显示的图像数量
imshow(grid_img, filename)