Skip to content

Commit

Permalink
I did not realize that reading the ground truth image had a bug. This…
Browse files Browse the repository at this point in the history
… is fixed!
  • Loading branch information
MrNeRF committed Aug 6, 2023
1 parent 05f73d4 commit cc9e4ff
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 15 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ external/libtorch/
.idea
*.pt
.venv/
*.png
test.py
41 changes: 40 additions & 1 deletion debug/compare.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import torch
import numpy as np
from torchvision import transforms
from PIL import Image

# Specification for tensors
tensor_specs = {
"image": {"dims": 3, "shape": [3, 546, 979], "type": float},
"gt_image": {"dims": 3, "shape": [3, 546, 979], "type": float},
"means3D": {"dims": 2, "shape": [136029, 3], "type": float},
"sh": {"dims": 3, "shape": [136029, 16, 3], "type": float},
"colors_precomp": {"dims": 1, "shape": [0], "type": float},
Expand All @@ -29,6 +32,28 @@
"max_radii2D_masked": {"dims": 1, "shape": [136029], "type": float},
}

def save_tensor(filename, tensor, tensor_spec):
with open(filename, 'wb') as f:
# Write dims
f.write(len(tensor_spec["shape"]).to_bytes(4, 'little'))

# Write shape
for dim in tensor_spec["shape"]:
f.write(dim.to_bytes(8, 'little'))

# Convert tensor to numpy
data_np = tensor.numpy()

# Write data to file
data_type = tensor_spec["type"]
if data_type == bool:
data_np.astype(np.bool_).tofile(f)
elif data_type == float:
data_np.astype(np.float32).tofile(f)
elif data_type == int:
data_np.astype(np.int32).tofile(f)
else:
data_np.astype(np.int64).tofile(f)

def load_tensor(filename, tensor_spec):
with open(filename, 'rb') as f:
Expand Down Expand Up @@ -66,7 +91,7 @@ def load_tensor(filename, tensor_spec):
name: load_tensor(f"libtorch_{name}.pt", tensor_specs[name]) for name in tensor_specs.keys()
}

tolerance = 1e-5
tolerance = 1e-3

for name, tensor in py_tensors.items():
print(f"======= Comparing {name} =======")
Expand All @@ -76,6 +101,20 @@ def load_tensor(filename, tensor_spec):
print(f"{name}: libtorch tensor is None!")
continue

if name == "image" or name == "gt_image":
# Convert to PIL image
tensor_spec = {
"dims": len(tensor.shape),
"shape": tuple(tensor.shape),
"type": float
}
save_tensor(f"pytorch_{name}_libtorch.pt", tensor, tensor_spec)
pytorch_img = transforms.ToPILImage()(tensor)
libtorch_img = transforms.ToPILImage()(libtorch_tensor)

pytorch_img.save(f"pytorch_{name}.png")
libtorch_img.save(f"libtorch_{name}.png")

if tensor is None:
print(f"{name}: pytorch tensor is None!")
print(f"{name} (libtorch): Shape: {libtorch_tensor.shape}")
Expand Down
31 changes: 31 additions & 0 deletions includes/debug_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,35 @@ namespace ts {

outfile.close();
}

inline torch::Tensor load_my_tensor(const std::string& filename) {
std::ifstream infile(filename, std::ios::binary);
if (!infile.is_open()) {
throw std::runtime_error("Failed to open file " + filename);
}

// Read tensor dimensions
int dims;
infile.read(reinterpret_cast<char*>(&dims), sizeof(int));

// Read tensor sizes
std::vector<int64_t> sizes(dims);
infile.read(reinterpret_cast<char*>(sizes.data()), dims * sizeof(int64_t));

// Determine the size of the tensor data
int64_t numel = 1;
for (int i = 0; i < dims; ++i) {
numel *= sizes[i];
}

torch::Tensor tensor;

// We assume here float
std::vector<float> data(numel);
infile.read(reinterpret_cast<char*>(data.data()), numel * sizeof(float));
tensor = torch::tensor(data).reshape(sizes);

infile.close();
return tensor;
}
} // namespace ts
44 changes: 32 additions & 12 deletions src/camera.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "camera.cuh"
#include "camera_info.cuh"
#include "camera_utils.cuh"
#include "debug_utils.cuh"
#include "parameters.cuh"
#include <eigen3/Eigen/Dense>
#include <torch/torch.h>
Expand Down Expand Up @@ -39,18 +40,37 @@ Camera::Camera(int imported_colmap_id,
// TODO: I have skipped the resolution for now.
Camera loadCam(const ModelParameters& params, int id, CameraInfo& cam_info) {
// Create a torch::Tensor from the image data
torch::Tensor original_image_tensor = torch::from_blob(cam_info._img_data, {static_cast<long>(cam_info._img_h), static_cast<long>(cam_info._img_w), 3}, torch::kU8).to(torch::kFloat) / 255.f;

// Change the view to be {height * width, 3}
// original_image_tensor = original_image_tensor.view({static_cast<long>(cam_info._image_height) * static_cast<long>(cam_info._image_width), 3});
// TODO: Check if this is correct
original_image_tensor = original_image_tensor.permute({2, 0, 1});
free_image(cam_info._img_data); // we dont longer need the image here.
cam_info._img_data = nullptr; // Assure that we dont use the image data anymore.

if (original_image_tensor.size(0) > 3) {
original_image_tensor = original_image_tensor.slice(0, 0, 3);
throw std::runtime_error("Image has more than 3 channels. This is not supported.");
torch::Tensor original_image_tensor;
if (id != 0) {
std::vector<uint8_t> r;
std::vector<uint8_t> g;
std::vector<uint8_t> b;
for (int i = 0; i < cam_info._img_h; i++) {
for (int j = 0; j < cam_info._img_w; j++) {
r.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 0]);
g.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 1]);
b.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 2]);
}
}
// concat the vectors
std::vector<uint8_t> rgb;
rgb.insert(rgb.end(), r.begin(), r.end());
rgb.insert(rgb.end(), g.begin(), g.end());
rgb.insert(rgb.end(), b.begin(), b.end());
// to torch tensor
original_image_tensor = torch::from_blob(rgb.data(), {static_cast<long>(cam_info._img_h), static_cast<long>(cam_info._img_w), 3}, torch::kU8).clone().to(torch::kFloat) / 255.f;
original_image_tensor = original_image_tensor.permute({2, 0, 1});

free_image(cam_info._img_data); // we dont longer need the image here.
cam_info._img_data = nullptr; // Assure that we dont use the image data anymore.

if (original_image_tensor.size(0) > 3) {
original_image_tensor = original_image_tensor.slice(0, 0, 3);
throw std::runtime_error("Image has more than 3 channels. This is not supported.");
}
} else {
original_image_tensor = ts::load_my_tensor("/home/paja/projects/gaussian_splatting_cuda/cmake-build-debug/pytorch_gt_image.tp");
std::cout << "shape: " << original_image_tensor.sizes() << std::endl;
}

return Camera(cam_info._camera_ID, cam_info._R, cam_info._T, cam_info._fov_x, cam_info._fov_y, original_image_tensor,
Expand Down
3 changes: 2 additions & 1 deletion src/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ int main(int argc, char* argv[]) {
// Loss Computations
ts::save_my_tensor(image, "libtorch_image.pt");
auto gt_image = cam.Get_original_image().to(torch::kCUDA);
ts::save_my_tensor(gt_image, "libtorch_gt_image.pt");
auto l1l = gaussian_splatting::l1_loss(image, gt_image);
auto loss = (1.0 - optimParams.lambda_dssim) * l1l + optimParams.lambda_dssim * (1.0 - gaussian_splatting::ssim(image, gt_image));
std::cout << "Iteration: " << iter << " Loss: " << loss.item<float>() << std::endl;
Expand All @@ -66,7 +67,7 @@ int main(int argc, char* argv[]) {
auto max_radii = torch::max(visible_max_radii, visible_radii);
gaussians._max_radii2D.masked_scatter_(visibility_filter, max_radii);
ts::save_my_tensor(gaussians._max_radii2D, "libtorch_max_radii2D_masked.pt");
if (iter == 2) {
if (iter == 100) {
exit(0);
}

Expand Down
1 change: 0 additions & 1 deletion src/rasterize_points.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ RasterizeGaussiansCUDA(
if (sh.size(0) != 0) {
M = sh.size(1);
}

print_tensor_info(background, "background");
print_tensor_info(means3D, "means3D");
print_tensor_info(colors, "colors");
Expand Down

0 comments on commit cc9e4ff

Please sign in to comment.