Skip to content

Commit

Permalink
finally it works. Problem was the conversion by torch. wtf
Browse files Browse the repository at this point in the history
  • Loading branch information
MrNeRF committed Aug 7, 2023
1 parent 72d39ce commit fd30423
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 47 deletions.
2 changes: 1 addition & 1 deletion includes/camera.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public:
float Get_FoVx() const { return static_cast<float>(_FoVx); }
float Get_FoVy() const { return static_cast<float>(_FoVy); }
std::string Get_image_name() const { return _image_name; }
torch::Tensor& Get_original_image() { return _original_image; }
const torch::Tensor& Get_original_image() { return _original_image; }
int Get_image_width() const { return _image_width; }
int Get_image_height() const { return _image_height; }
double Get_zfar() const { return _zfar; }
Expand Down
2 changes: 1 addition & 1 deletion includes/scene.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Camera;
class Scene {
public:
Scene(GaussianModel& gaussians, const ModelParameters& params);
[[nodiscard]] int Get_camera_count() const { return _scene_infos->_cameras.size(); }
[[nodiscard]] int Get_camera_count() const { return _cameras.size(); }
Camera& Get_training_camera(int i) { return _cameras[i]; }
[[nodiscard]] float Get_cameras_extent() const { return static_cast<float>(_scene_infos->_nerf_norm_radius); }

Expand Down
51 changes: 24 additions & 27 deletions src/camera.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "debug_utils.cuh"
#include "parameters.cuh"
#include <eigen3/Eigen/Dense>
#include <string>
#include <torch/torch.h>

Camera::Camera(int imported_colmap_id,
Expand Down Expand Up @@ -40,36 +41,32 @@ Camera::Camera(int imported_colmap_id,
// TODO: I have skipped the resolution for now.
Camera loadCam(const ModelParameters& params, int id, CameraInfo& cam_info) {
// Create a torch::Tensor from the image data
torch::Tensor original_image_tensor;
if (id != 0) {
std::vector<uint8_t> r;
std::vector<uint8_t> g;
std::vector<uint8_t> b;
for (int i = 0; i < cam_info._img_h; i++) {
for (int j = 0; j < cam_info._img_w; j++) {
r.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 0]);
g.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 1]);
b.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 2]);
}
// TODO: optimize
std::vector<uint8_t> r;
std::vector<uint8_t> g;
std::vector<uint8_t> b;
for (int i = 0; i < cam_info._img_h; i++) {
for (int j = 0; j < cam_info._img_w; j++) {
r.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 0] / 255.f);
g.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 1] / 255.f);
b.push_back(cam_info._img_data[i * cam_info._img_w * cam_info._channels + j * cam_info._channels + 2] / 255.f);
}
// concat the vectors
std::vector<uint8_t> rgb;
rgb.insert(rgb.end(), r.begin(), r.end());
rgb.insert(rgb.end(), g.begin(), g.end());
rgb.insert(rgb.end(), b.begin(), b.end());
// to torch tensor
original_image_tensor = torch::from_blob(rgb.data(), {static_cast<long>(cam_info._img_h), static_cast<long>(cam_info._img_w), 3}, torch::kU8).clone().to(torch::kFloat) / 255.f;
original_image_tensor = original_image_tensor.permute({2, 0, 1});
}
// concat the vectors
std::vector<float> rgb;
rgb.insert(rgb.end(), r.begin(), r.end());
rgb.insert(rgb.end(), g.begin(), g.end());
rgb.insert(rgb.end(), b.begin(), b.end());
// to torch tensor
torch::Tensor original_image_tensor = torch::from_blob(rgb.data(), {static_cast<long>(cam_info._img_h), static_cast<long>(cam_info._img_w), 3}, torch::kFloat32).clone();
original_image_tensor = original_image_tensor.permute({2, 0, 1});

free_image(cam_info._img_data); // we dont longer need the image here.
cam_info._img_data = nullptr; // Assure that we dont use the image data anymore.
free_image(cam_info._img_data); // we dont longer need the image here.
cam_info._img_data = nullptr; // Assure that we dont use the image data anymore.

if (original_image_tensor.size(0) > 3) {
original_image_tensor = original_image_tensor.slice(0, 0, 3);
throw std::runtime_error("Image has more than 3 channels. This is not supported.");
}
} else {
original_image_tensor = ts::load_my_tensor("/home/paja/projects/gaussian_splatting_cuda/cmake-build-debug/pytorch_gt_image.tp");
if (original_image_tensor.size(0) > 3) {
original_image_tensor = original_image_tensor.slice(0, 0, 3);
throw std::runtime_error("Image has more than 3 channels. This is not supported.");
}

return Camera(cam_info._camera_ID, cam_info._R, cam_info._T, cam_info._fov_x, cam_info._fov_y, original_image_tensor,
Expand Down
2 changes: 1 addition & 1 deletion src/camera_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Eigen::Quaterniond rotmat2qvec(const Eigen::Matrix3d& R) {

std::tuple<unsigned char*, int, int, int> read_image(std::filesystem::path image_path) {
int width, height, channels;
unsigned char* img = stbi_load(image_path.string().c_str(), &width, &height, &channels, 0);
unsigned char* img = stbi_load(image_path.string().c_str(), &width, &height, &channels, 3);
if (img == nullptr) {
throw std::runtime_error("Could not load image: " + image_path.string());
}
Expand Down
34 changes: 17 additions & 17 deletions src/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
#include <random>
#include <torch/torch.h>


std::vector<int> get_random_indices(int max_index) {
std::vector<int> indices(max_index);
std::iota(indices.begin(), indices.end(), 0);
// Shuffle the vector
std::shuffle(indices.begin(), indices.end(), std::default_random_engine());
return indices;
}

int main(int argc, char* argv[]) {

if (argc != 2) {
Expand All @@ -31,19 +40,19 @@ int main(int argc, char* argv[]) {
auto background = modelParams.white_background ? torch::tensor({1.f, 1.f, 1.f}) : torch::tensor({0.f, 0.f, 0.f}, pointType).to(torch::kCUDA);

const int camera_count = scene.Get_camera_count();
// Initialize random engine
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(0, camera_count - 1);
// training loop
for (int iter = 1; iter < optimParams.iterations; ++iter) {
std::vector<int> indices;
for (int iter = 1; iter < optimParams.iterations + 1; ++iter) {
if (iter % 1000 == 0) {
gaussians.One_up_sh_degree();
}

if (indices.empty()) {
indices = get_random_indices(camera_count);
}
const int camera_index = indices.back();
indices.pop_back(); // remove last element to iterate over all cameras randomly
auto& cam = scene.Get_training_camera(camera_index);
// Render
// const int random_index = dis(gen);
auto& cam = scene.Get_training_camera(0);
auto [image, viewspace_point_tensor, visibility_filter, radii] = render(cam, gaussians, pipelineParams, background);

// Loss Computations
Expand All @@ -57,19 +66,10 @@ int main(int argc, char* argv[]) {
loss.backward();
{
torch::NoGradGuard no_grad;
// Keep track of max radii in image-space for pruning
// ts::print_debug_info(gaussians._max_radii2D, "max_radii2D");
// ts::print_debug_info(visibility_filter, "visibility_filter");
// ts::print_debug_info(radii, "radii");
// ts::print_debug_info(viewspace_point_tensor, "viewspace_point_tensor");
auto visible_max_radii = gaussians._max_radii2D.masked_select(visibility_filter);
auto visible_radii = radii.masked_select(visibility_filter);
auto max_radii = torch::max(visible_max_radii, visible_radii);
gaussians._max_radii2D.masked_scatter_(visibility_filter, max_radii);
// ts::print_debug_info(gaussians._max_radii2D, "max_radii2D_masked");
// if (iter == 701) {
// exit(0);
// }

// TODO: support saving
// if (iteration in saving_iterations):
Expand Down

0 comments on commit fd30423

Please sign in to comment.