From cd94ab97c93152fd98aca6474c40e8196bcc8a24 Mon Sep 17 00:00:00 2001 From: rsy6318 Date: Fri, 23 Feb 2024 14:40:05 +0800 Subject: [PATCH 1/2] Update INSTALL.md --- INSTALL.md | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/INSTALL.md b/INSTALL.md index 5439a4ed..1bacbdb0 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -1,11 +1,10 @@ # Installation - ## Requirements ### Core library -The core library is written in PyTorch. Several components have underlying implementation in CUDA for improved performance. A subset of these components have CPU implementations in C++/PyTorch. It is advised to use PyTorch3D with GPU support in order to use all the features. +The core library is written in PyTorch. Several components have underlying implementation in CUDA for improved performance. A subset of these components have CPU implementations in C++/PyTorch. It is advised to use PyTorch3D with GPU support in order to use all the features. - Linux or macOS or Windows - Python 3.8, 3.9 or 3.10 @@ -18,6 +17,7 @@ The core library is written in PyTorch. Several components have underlying imple - If CUDA older than 11.7 is to be used and you are building from source, the CUB library must be available. We recommend version 1.10.0. The runtime dependencies can be installed by running: + ``` conda create -n pytorch3d python=3.9 conda activate pytorch3d @@ -26,12 +26,15 @@ conda install -c fvcore -c iopath -c conda-forge fvcore iopath ``` For the CUB build time dependency, which you only need if you have CUDA older than 11.7, if you are using conda, you can continue with + ``` conda install -c bottler nvidiacub ``` + Otherwise download the CUB library from https://github.com/NVIDIA/cub/releases and unpack it to a folder of your choice. Define the environment variable CUB_HOME before building and point it to the directory that contains `CMakeLists.txt` for CUB. For example on Linux/Mac, + ``` curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz tar xzf 1.10.0.tar.gz @@ -41,6 +44,7 @@ export CUB_HOME=$PWD/cub-1.10.0 ### Tests/Linting and Demos For developing on top of PyTorch3D or contributing, you will need to run the linter and tests. If you want to run any of the notebook tutorials as `docs/tutorials` or the examples in `docs/examples` you will also need matplotlib and OpenCV. + - scikit-image - black - usort @@ -53,6 +57,7 @@ For developing on top of PyTorch3D or contributing, you will need to run the lin - opencv-python These can be installed by running: + ``` # Demos and examples conda install jupyter @@ -63,6 +68,7 @@ pip install black usort flake8 flake8-bugbear flake8-comprehensions ``` ## Installing prebuilt binaries for PyTorch3D + After installing the above dependencies, run one of the following commands: ### 1. Install with CUDA support from Anaconda Cloud, on Linux only @@ -73,21 +79,25 @@ conda install pytorch3d -c pytorch3d ``` Or, to install a nightly (non-official, alpha) build: + ``` # Anaconda Cloud conda install pytorch3d -c pytorch3d-nightly ``` ### 2. Install wheels for Linux + We have prebuilt wheels with CUDA for Linux for PyTorch 1.11.0, for each of the supported CUDA versions, for Python 3.8 and 3.9. This is for ease of use on Google Colab. These are installed in a special way. For example, to install for Python 3.8, PyTorch 1.11.0 and CUDA 11.3 + ``` pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html ``` In general, from inside IPython, or in Google Colab or a jupyter notebook, you can install with + ``` import sys import torch @@ -102,14 +112,18 @@ version_str="".join([ ``` ## Building / installing from source. + CUDA support will be included if CUDA is available in pytorch or if the environment variable `FORCE_CUDA` is set to `1`. ### 1. Install from GitHub + ``` pip install "git+https://github.com/facebookresearch/pytorch3d.git" ``` + To install using the code of the released version instead of from the main branch, use the following instead. + ``` pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" ``` @@ -118,18 +132,22 @@ For CUDA builds with versions earlier than CUDA 11, set `CUB_HOME` before buildi **Install from Github on macOS:** Some environment variables should be provided, like this. + ``` MACOSX_DEPLOYMENT_TARGET=10.14 CC=clang CXX=clang++ pip install "git+https://github.com/facebookresearch/pytorch3d.git" ``` ### 2. Install from a local clone + ``` git clone https://github.com/facebookresearch/pytorch3d.git cd pytorch3d && pip install -e . ``` + To rebuild after installing from a local clone run, `rm -rf build/ **/*.so` then `pip install -e .`. You often need to rebuild pytorch3d after reinstalling PyTorch. For CUDA builds with versions earlier than CUDA 11, set `CUB_HOME` before building as described above. **Install from local clone on macOS:** + ``` MACOSX_DEPLOYMENT_TARGET=10.14 CC=clang CXX=clang++ pip install -e . ``` @@ -139,12 +157,14 @@ MACOSX_DEPLOYMENT_TARGET=10.14 CC=clang CXX=clang++ pip install -e . Depending on the version of PyTorch, changes to some PyTorch headers may be needed before compilation. These are often discussed in issues in this repository. After any necessary patching, you can go to "x64 Native Tools Command Prompt for VS 2019" to compile and install + ``` cd pytorch3d python3 setup.py install ``` After installing, you can run **unit tests** + ``` python3 -m unittest discover -v -s tests -t . ``` From ef43b8ddbf6480c9cf49d0f076fb7d12c0ecf285 Mon Sep 17 00:00:00 2001 From: rsy6318 Date: Fri, 23 Feb 2024 14:41:35 +0800 Subject: [PATCH 2/2] add DirDist --- .circleci/regenerate.py | 0 dev/linter.sh | 0 docs/examples/pulsar_basic.py | 0 docs/examples/pulsar_basic_unified.py | 0 docs/examples/pulsar_cam.py | 0 docs/examples/pulsar_cam_unified.py | 0 docs/examples/pulsar_multiview.py | 0 docs/examples/pulsar_optimization.py | 0 docs/examples/pulsar_optimization_unified.py | 0 docs/tutorials/data/cow_mesh/cow.mtl | 0 docs/tutorials/data/cow_mesh/cow.obj | 0 docs/tutorials/data/cow_mesh/cow_texture.png | Bin packaging/build_wheel.sh | 0 packaging/conda/build_pytorch3d.sh | 0 packaging/conda/switch_cuda_version.sh | 0 projects/implicitron_trainer/experiment.py | 0 .../closest_point_on_surface.cu | 269 +++++++++++++++++ .../closest_point_on_surface.h | 20 ++ .../closest_point_on_surface_cpu.cpp | 27 ++ pytorch3d/csrc/ext.cpp | 4 + pytorch3d/loss/__init__.py | 2 +- pytorch3d/loss/dirdist.py | 282 ++++++++++++++++++ pytorch3d/renderer/opengl/opengl_utils.py | 0 scripts/parse_tutorials.py | 0 setup.py | 0 tests/benchmarks/bm_main.py | 0 tests/benchmarks/bm_pulsar.py | 0 tests/data/missing_usemtl/cow.mtl | 0 tests/data/missing_usemtl/cow.obj | 0 tests/data/missing_usemtl/cow_texture.png | Bin 30 files changed, 603 insertions(+), 1 deletion(-) mode change 100755 => 100644 .circleci/regenerate.py mode change 100755 => 100644 dev/linter.sh mode change 100755 => 100644 docs/examples/pulsar_basic.py mode change 100755 => 100644 docs/examples/pulsar_basic_unified.py mode change 100755 => 100644 docs/examples/pulsar_cam.py mode change 100755 => 100644 docs/examples/pulsar_cam_unified.py mode change 100755 => 100644 docs/examples/pulsar_multiview.py mode change 100755 => 100644 docs/examples/pulsar_optimization.py mode change 100755 => 100644 docs/examples/pulsar_optimization_unified.py mode change 100755 => 100644 docs/tutorials/data/cow_mesh/cow.mtl mode change 100755 => 100644 docs/tutorials/data/cow_mesh/cow.obj mode change 100755 => 100644 docs/tutorials/data/cow_mesh/cow_texture.png mode change 100755 => 100644 packaging/build_wheel.sh mode change 100755 => 100644 packaging/conda/build_pytorch3d.sh mode change 100755 => 100644 packaging/conda/switch_cuda_version.sh mode change 100755 => 100644 projects/implicitron_trainer/experiment.py create mode 100644 pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface.cu create mode 100644 pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface.h create mode 100644 pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface_cpu.cpp create mode 100644 pytorch3d/loss/dirdist.py mode change 100755 => 100644 pytorch3d/renderer/opengl/opengl_utils.py mode change 100755 => 100644 scripts/parse_tutorials.py mode change 100755 => 100644 setup.py mode change 100755 => 100644 tests/benchmarks/bm_main.py mode change 100755 => 100644 tests/benchmarks/bm_pulsar.py mode change 100755 => 100644 tests/data/missing_usemtl/cow.mtl mode change 100755 => 100644 tests/data/missing_usemtl/cow.obj mode change 100755 => 100644 tests/data/missing_usemtl/cow_texture.png diff --git a/.circleci/regenerate.py b/.circleci/regenerate.py old mode 100755 new mode 100644 diff --git a/dev/linter.sh b/dev/linter.sh old mode 100755 new mode 100644 diff --git a/docs/examples/pulsar_basic.py b/docs/examples/pulsar_basic.py old mode 100755 new mode 100644 diff --git a/docs/examples/pulsar_basic_unified.py b/docs/examples/pulsar_basic_unified.py old mode 100755 new mode 100644 diff --git a/docs/examples/pulsar_cam.py b/docs/examples/pulsar_cam.py old mode 100755 new mode 100644 diff --git a/docs/examples/pulsar_cam_unified.py b/docs/examples/pulsar_cam_unified.py old mode 100755 new mode 100644 diff --git a/docs/examples/pulsar_multiview.py b/docs/examples/pulsar_multiview.py old mode 100755 new mode 100644 diff --git a/docs/examples/pulsar_optimization.py b/docs/examples/pulsar_optimization.py old mode 100755 new mode 100644 diff --git a/docs/examples/pulsar_optimization_unified.py b/docs/examples/pulsar_optimization_unified.py old mode 100755 new mode 100644 diff --git a/docs/tutorials/data/cow_mesh/cow.mtl b/docs/tutorials/data/cow_mesh/cow.mtl old mode 100755 new mode 100644 diff --git a/docs/tutorials/data/cow_mesh/cow.obj b/docs/tutorials/data/cow_mesh/cow.obj old mode 100755 new mode 100644 diff --git a/docs/tutorials/data/cow_mesh/cow_texture.png b/docs/tutorials/data/cow_mesh/cow_texture.png old mode 100755 new mode 100644 diff --git a/packaging/build_wheel.sh b/packaging/build_wheel.sh old mode 100755 new mode 100644 diff --git a/packaging/conda/build_pytorch3d.sh b/packaging/conda/build_pytorch3d.sh old mode 100755 new mode 100644 diff --git a/packaging/conda/switch_cuda_version.sh b/packaging/conda/switch_cuda_version.sh old mode 100755 new mode 100644 diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py old mode 100755 new mode 100644 diff --git a/pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface.cu b/pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface.cu new file mode 100644 index 00000000..13158b34 --- /dev/null +++ b/pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface.cu @@ -0,0 +1,269 @@ +#include +#include +#include + +#include + + +__device__ float clamp(const float input, const float min_value, const float max_value) +{ + float output=input; + if (input>max_value) + { + output=max_value; + } + else if(input tmp0 ) + { + float numer = tmp1 - tmp0; + float denom = a-2*b+c; + s = clamp( numer/denom, 0.f, 1.f ); + t = 1-s; + } + else + { + t = clamp( -e/c, 0.f, 1.f ); + s = 0.f; + } + } + else if ( t < 0.f ) + { + if ( a+d > b+e ) + { + float numer = c+e-b-d; + float denom = a-2*b+c; + s = clamp( numer/denom, 0.f, 1.f ); + t = 1-s; + } + else + { + s = clamp( -e/c, 0.f, 1.f ); + t = 0.f; + } + } + else + { + float numer = c+e-b-d; + float denom = a-2*b+c; + s = clamp( numer/denom, 0.f, 1.f ); + t = 1.f - s; + } + } + result_s=s; + result_t=t; + result_dist=a*s*s+2*b*s*t+c*t*t+2*d*s+2*e*t+f; + + //return 1; +} + + + +__global__ void closestPointonSurface_kernel( + int n, //number of points + const float* points, + int m, //number of triangles + const float* f1, + const float* f2, + const float* f3, + float* w1, + float* w2, + float* w3, + int* indexes +) +{ + const int batch=1024; + float dist_temp; + float s_temp; + float t_temp; + + __shared__ float f1_buff[batch*3]; + __shared__ float f2_buff[batch*3]; + __shared__ float f3_buff[batch*3]; + + float f1_x; + float f1_y; + float f1_z; + float f2_x; + float f2_y; + float f2_z; + float f3_x; + float f3_y; + float f3_z; + + int tid=threadIdx.x+blockIdx.x*blockDim.x; + for (int i =tid;i=dist_temp) + { + best_dist=dist_temp; + best_idx=j+start; + best_s=s_temp; + best_t=t_temp; + } + } + __syncthreads(); + } + + w1[i]=1-best_s-best_t; + w2[i]=best_s; + w3[i]=best_t; + indexes[i]=best_idx; + } + +} + + + + +std::vector closestPointonSurface_cuda_forward( + at::Tensor points, + at::Tensor f1, + at::Tensor f2, + at::Tensor f3 +) +{ + const int n=points.size(0); + const int m=f1.size(0); + //printf("%d",m); + at::Tensor w1=torch::zeros({n},torch::CUDA(torch::kFloat)); + at::Tensor w2=torch::zeros({n},torch::CUDA(torch::kFloat)); + at::Tensor w3=torch::zeros({n},torch::CUDA(torch::kFloat)); + at::Tensor indexes=torch::zeros({n},torch::CUDA(torch::kInt)); + + closestPointonSurface_kernel<<<32768,1024>>>( + n,points.data_ptr(),m, + f1.data_ptr(),f2.data_ptr(),f3.data_ptr(), + w1.data_ptr(),w2.data_ptr(),w3.data_ptr(), + indexes.data_ptr() + ); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("Error in closestPointonSurface_cuda_forward: %s\n", cudaGetErrorString(err)); + } + + return {indexes,w1,w2,w3}; + +} \ No newline at end of file diff --git a/pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface.h b/pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface.h new file mode 100644 index 00000000..f9eea913 --- /dev/null +++ b/pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +#include "utils/pytorch3d_cutils.h" + + +//std::vector closestPointonSurface_cuda_forward( + // at::Tensor points, + // at::Tensor f1, + // at::Tensor f2, + // at::Tensor f3 +//); + +std::vector closestPointonSurface_forward( + at::Tensor points, + at::Tensor f1, + at::Tensor f2, + at::Tensor f3 + ); \ No newline at end of file diff --git a/pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface_cpu.cpp b/pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface_cpu.cpp new file mode 100644 index 00000000..39cc3158 --- /dev/null +++ b/pytorch3d/csrc/closest_point_on_surface/closest_point_on_surface_cpu.cpp @@ -0,0 +1,27 @@ +#include +#include + +std::vector closestPointonSurface_cuda_forward( + at::Tensor points, + at::Tensor f1, + at::Tensor f2, + at::Tensor f3 +); + +std::vector closestPointonSurface_forward( + at::Tensor points, + at::Tensor f1, + at::Tensor f2, + at::Tensor f3 + ) +{ + return closestPointonSurface_cuda_forward(points,f1,f2,f3); +} + + + +//PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +//{ +// m.def("forward", &closestPointonSurface_forward, "forward (CUDA)"); + //m.def("backward", &closest_point_on_surface_backward, "backward (CUDA)"); +//} \ No newline at end of file diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 6a17dbb0..bd5a93b3 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -31,6 +31,7 @@ #include "rasterize_points/rasterize_points.h" #include "sample_farthest_points/sample_farthest_points.h" #include "sample_pdf/sample_pdf.h" +#include "closest_point_on_surface/closest_point_on_surface.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("face_areas_normals_forward", &FaceAreasNormalsForward); @@ -42,6 +43,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #ifdef WITH_CUDA m.def("knn_check_version", &KnnCheckVersion); #endif + + m.def("closest_point_on_surface_forward", &closestPointonSurface_forward); + m.def("knn_points_idx", &KNearestNeighborIdx); m.def("knn_points_backward", &KNearestNeighborBackward); m.def("ball_query", &BallQuery); diff --git a/pytorch3d/loss/__init__.py b/pytorch3d/loss/__init__.py index 2b8d10de..8255d091 100644 --- a/pytorch3d/loss/__init__.py +++ b/pytorch3d/loss/__init__.py @@ -10,6 +10,6 @@ from .mesh_laplacian_smoothing import mesh_laplacian_smoothing from .mesh_normal_consistency import mesh_normal_consistency from .point_mesh_distance import point_mesh_edge_distance, point_mesh_face_distance - +from .dirdist import DirDist_P2P,DirDist_M2P,DirDist_M2M __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/loss/dirdist.py b/pytorch3d/loss/dirdist.py new file mode 100644 index 00000000..115ab088 --- /dev/null +++ b/pytorch3d/loss/dirdist.py @@ -0,0 +1,282 @@ +import torch +from pytorch3d import _C +from pytorch3d.ops import knn_points + +class P2F_dist(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self,points, vertices, faces): + #points: N,3 + #vertices: M,3 + #faces: M,3 + f1_idx,f2_idx,f3_idx=faces[:,0],faces[:,1],faces[:,2] + f1=vertices[f1_idx] + f2=vertices[f2_idx] + f3=vertices[f3_idx] + indexes,w1,w2,w3=_C.closest_point_on_surface_forward(points,f1,f2,f3) + return indexes,w1,w2,w3 + + +def face_area_normals(faces,vs): + #faces [M,3] + #vs [B,N,3] + face_normals=torch.cross(vs[:,faces[:,1],:]-vs[:,faces[:,0],:], + vs[:,faces[:,2],:]-vs[:,faces[:,0],:],dim=2) + face_areas=torch.norm(face_normals,dim=2) + face_normals=face_normals/face_areas[:,:,None] + face_areas=0.5*face_areas + return face_areas,face_normals + +def sampl_surface(faces,vs,count): + #faces: [M,3] + #vs: [N,3] + vs=vs.unsqueeze(0) + bsize,nvs,_=vs.shape + weights,normal=face_area_normals(faces,vs) + weights_sum = torch.sum(weights, dim=1) + dist = torch.distributions.categorical.Categorical(probs=weights / weights_sum[:, None]) + face_index = dist.sample((count,)) + + tri_origins = vs[:, faces[:, 0], :] + tri_vectors = vs[:, faces[:, 1:], :].clone() + tri_vectors -= tri_origins.repeat(1, 1, 2).reshape((bsize, len(faces), 2, 3)) + + # pull the vectors for the faces we are going to sample from + face_index = face_index.transpose(0, 1) + face_index = face_index[:, :, None].expand((bsize, count, 3)) + tri_origins = torch.gather(tri_origins, dim=1, index=face_index) + face_index2 = face_index[:, :, None, :].expand((bsize, count, 2, 3)) + tri_vectors = torch.gather(tri_vectors, dim=1, index=face_index2) + + # randomly generate two 0-1 scalar components to multiply edge vectors by + random_lengths = torch.rand(count, 2, 1, device=vs.device, dtype=tri_vectors.dtype) + + # points will be distributed on a quadrilateral if we use 2x [0-1] samples + # if the two scalar components sum less than 1.0 the point will be + # inside the triangle, so we find vectors longer than 1.0 and + # transform them to be inside the triangle + random_test = random_lengths.sum(dim=1).reshape(-1) > 1.0 + random_lengths[random_test] -= 1.0 + random_lengths = torch.abs(random_lengths) + + # multiply triangle edge vectors by the random lengths and sum + sample_vector = (tri_vectors * random_lengths[None, :]).sum(dim=2) + + # finally, offset by the origin to generate + # (n,3) points in space on the triangle + samples = sample_vector + tri_origins + + normals = torch.gather(normal, dim=1, index=face_index) + + return samples.squeeze(0), normals.squeeze(0) + + +def get_face_center(vertices,faces): + return (vertices[faces[:,0]]+vertices[faces[:,1]]+vertices[faces[:,2]])/3 + + +class DirDist_M2M(torch.nn.Module): + def __init__(self,num_query=20000,std=0.05): + super().__init__() + self.num_query=num_query + self.std=std + + + def forward(self,src_v,src_f,tgt_v,tgt_f): + #src_v, tgt_v [N,3] + #src_f, tgt_g [M,3], long + + ''' + Note: You could also choose 'pytorch3d.structures.Meshes' to represent the two meshes + src_v=src_mesh.verts_packed() + src_f=src_mesh.faces_packed() + + tgt_v=tgt_mesh.verts_packed() + tgt_f=tgt_mesh.faces_packed()''' + + query_points,_=sampl_surface(tgt_f,tgt_v,self.num_query) + noise_offset=torch.randn_like(query_points)*self.std + query_points=query_points+noise_offset + + src_f1=src_v[src_f[:,0]] + src_f2=src_v[src_f[:,1]] + src_f3=src_v[src_f[:,2]] + + tgt_f1=tgt_v[tgt_f[:,0]] + tgt_f2=tgt_v[tgt_f[:,1]] + tgt_f3=tgt_v[tgt_f[:,2]] + + src_center=(src_f1+src_f2+src_f3)/3 + tgt_center=(tgt_f1+tgt_f2+tgt_f3)/3 + + query_points=torch.cat([query_points.detach(),src_center.detach()],dim=0) + + src_indexes,src_w1,src_w2,src_w3=_C.closest_point_on_surface_forward(query_points,src_f1,src_f2,src_f3) + tgt_indexes,tgt_w1,tgt_w2,tgt_w3=_C.closest_point_on_surface_forward(query_points,tgt_f1,tgt_f2,tgt_f3) + + sel_src_f1=src_f1[src_indexes.long()] + sel_src_f2=src_f2[src_indexes.long()] + sel_src_f3=src_f3[src_indexes.long()] + + sel_tgt_f1=tgt_f1[tgt_indexes.long()] + sel_tgt_f2=tgt_f2[tgt_indexes.long()] + sel_tgt_f3=tgt_f3[tgt_indexes.long()] + + closest_src=src_w1[:,None]*sel_src_f1+src_w2[:,None]*sel_src_f2+src_w3[:,None]*sel_src_f3 + closest_tgt=tgt_w1[:,None]*sel_tgt_f1+tgt_w2[:,None]*sel_tgt_f2+tgt_w3[:,None]*sel_tgt_f3 + + dir_src=query_points-closest_src + udf_src=torch.norm(dir_src+1e-10,dim=-1,keepdim=True) + geo_src=torch.cat([dir_src,udf_src],dim=1) + + + dir_tgt=query_points-closest_tgt + udf_tgt=torch.norm(dir_tgt+1e-10,dim=-1,keepdim=True) + geo_tgt=torch.cat([dir_tgt,udf_tgt],dim=1) + + return torch.mean(torch.abs(geo_src-geo_tgt))*4 + + +class DirDist_P2P(torch.nn.Module): + def __init__(self,up_ratio=10,K=5,std=0.05,weighted_query=True,beta=3): + super().__init__() + self.K=K + self.up_ratio=up_ratio + self.std=std + self.weighted_query=weighted_query + self.beta=beta + + def cal_udf_weights(self,x,query): + #x: (B,N,3) + #query=self.grid_flatten.to(x).unsqueeze(0).repeat(x.size(0),1,1) + + dists,idx,knn_pc=knn_points(query,x,K=self.K,return_nn=True,return_sorted=True) #(B,N,K) (B,N,K) (B,N,K,3) + + dir=query.unsqueeze(2)-knn_pc #(B,N,K,3) + + #weights=torch.softmax(-dists.sqrt(),dim=2) #(B,N,K) weight more, dist small + #weights=torch.softmax(-dists,dim=2) #(B,N,K) weight more, dist small + #weights=torch.softmax(-dists/torch.min(dists,dim=2,keepdim=True)[0],dim=2) #(B,N,K) weight more, dist small + + norm = torch.sum(1.0 / (dists + 1e-8), dim = 2, keepdim = True) + weights = (1.0 / (dists.detach() + 1e-8)) / norm.detach() + + + #print(weights) + #assert False + + #udf=torch.sum((dists+1e-10).sqrt()*weights,dim=2) #(B,N) + #udf=torch.sum(dists*weights,dim=2) #(B,N) + + udf_grad=torch.sum(dir*weights.unsqueeze(-1),dim=2) #(B,N,3) + udf=torch.norm(udf_grad+1e-10,dim=-1) + + return udf,udf_grad,weights + + def cal_udf(self,x,weights,query): + #query=self.grid_flatten.to(x).unsqueeze(0).repeat(x.size(0),1,1) + + dists,idx,knn_pc=knn_points(query,x,K=self.K,return_nn=True,return_sorted=True) #(B,N,K) (B,N,K) (B,N,K,3) + dir=query.unsqueeze(2)-knn_pc #(B,N,K,3) + #udf=torch.sum((dists+1e-10).sqrt()*weights,dim=2) #(B,N) + #udf=torch.sum(dists*weights,dim=2) #(B,N) + + udf_grad=torch.sum(dir*weights.unsqueeze(-1),dim=2) #(B,N,3) + udf=torch.norm(udf_grad+1e-10,dim=-1) + return udf,udf_grad + + def forward(self,src,tgt): + #src: target (B,N,3) + #tgt: source (B,N,3) + + with torch.no_grad(): + + std=self.std + noise_offset=torch.randn(tgt.size(0),tgt.size(1),self.up_ratio,3).to(tgt).float() * std + + + query=tgt.unsqueeze(2)+noise_offset + query=query.reshape(tgt.size(0),-1,3).detach() + + + query=torch.cat((query,src.detach()),dim=1) + + udf_tgt,udf_grad_tgt,_=self.cal_udf_weights(tgt,query) + udf_src,udf_grad_src,_=self.cal_udf_weights(src,query) + + + udf_error=torch.abs(udf_tgt-udf_src) #(B,M) + + udf_grad_error=torch.sum(torch.abs(udf_grad_src-udf_grad_tgt),axis=-1) #(B,M) + #udf_grad_loss=torch.mean(torch.square(udf_grad_src-udf_grad_tgt)) + + if self.weighted_query: + + with torch.no_grad(): + query_weights=torch.exp(-udf_error.detach()*self.beta)*torch.exp(-udf_grad_error.detach()*self.beta) + return torch.sum((udf_error+udf_grad_error)*query_weights.detach())/query.size(0)/query.size(1) + + else: + query_weights=1 + return torch.sum((udf_error+udf_grad_error)*query_weights)/query.size(0)/query.size(1) + + +class DirDist_M2P(torch.nn.Module): + def __init__(self,up_ratio=3,beta=0,K=5,std=0.05): + super().__init__() + self.up_ratio=up_ratio + self.beta=beta + self.K=K + self.std=std + + def forward(self,src_v,src_f,tgt_points): + #src_v [N,3] + #src_f [F,3] long + #tgt_points [M,3] + + src_f1=src_v[src_f[:,0]] + src_f2=src_v[src_f[:,1]] + src_f3=src_v[src_f[:,2]] + + src_center=(src_f1+src_f2+src_f3)/3 + + query_points=tgt_points.unsqueeze(1)+self.std*torch.randn(tgt_points.size(0),self.up_ratio,tgt_points.size(1)).to(tgt_points) + query_points=query_points.reshape(-1,3) + + query_points=torch.cat([query_points.detach(),src_center.detach()],dim=0) + + src_indexes,src_w1,src_w2,src_w3=_C.closest_point_on_surface_forward(query_points,src_f1,src_f2,src_f3) + + sel_src_f1=src_f1[src_indexes.long()] + sel_src_f2=src_f2[src_indexes.long()] + sel_src_f3=src_f3[src_indexes.long()] + + closest_src=src_w1[:,None]*sel_src_f1+src_w2[:,None]*sel_src_f2+src_w3[:,None]*sel_src_f3 + + dir_src=query_points-closest_src + udf_src=torch.norm(dir_src+1e-10,dim=-1,keepdim=True) + geo_src=torch.cat([dir_src,udf_src],dim=1) + + dists,_,knn_pc=knn_points(query_points.unsqueeze(0),tgt_points.unsqueeze(0),K=self.K,return_nn=True,return_sorted=True) #(B,N,K) (B,N,K) (B,N,K,3) + + dir=query_points.unsqueeze(0).unsqueeze(2)-knn_pc #(B,N,K,3) + + norm = torch.sum(1.0 / (dists + 1e-8), dim = 2, keepdim = True) + weights = (1.0 / (dists.detach() + 1e-8)) / norm.detach() + + + dir_tgt=torch.sum(dir*weights.unsqueeze(-1),dim=2).squeeze(0) #(N,3) + udf_tgt=torch.norm(dir_tgt+1e-10,dim=1) + + geo_tgt=torch.cat([dir_tgt,udf_tgt.unsqueeze(1)],dim=-1) + + errors=torch.sum(torch.abs(geo_src-geo_tgt),dim=-1) + + query_weights=torch.exp(-errors*self.beta).detach() + + return torch.mean(errors*query_weights) + + + + diff --git a/pytorch3d/renderer/opengl/opengl_utils.py b/pytorch3d/renderer/opengl/opengl_utils.py old mode 100755 new mode 100644 diff --git a/scripts/parse_tutorials.py b/scripts/parse_tutorials.py old mode 100755 new mode 100644 diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 diff --git a/tests/benchmarks/bm_main.py b/tests/benchmarks/bm_main.py old mode 100755 new mode 100644 diff --git a/tests/benchmarks/bm_pulsar.py b/tests/benchmarks/bm_pulsar.py old mode 100755 new mode 100644 diff --git a/tests/data/missing_usemtl/cow.mtl b/tests/data/missing_usemtl/cow.mtl old mode 100755 new mode 100644 diff --git a/tests/data/missing_usemtl/cow.obj b/tests/data/missing_usemtl/cow.obj old mode 100755 new mode 100644 diff --git a/tests/data/missing_usemtl/cow_texture.png b/tests/data/missing_usemtl/cow_texture.png old mode 100755 new mode 100644