diff --git a/pytorch3d/ops/mesh_curvature.py b/pytorch3d/ops/mesh_curvature.py new file mode 100644 index 000000000..99285d07d --- /dev/null +++ b/pytorch3d/ops/mesh_curvature.py @@ -0,0 +1,106 @@ +import torch +import pytorch3d +# from pytorch3d.ops import cot_laplacian +# from pytorch3d.structures import Meshes + +def one_hot_sparse(A,num_classes,value=None): + A = A.int() + B = torch.arange(A.shape[0]).to(A.device) + if value==None: + C = torch.ones_like(B) + else: + C = value + return torch.sparse_coo_tensor(torch.stack([B,A]),C,size=(A.shape[0],num_classes)) + +def faces_angle(meshs: pytorch3d.structures.Meshes)->torch.Tensor: + """ + Compute the angle of each face in a mesh + Args: + meshs: Meshes object + Returns: + angles: Tensor of shape (N,3) where N is the number of faces + """ + Face_coord = meshs.verts_packed()[meshs.faces_packed()] + A = Face_coord[:,1,:] - Face_coord[:,0,:] + B = Face_coord[:,2,:] - Face_coord[:,1,:] + C = Face_coord[:,0,:] - Face_coord[:,2,:] + angle_0 = torch.arccos(-torch.sum(A*C,dim=1)/torch.norm(A,dim=1)/torch.norm(C,dim=1)) + angle_1 = torch.arccos(-torch.sum(A*B,dim=1)/torch.norm(A,dim=1)/torch.norm(B,dim=1)) + angle_2 = torch.arccos(-torch.sum(B*C,dim=1)/torch.norm(B,dim=1)/torch.norm(C,dim=1)) + angles = torch.stack([angle_0,angle_1,angle_2],dim=1) + return angles + +def dual_area_weights_on_faces(Surfaces: pytorch3d.structures.Meshes)->torch.Tensor: + """ + Compute the dual area weights of 3 vertices of each triangles in a mesh + Args: + Surfaces: Meshes object + Returns: + dual_area_weight: Tensor of shape (N,3) where N is the number of triangles + the dual area of a vertices in a triangles is defined as the area of the sub-quadrilateral divided by three perpendicular bisectors + """ + angles = faces_angle(Surfaces) + sin2angle = torch.sin(2*angles) + dual_area_weight = torch.ones_like(Surfaces.faces_packed())*(torch.sum(sin2angle,dim=1).view(-1,1).repeat(1,3)) + for i in range(3): + j,k = (i+1)%3, (i+2)%3 + dual_area_weight[:,i] = 0.5*(sin2angle[:,j]+sin2angle[:,k])/dual_area_weight[:,i] + return dual_area_weight + + +def Dual_area_for_vertices(Surfaces: pytorch3d.structures.Meshes)->torch.Tensor: + """ + Compute the dual area of each vertices in a mesh + Args: + Surfaces: Meshes object + Returns: + dual_area_vertex: Tensor of shape (N,1) where N is the number of vertices + the dual area of a vertices is defined as the sum of the dual area of the triangles that contains this vertices + """ + + dual_area_weight = dual_area_weights_on_faces(Surfaces) + dual_area_faces = Surfaces.faces_areas_packed().view(-1,1).repeat(1,3)*dual_area_weight + face_vertices_to_idx = one_hot_sparse(Surfaces.faces_packed().view(-1),num_classes=Surfaces.num_verts_per_mesh().sum()) + dual_area_vertex = torch.sparse.mm(face_vertices_to_idx.float().T,dual_area_faces.view(-1,1)).T + return dual_area_vertex + + +def Gaussian_curvature(Surfaces: pytorch3d.structures.Meshes,return_topology=False)->torch.Tensor: + """ + Compute the gaussian curvature of each vertices in a mesh by local Gauss-Bonnet theorem + Args: + Surfaces: Meshes object + return_topology: bool, if True, return the Euler characteristic and genus of the mesh + Returns: + gaussian_curvature: Tensor of shape (N,1) where N is the number of vertices + the gaussian curvature of a vertices is defined as the sum of the angles of the triangles that contains this vertices minus 2*pi and divided by the dual area of this vertices + """ + + face_vertices_to_idx = one_hot_sparse(Surfaces.faces_packed().view(-1),num_classes=Surfaces.num_verts_per_mesh().sum()) + vertices_to_meshid = one_hot_sparse(Surfaces.verts_packed_to_mesh_idx(),num_classes=Surfaces.num_verts_per_mesh().shape[0]) + sum_angle_for_vertices = torch.sparse.mm(face_vertices_to_idx.float().T,faces_angle(Surfaces).view(-1,1)).T + # Euler_chara = torch.sparse.mm(vertices_to_meshid.float().T,(2*torch.pi - sum_angle_for_vertices).T).T/torch.pi/2 + # Euler_chara = torch.round(Euler_chara) + # print('Euler_characteristic:',Euler_chara) + # Genus = (2-Euler_chara)/2 + #print('Genus:',Genus) + gaussian_curvature = (2*torch.pi - sum_angle_for_vertices)/Dual_area_for_vertices(Surfaces) + if return_topology: + Euler_chara = torch.sparse.mm(vertices_to_meshid.float().T,(2*torch.pi - sum_angle_for_vertices).T).T/torch.pi/2 + Euler_chara = torch.round(Euler_chara) + return gaussian_curvature, Euler_chara, Genus + return gaussian_curvature + +def Average_from_verts_to_face(Surfaces: pytorch3d.structures.Meshes, vect_verts: torch.Tensor)->torch.Tensor: + """ + Compute the average of feature vectors defined on vertices to faces by dual area weights + Args: + Surfaces: Meshes object + vect_verts: Tensor of shape (N,C) where N is the number of vertices, C is the number of feature channels + Returns: + vect_faces: Tensor of shape (F,C) where F is the number of faces + """ + assert vect_verts.shape[0] == Surfaces.verts_packed().shape[0] + dual_weight = dual_area_weights_on_faces(Surfaces).view(-1) + wg = one_hot_sparse(Surfaces.faces_packed().view(-1),num_classes=Surfaces.num_verts_per_mesh().sum(),value=dual_weight).float() + return torch.sparse.mm(wg,vect_verts).view(-1,3).sum(dim=1) diff --git a/pytorch3d/ops/pcl_curvature.py b/pytorch3d/ops/pcl_curvature.py new file mode 100644 index 000000000..798b849b4 --- /dev/null +++ b/pytorch3d/ops/pcl_curvature.py @@ -0,0 +1,113 @@ +import torch +import pytorch3d +from pytorch3d.ops import knn_points,knn_gather +# from pytorch3d.structures import Meshes + +def Weingarten_maps(pointscloud:torch.Tensor, k=50)->torch.Tensor: + """ + Compute the Weingarten maps of a point cloud + Args: + pointscloud: Tensor of shape (B,N,3) where N is the number of points in each batch + k: int, number of neighbors + Returns: + Weingarten_fields: Tensor of shape (B,N,2,2) where N is the number of points + normals_field: Tensor of shape (B,N,3) where N is the number of points + tangent1_field: Tensor of shape (B,N,3) + tangent2_field: Tensor of shape (B,N,3) + """ + + + pointscloud_shape = pointscloud.shape + batch_size = pointscloud_shape[0] + num_points = torch.LongTensor([pointscloud_shape[1]]*batch_size) + + # undo global mean for stability + pointscloud_centered = pointscloud - pointscloud.mean(-2).view(batch_size,1,3) + knn_info = knn_points(pointscloud_centered,pointscloud_centered,lengths1=num_points,lengths2=num_points,K=k,return_nn=True) + + # compute knn & covariance matrix + knn_point_centered = knn_info.knn - knn_info.knn.mean(-2).view(batch_size,-1,1,3) + covs_field = torch.matmul(knn_point_centered.transpose(-1,-2),knn_point_centered) /(knn_point_centered.shape[-1]-1) + frames_field = torch.linalg.eigh(covs_field).eigenvectors + + normals_field = frames_field[:,:,:,0] + tangent1_field = frames_field[:,:,:,1] + tangent2_field = frames_field[:,:,:,2] + + + local_pt_difference = knn_info.knn[:,:,1:k,:]- pointscloud_centered[:,:,None,:] # B x N x K x 3 + + # Disambiguates normals by checking the sign of the projection of the + proj = (normals_field[:, :, None] * local_pt_difference).sum(-1) + # check how many projections are positive + n_pos = (proj > 0).type_as(knn_info.knn).sum(-1, keepdim=True) + # flip the principal directions where number of positive correlations for + flip = (n_pos < (0.5 * (k-1))).type_as(knn_info.knn) + + normals_field = (1.0 - 2.0 * flip) * normals_field + + # local normals difference + local_normals_difference = knn_gather(normals_field,knn_info.idx,lengths=num_points)[:,:,1:k,:] - normals_field[:,:,None,:] + + # project the difference onto the tangent plane, getting the differential of the gaussian map + local_dpt_tangent1 = (local_pt_difference * tangent1_field[:,:,None,:]).sum(-1,keepdim=True) + local_dpt_tangent2 = (local_pt_difference * tangent2_field[:,:,None,:]).sum(-1,keepdim=True) + local_dpt_tangent = torch.cat((local_dpt_tangent1,local_dpt_tangent2),dim=-1) + local_dnormals_tangent1 = (local_normals_difference * tangent1_field[:,:,None,:]).sum(-1,keepdim=True) + local_dnormals_tangent2 = (local_normals_difference * tangent2_field[:,:,None,:]).sum(-1,keepdim=True) + local_dnormals_tangent = torch.cat((local_dnormals_tangent1,local_dnormals_tangent2),dim=-1) + + + # estimate the weingarten map by solving a least squares problem: W = Dn^T Dp (Dp^T Dp)^-1 + XXT = torch.matmul(local_dpt_tangent.transpose(-1,-2),local_dpt_tangent) + YXT = torch.matmul(local_dnormals_tangent.transpose(-1,-2),local_dpt_tangent) + XYT = torch.matmul(local_dpt_tangent.transpose(-1,-2),local_dnormals_tangent) + #Weingarten_fields_0 = torch.matmul(YXT,torch.inverse(XXT+1e-8*torch.eye(2).type_as(XXT))) ## the unsymetric version + + + # solve the sylvester equation to get the shape operator (symmetric version) + S = YXT + XYT + + XXT_eig = torch.linalg.eigh(XXT) + Q = XXT_eig.eigenvectors + #D = torch.diag_embed(XXT_eig.eigenvalues) + # XX^T = Q^T D Q + Q_TSQ = torch.matmul(Q.transpose(-1,-2),torch.matmul(S,Q)) + + a = XXT_eig.eigenvalues[:,:,0] + b = XXT_eig.eigenvalues[:,:,1] + a_b = a+b + a2_a_b = torch.stack((2*a,a_b),dim=-1).view(batch_size,-1,1,2) + a_b_b2 = torch.stack((a_b,2*b),dim=-1).view(batch_size,-1,1,2) + c = torch.stack((a2_a_b,a_b_b2),dim=-2).view(batch_size,-1,2,2) + + E = (1/c+1e-6) * Q_TSQ + Weingarten_fields = torch.matmul(Q,torch.matmul(E,Q.transpose(-1,-2))) + + + return Weingarten_fields, normals_field, tangent1_field, tangent2_field + +def Curvature_pcl(pointscloud, k=50, return_princpals=False): + """ + Compute the gaussian curvature of point clouds + pointscloud: B x N x 3 + k: int, number of neighbors + return_princpals: bool,if True, return principal curvature and principal directions + if False, return gaussian curvature, mean curvature only + """ + + pointscloud_shape = pointscloud.shape + batch_size = pointscloud_shape[0] + num_points = torch.LongTensor([pointscloud_shape[1]]*batch_size) + Weingarten_fields, normals_field, tangent1_field, tangent2_field = Weingarten_maps(pointscloud, k=k) + tangent_space = tangent_space = torch.cat((tangent1_field.view(batch_size,-1,1,3),tangent2_field.view(batch_size,-1,1,3)),dim=-2) + if return_princpals: + principal_curvature , principal_direction_local = torch.linalg.eigh(Weingarten_fields) + principal_direction_global = torch.matmul(principal_direction_local.transpose(-1,-2),tangent_space) + return principal_curvature, principal_direction_global, normals_field + else: + gaussian_curvature_pcl = torch.det(Weingarten_fields) + mean_curvature_pcl = Weingarten_fields.diagonal(offset=0, dim1=-1, dim2=-2).mean(-1) + return gaussian_curvature_pcl, mean_curvature_pcl + +