From cc9e4fffc71decf114ee3ed76cc44917b0c00308 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Sun, 6 Aug 2023 21:55:25 +0200 Subject: [PATCH] I did not realize that reading the ground truth image had a bug. This is fixed! --- .gitignore | 2 ++ debug/compare.py | 41 ++++++++++++++++++++++++++++++++++++- includes/debug_utils.cuh | 31 ++++++++++++++++++++++++++++ src/camera.cu | 44 +++++++++++++++++++++++++++++----------- src/main.cu | 3 ++- src/rasterize_points.cu | 1 - 6 files changed, 107 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index dd67d07..3751744 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ external/libtorch/ .idea *.pt .venv/ +*.png +test.py diff --git a/debug/compare.py b/debug/compare.py index 5477d32..6336c7a 100644 --- a/debug/compare.py +++ b/debug/compare.py @@ -1,9 +1,12 @@ import torch import numpy as np +from torchvision import transforms +from PIL import Image # Specification for tensors tensor_specs = { "image": {"dims": 3, "shape": [3, 546, 979], "type": float}, + "gt_image": {"dims": 3, "shape": [3, 546, 979], "type": float}, "means3D": {"dims": 2, "shape": [136029, 3], "type": float}, "sh": {"dims": 3, "shape": [136029, 16, 3], "type": float}, "colors_precomp": {"dims": 1, "shape": [0], "type": float}, @@ -29,6 +32,28 @@ "max_radii2D_masked": {"dims": 1, "shape": [136029], "type": float}, } +def save_tensor(filename, tensor, tensor_spec): + with open(filename, 'wb') as f: + # Write dims + f.write(len(tensor_spec["shape"]).to_bytes(4, 'little')) + + # Write shape + for dim in tensor_spec["shape"]: + f.write(dim.to_bytes(8, 'little')) + + # Convert tensor to numpy + data_np = tensor.numpy() + + # Write data to file + data_type = tensor_spec["type"] + if data_type == bool: + data_np.astype(np.bool_).tofile(f) + elif data_type == float: + data_np.astype(np.float32).tofile(f) + elif data_type == int: + data_np.astype(np.int32).tofile(f) + else: + data_np.astype(np.int64).tofile(f) def load_tensor(filename, tensor_spec): with open(filename, 'rb') as f: @@ -66,7 +91,7 @@ def load_tensor(filename, tensor_spec): name: load_tensor(f"libtorch_{name}.pt", tensor_specs[name]) for name in tensor_specs.keys() } -tolerance = 1e-5 +tolerance = 1e-3 for name, tensor in py_tensors.items(): print(f"======= Comparing {name} =======") @@ -76,6 +101,20 @@ def load_tensor(filename, tensor_spec): print(f"{name}: libtorch tensor is None!") continue + if name == "image" or name == "gt_image": + # Convert to PIL image + tensor_spec = { + "dims": len(tensor.shape), + "shape": tuple(tensor.shape), + "type": float + } + save_tensor(f"pytorch_{name}_libtorch.pt", tensor, tensor_spec) + pytorch_img = transforms.ToPILImage()(tensor) + libtorch_img = transforms.ToPILImage()(libtorch_tensor) + + pytorch_img.save(f"pytorch_{name}.png") + libtorch_img.save(f"libtorch_{name}.png") + if tensor is None: print(f"{name}: pytorch tensor is None!") print(f"{name} (libtorch): Shape: {libtorch_tensor.shape}") diff --git a/includes/debug_utils.cuh b/includes/debug_utils.cuh index 8e46851..2393816 100644 --- a/includes/debug_utils.cuh +++ b/includes/debug_utils.cuh @@ -38,4 +38,35 @@ namespace ts { outfile.close(); } + + inline torch::Tensor load_my_tensor(const std::string& filename) { + std::ifstream infile(filename, std::ios::binary); + if (!infile.is_open()) { + throw std::runtime_error("Failed to open file " + filename); + } + + // Read tensor dimensions + int dims; + infile.read(reinterpret_cast(&dims), sizeof(int)); + + // Read tensor sizes + std::vector sizes(dims); + infile.read(reinterpret_cast(sizes.data()), dims * sizeof(int64_t)); + + // Determine the size of the tensor data + int64_t numel = 1; + for (int i = 0; i < dims; ++i) { + numel *= sizes[i]; + } + + torch::Tensor tensor; + + // We assume here float + std::vector data(numel); + infile.read(reinterpret_cast(data.data()), numel * sizeof(float)); + tensor = torch::tensor(data).reshape(sizes); + + infile.close(); + return tensor; + } } // namespace ts diff --git a/src/camera.cu b/src/camera.cu index 1c895c2..db573f4 100644 --- a/src/camera.cu +++ b/src/camera.cu @@ -1,6 +1,7 @@ #include "camera.cuh" #include "camera_info.cuh" #include "camera_utils.cuh" +#include "debug_utils.cuh" #include "parameters.cuh" #include #include @@ -39,18 +40,37 @@ Camera::Camera(int imported_colmap_id, // TODO: I have skipped the resolution for now. Camera loadCam(const ModelParameters& params, int id, CameraInfo& cam_info) { // Create a torch::Tensor from the image data - torch::Tensor original_image_tensor = torch::from_blob(cam_info._img_data, {static_cast(cam_info._img_h), static_cast(cam_info._img_w), 3}, torch::kU8).to(torch::kFloat) / 255.f; - - // Change the view to be {height * width, 3} - // original_image_tensor = original_image_tensor.view({static_cast(cam_info._image_height) * static_cast(cam_info._image_width), 3}); - // TODO: Check if this is correct - original_image_tensor = original_image_tensor.permute({2, 0, 1}); - free_image(cam_info._img_data); // we dont longer need the image here. - cam_info._img_data = nullptr; // Assure that we dont use the image data anymore. - - if (original_image_tensor.size(0) > 3) { - original_image_tensor = original_image_tensor.slice(0, 0, 3); - throw std::runtime_error("Image has more than 3 channels. This is not supported."); + torch::Tensor original_image_tensor; + if (id != 0) { + std::vector r; + std::vector g; + std::vector b; + for (int i = 0; i < cam_info._img_h; i++) { + for (int j = 0; j < cam_info._img_w; j++) { + r.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 0]); + g.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 1]); + b.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 2]); + } + } + // concat the vectors + std::vector rgb; + rgb.insert(rgb.end(), r.begin(), r.end()); + rgb.insert(rgb.end(), g.begin(), g.end()); + rgb.insert(rgb.end(), b.begin(), b.end()); + // to torch tensor + original_image_tensor = torch::from_blob(rgb.data(), {static_cast(cam_info._img_h), static_cast(cam_info._img_w), 3}, torch::kU8).clone().to(torch::kFloat) / 255.f; + original_image_tensor = original_image_tensor.permute({2, 0, 1}); + + free_image(cam_info._img_data); // we dont longer need the image here. + cam_info._img_data = nullptr; // Assure that we dont use the image data anymore. + + if (original_image_tensor.size(0) > 3) { + original_image_tensor = original_image_tensor.slice(0, 0, 3); + throw std::runtime_error("Image has more than 3 channels. This is not supported."); + } + } else { + original_image_tensor = ts::load_my_tensor("/home/paja/projects/gaussian_splatting_cuda/cmake-build-debug/pytorch_gt_image.tp"); + std::cout << "shape: " << original_image_tensor.sizes() << std::endl; } return Camera(cam_info._camera_ID, cam_info._R, cam_info._T, cam_info._fov_x, cam_info._fov_y, original_image_tensor, diff --git a/src/main.cu b/src/main.cu index 91e1d10..ce8fee9 100644 --- a/src/main.cu +++ b/src/main.cu @@ -49,6 +49,7 @@ int main(int argc, char* argv[]) { // Loss Computations ts::save_my_tensor(image, "libtorch_image.pt"); auto gt_image = cam.Get_original_image().to(torch::kCUDA); + ts::save_my_tensor(gt_image, "libtorch_gt_image.pt"); auto l1l = gaussian_splatting::l1_loss(image, gt_image); auto loss = (1.0 - optimParams.lambda_dssim) * l1l + optimParams.lambda_dssim * (1.0 - gaussian_splatting::ssim(image, gt_image)); std::cout << "Iteration: " << iter << " Loss: " << loss.item() << std::endl; @@ -66,7 +67,7 @@ int main(int argc, char* argv[]) { auto max_radii = torch::max(visible_max_radii, visible_radii); gaussians._max_radii2D.masked_scatter_(visibility_filter, max_radii); ts::save_my_tensor(gaussians._max_radii2D, "libtorch_max_radii2D_masked.pt"); - if (iter == 2) { + if (iter == 100) { exit(0); } diff --git a/src/rasterize_points.cu b/src/rasterize_points.cu index d89b0de..0ac3c05 100644 --- a/src/rasterize_points.cu +++ b/src/rasterize_points.cu @@ -98,7 +98,6 @@ RasterizeGaussiansCUDA( if (sh.size(0) != 0) { M = sh.size(1); } - print_tensor_info(background, "background"); print_tensor_info(means3D, "means3D"); print_tensor_info(colors, "colors");