import torch
import numpy as np
import matplotlib.pyplot as plt
= ylim = 4
xlim
def density(z):
= 0.5 * ((torch.norm(z, 2, dim=1) - 2) / 0.4) ** 2
u = u - torch.log(
u -0.5 * ((z[:, 0] - 2) / 0.6) ** 2)
torch.exp(+ torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)
)return u
= y = np.linspace(-xlim, xlim, 300)
x = np.meshgrid(x, y)
X, Y = X.shape
shape = np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))
X_flatten, Y_flatten = torch.from_numpy(np.concatenate([X_flatten, Y_flatten], 1))
Z
= 0.5 * ((torch.norm(Z, 2, dim=1) - 2) / 0.4) ** 2
u = u - torch.log(
u -0.5 * ((Z[:, 0] - 2) / 0.6) ** 2)
torch.exp(+ torch.exp(-0.5 * ((Z[:, 0] + 2) / 0.6) ** 2))
= torch.exp(-density(Z))
U = U.reshape(shape)
U
U.size()
U
x
Python Plotting Tests
def U_1(z):
= 0.5 * ((torch.norm(z, 2, dim=1) - 2) / 0.4) ** 2
u = u - torch.log(
u -0.5 * ((z[:, 0] - 2) / 0.6) ** 2)
torch.exp(+ torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)
)return u
def plot_density(density, xlim=4, ylim=4, ax=None, cmap="Purples"):
= y = np.linspace(-xlim, xlim, 300)
x = np.meshgrid(x, y)
X, Y = X.shape
shape = np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))
X_flatten, Y_flatten = torch.from_numpy(np.concatenate([X_flatten, Y_flatten], 1))
Z = torch.exp(-density(Z))
U = U.reshape(shape)
U if ax is None:
= plt.figure(figsize=(7, 7))
fig = fig.add_subplot(111)
ax
-xlim, xlim)
ax.set_xlim(-xlim, xlim)
ax.set_ylim(1)
ax.set_aspect(
=cmap, rasterized=True)
ax.pcolormesh(X, Y, U, cmap
ax.tick_params(="both",
axis=False,
left=False,
top=False,
right=False,
bottom=False,
labelleft=False,
labeltop=False,
labelright=False,
labelbottom
)return ax
plot_density(U_1)
def U_1(z):
= 0.5 * ((torch.norm(z, 2, dim=1) - 2) / 0.4) ** 2
u = u - torch.log(
u -0.5 * ((z[:, 0] - 2) / 0.6) ** 2)
torch.exp(+ torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)
)return u
= U_1(torch.randn(128,2, 2))
test
test
#torch.randn(2, 2)
import torch
def interpolate_tensor(tensor, coords):
# Get the dimensions of the tensor
= tensor.shape
height, width
# Separate x and y coordinates from the input tensor
= coords[:, 0]
x = coords[:, 1]
y
# Calculate the indices of the four surrounding elements
= x.floor().clamp(max=width - 1).long()
x1 = x1 + 1
x2 = y.floor().clamp(max=height - 1).long()
y1 = y1 + 1
y2
# Calculate the weight for interpolation
= x - x1.float()
weight_x2 = 1 - weight_x2
weight_x1 = y - y1.float()
weight_y2 = 1 - weight_y2
weight_y1
# Perform interpolation
= (
value max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y1 +
tensor[y1.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y1 +
tensor[y1.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y2 +
tensor[y2.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y2
tensor[y2.clamp(
)
return value
# Create a sample PyTorch tensor
= torch.tensor([[1, 2, 3],
tensor 4, 5, 6],
[7, 8, 9]])
[
# Create a tensor with x and y coordinates
= torch.tensor([[1.5, 2.2],
coords 0.2, 1.7]])
[
# Interpolate at continuous coordinates
= interpolate_tensor(tensor, coords)
interpolated_values
print(interpolated_values)
= torch.unique(torch_posterior, return_counts=True)
unique_values, counts
# Print the unique values and their counts
for value, count in zip(unique_values, counts):
print(f"Value: {value}, Count: {count}")
= plt.subplots()
fig, ax
# Plot the tensor values
='viridis')
ax.imshow(torch_posterior.cpu(), cmap
# Show the colorbar
plt.colorbar()
# Show the plot
plt.show()
= torch.zeros(size=(batch_size, 2)).normal_(mean=0, std=1)
batch = torch.zeros(size=(batch_size, 2)).normal_(mean=150, std=30)
batch2
def two_moons_density(z):
= z[:, 0]
x = z[:, 1]
y = torch.sqrt(x**2 + y**2)
d = torch.exp(-0.2 * d) * torch.cos(4 * np.pi * d)
density return density
def ring_density(z):
= torch.exp(-0.5 * ((z[:, 0] - 2) / 0.8) ** 2)
exp1 = torch.exp(-0.5 * ((z[:, 0] + 2) / 0.8) ** 2)
exp2 = 0.5 * ((torch.norm(z, 2, dim=1) - 4) / 0.4) ** 2
u = u - torch.log(exp1 + exp2 + 1e-6)
u return u
min()
two_moons_density(batch).min()
ring_density(batch).max() ring_density(batch).
def interpolate_tensor(tensor, z):
# Get the dimensions of the tensor
= tensor.shape[:2]
height, width
# Scale and shift the normal draws to match the image coordinates
= (z[:, 0] * 150 + 150).clamp(0, width - 1).long()
x = (z[:, 1] * 150 + 150).clamp(0, height - 1).long()
y
# Calculate the indices of the four surrounding elements
= x.floor()
x1 = x1 + 1
x2 = y.floor()
y1 = y1 + 1
y2
# Calculate the weight for interpolation
= x - x1.float()
weight_x2 = 1 - weight_x2
weight_x1 = y - y1.float()
weight_y2 = 1 - weight_y2
weight_y1
# Perform interpolation
= (
value max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y1 +
tensor[y1.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y1 +
tensor[y1.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y2 +
tensor[y2.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y2
tensor[y2.clamp(
)
return value
= torch.zeros(size=(batch_size, 2)).normal_(mean=0, std=1)
batch
interpolate_tensor(torch_posterior, batch)
#ring_density(batch)
#ring_density(batch)
#interpolate_tensor(torch_posterior,batch2)
batch2
= plt.subplots()
fig, ax
# Plot the tensor values
='viridis')
ax.imshow(ring_density(batch).cpu(), cmap
# Show the colorbar
plt.colorbar()
# Show the plot
plt.show()
81,156.5]]))
interpolate_tensor(torch_posterior, torch.tensor([[
81,157]
torch_posterior[
= torch.nonzero(torch_posterior != 0)
indices
len(indices)
# Print the pairs of indices
#for idx in indices:
# i, j = idx
# print(f"Pair of indices: ({i}, {j})")
= MultivariateNormal(torch.zeros(2), torch.eye(2))
test
test.sample()
import os
import imageio
def make_gif_from_train_plots(fname: str) -> None:
# Hiding the directory when commiting, but easy to infer rihgt path
= ""
png_dir = []
images = sorted(os.listdir(png_dir))
sort for file_name in sort[1::1]:
if file_name.endswith(".png"):
= os.path.join(png_dir, file_name)
file_path
images.append(imageio.imread(file_path))
"gifs/" + fname, images, duration=0.05)
imageio.mimsave(
"32_layer.gif") make_gif_from_train_plots(