-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
102 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import onnx | ||
import numpy as np | ||
from onnx import numpy_helper | ||
from printk import print_colored_box, print_colored_box_line | ||
|
||
|
||
def expand_dim(model_path): | ||
# 加载原始模型 | ||
onnx_model = onnx.load(model_path) | ||
|
||
# 获取模型的输出张量 | ||
output_tensor = None | ||
for node in onnx_model.graph.output: | ||
if node.name == "output": # 将 "output" 替换模型中实际的输出张量名称 | ||
output_tensor = node | ||
break | ||
print(output_tensor) | ||
# 修改输出张量的维度 | ||
if output_tensor is not None: | ||
output_tensor.type.tensor_type.shape.dim.insert(0, onnx.TensorShapeProto.Dimension(dim_value=1)) | ||
|
||
# 保存修改后的模型 | ||
onnx.save(onnx_model, "./modified_model.onnx") | ||
print_colored_box("The modified ONNX model is saved in ./modified_model.onnx") | ||
|
||
|
||
if __name__ == "__main__": | ||
model_path="/Users/bruce/Downloads/Chip_test_models/backbone_224_224/regnety_002.onnx" | ||
expand_dim(model_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import torch | ||
import torchvision.transforms as T | ||
import matplotlib.pyplot as plt | ||
from PIL import Image | ||
import numpy as np | ||
|
||
def rotate_with_grid_sample(image, angle): | ||
# 将PIL图像转换为torch张量 | ||
transform = T.Compose([T.ToTensor()]) | ||
tensor = transform(image).unsqueeze(0) # 添加batch维度 | ||
|
||
# 创建旋转网格 | ||
theta = torch.tensor([ | ||
[np.cos(np.radians(angle)), np.sin(-np.radians(angle)), 0], | ||
[np.sin(np.radians(angle)), np.cos(np.radians(angle)), 0] | ||
], dtype=torch.float) | ||
|
||
grid = torch.nn.functional.affine_grid(theta.unsqueeze(0), tensor.size(), align_corners=False) | ||
|
||
# 使用grid_sample进行采样 | ||
rotated_tensor = torch.nn.functional.grid_sample(tensor, grid, align_corners=False) | ||
|
||
# 将结果张量转换回PIL图像 | ||
rotated_image = T.ToPILImage()(rotated_tensor.squeeze(0)) | ||
return rotated_image | ||
|
||
|
||
def visualize_image_operations(image_path): | ||
# 加载图像 | ||
img = Image.open(image_path) | ||
|
||
# 裁剪图像(例如,裁剪中心区域) | ||
width, height = img.size | ||
new_width, new_height = width // 2, height // 2 | ||
left = (width - new_width) // 2 | ||
top = (height - new_height) // 2 | ||
right = (width + new_width) // 2 | ||
bottom = (height + new_height) // 2 | ||
cropped_img = img.crop((left, top, right, bottom)) | ||
|
||
# 调整大小(例如,将图像缩小一半) | ||
resized_img = img.resize((width // 2, height // 2)) | ||
|
||
# 网格采样:这里我们使用旋转作为示例 | ||
angle = 45 # 旋转角度 | ||
rotated_image = rotate_with_grid_sample(img, angle) | ||
|
||
# 可视化 | ||
plt.figure(figsize=(10, 8)) | ||
|
||
plt.subplot(2, 2, 1) | ||
plt.imshow(img) | ||
plt.title('Original Image') | ||
|
||
plt.subplot(2, 2, 2) | ||
plt.imshow(cropped_img) | ||
plt.title('Cropped Image') | ||
|
||
plt.subplot(2, 2, 3) | ||
plt.imshow(resized_img) | ||
plt.title('Resized Image') | ||
|
||
plt.subplot(2, 2, 4) | ||
plt.imshow(rotated_image) | ||
plt.title(f'Rotated Image by {angle} Degrees') | ||
|
||
plt.tight_layout() | ||
plt.show() | ||
|
||
|
||
# 替换以下路径为你的图片路径 | ||
image_path = "/Users/bruce/PycharmProjects/Pytorch_learning/Deploy/face_torch_3.png" | ||
visualize_image_operations(image_path) |