-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5d7a46f
commit e4f9bf8
Showing
15 changed files
with
287 additions
and
410 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,8 @@ venv | |
|
||
**/__pycache__ | ||
*.png | ||
*.coverage | ||
*.egg-info | ||
**/.ipynb_checkpoints | ||
*.sif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
plt.rcParams['text.usetex'] = True | ||
plt.rc('text.latex', preamble=r'\usepackage{bm}') | ||
|
||
|
||
def x_param_curve(t): | ||
return np.stack((t, np.sqrt(1-t**3)), axis=1) | ||
|
||
|
||
def x_param_derivative(t): | ||
return np.stack((np.ones_like(t), -3/2*t**2/np.sqrt(1-t**3)), axis=1) | ||
|
||
|
||
def manifold_density_reweight_desired(xy): | ||
return ((3*xy[:, 0]**2)**2 + (xy[:, 1])**2)**(-1/2) | ||
|
||
|
||
def manifold_density_reweight_actual_x_param(xy): | ||
return (np.ones_like(xy[:, 0])**2 + (3/2*xy[:, 0]**2/np.sqrt(1-xy[:, 0]**3))**2)**(-1/2) | ||
|
||
|
||
def manifold_density_reweight_actual_y_param(xy): | ||
return ((2/3 * xy[:, 1]/((1-xy[:, 1]**2)**(2))**(1/3))**2 + np.ones_like(xy[:, 1])**2)**(-1/2) | ||
|
||
|
||
def length_param_curve(min_val=-1.5, max_val=1, n=100): | ||
t_actual = np.linspace(min_val, max_val, n + 2)[1:-1] | ||
norm_derivative_t_actual = np.linalg.norm(x_param_derivative(t_actual), axis=1) | ||
s_actual = np.cumsum(norm_derivative_t_actual)/np.sum(norm_derivative_t_actual) | ||
s_wanted = np.linspace(0, 1, n) | ||
t_wanted = np.interp(s_wanted, s_actual, t_actual) | ||
xy_wanted = x_param_curve(t_wanted) | ||
return xy_wanted | ||
|
||
|
||
xy = length_param_curve(min_val=-2, n=50000) | ||
x_max = np.max(xy[:, 0]) | ||
x_min = np.min(-xy[:, 0]) | ||
y_max = np.max(xy[:, 1]) | ||
y_min = np.min(-xy[:, 1]) | ||
density_funcs = { | ||
r"just manifold": None, | ||
r"manifold prior density": lambda a: np.exp(-1/2*np.linalg.norm(a, axis=1)**2), | ||
r"manifold actual density": lambda a: np.exp(-1/2*np.linalg.norm(a, axis=1)**2) * manifold_density_reweight_desired(a), | ||
r"manifold proj theta_1": lambda a: np.exp(-1/2*a[:, 0]**2) * manifold_density_reweight_actual_x_param(a), | ||
r"manifold proj theta_2": lambda a: np.exp(-1/2*a[:, 1]**2) * manifold_density_reweight_actual_y_param(a), | ||
r"manifold proj theta_12": lambda a: np.exp(-1/2*a[:, 1]**2) * (manifold_density_reweight_actual_x_param(a) + manifold_density_reweight_actual_y_param(a)) | ||
} | ||
general_max = np.maximum(y_max, x_max) | ||
general_min = np.minimum(y_min, x_min) | ||
for name, density_func in density_funcs.items(): | ||
if density_func is None: | ||
plt.scatter(xy[:,0], xy[:,1], s=1, color=plt.get_cmap('Blues')(0.99)) | ||
plt.scatter(xy[:,0], -xy[:,1], s=1, color=plt.get_cmap('Blues')(0.99)) | ||
else: | ||
density = density_func(xy) | ||
density = (density - density.min())/(density.max() - density.min()) | ||
density = density*(2-density) | ||
plt.scatter(xy[:, 0], xy[:, 1], s=1, c=density, cmap='Blues') | ||
plt.scatter(xy[:, 0], -xy[:, 1], s=1, c=density, cmap='Blues') | ||
plt.xlim(general_min, general_max) | ||
plt.ylim(general_min, general_max) | ||
plt.xticks(fontsize=18) | ||
plt.yticks(fontsize=18) | ||
plt.xlabel(r'$\theta_1$', loc='center', fontsize=18) | ||
plt.ylabel(r'$\theta_2$', loc='center', fontsize=18) | ||
plt.gca().set_aspect("equal", adjustable='box') | ||
plt.tight_layout() | ||
plt.savefig(f"{name}.png") | ||
plt.show() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import numpy as np | ||
|
||
from matplotlib import pyplot as plt | ||
plt.rcParams['text.usetex'] = True | ||
|
||
|
||
def f_(x, y): | ||
return x**3 + y**2 | ||
|
||
|
||
def manifold_density_reweight_desired_(x, y): | ||
return ((3*x**2)**2 + (y)**2)**(-1/2) | ||
|
||
|
||
x_max = 2 | ||
x_min = -2 | ||
y_max = np.sqrt(1-(-2)**3) | ||
y_min = -np.sqrt(1-(-2)**3) | ||
|
||
general_max = np.maximum(y_max, x_max) | ||
general_min = np.minimum(y_min, x_min) | ||
|
||
n = 500 | ||
x = np.linspace(general_min, general_max, n) | ||
y = np.linspace(general_min, general_max, n) | ||
X, Y = np.meshgrid(x, y) | ||
Z = manifold_density_reweight_desired_(X, Y) | ||
Z = (Z - Z.min()) / (Z.max() - Z.min()) | ||
Z = Z*(2-Z) | ||
|
||
mask = f_(X, Y) | ||
|
||
for Zm, name in [(Z, 'prior'), (np.ma.masked_where((1.3 < mask) | (mask < 0.7), Z), 'posterior_trivial')]: | ||
plt.pcolormesh(X, Y, Zm, shading='gouraud', color='blue', cmap='Blues') | ||
plt.xlim(general_min, general_max) | ||
plt.ylim(general_min, general_max) | ||
plt.xticks(fontsize=18) | ||
plt.yticks(fontsize=18) | ||
plt.xlabel(r'$\theta_1$', loc='center', fontsize=18) | ||
plt.ylabel(r'$\theta_2$', loc='center', fontsize=18) | ||
plt.gca().set_aspect("equal", adjustable='box') | ||
plt.tight_layout() | ||
plt.savefig(f'{name}.png') | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.