import torch
import numpy as np
import matplotlib.pyplot as plt
xlim = ylim = 4
def density(z):
u = 0.5 * ((torch.norm(z, 2, dim=1) - 2) / 0.4) ** 2
u = u - torch.log(
torch.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2)
+ torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)
)
return u
x = y = np.linspace(-xlim, xlim, 300)
X, Y = np.meshgrid(x, y)
shape = X.shape
X_flatten, Y_flatten = np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))
Z = torch.from_numpy(np.concatenate([X_flatten, Y_flatten], 1))
u = 0.5 * ((torch.norm(Z, 2, dim=1) - 2) / 0.4) ** 2
u = u - torch.log(
torch.exp(-0.5 * ((Z[:, 0] - 2) / 0.6) ** 2)
+ torch.exp(-0.5 * ((Z[:, 0] + 2) / 0.6) ** 2))
U = torch.exp(-density(Z))
U = U.reshape(shape)
U.size()
U
xPython Plotting Tests
def U_1(z):
u = 0.5 * ((torch.norm(z, 2, dim=1) - 2) / 0.4) ** 2
u = u - torch.log(
torch.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2)
+ torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)
)
return u
def plot_density(density, xlim=4, ylim=4, ax=None, cmap="Purples"):
x = y = np.linspace(-xlim, xlim, 300)
X, Y = np.meshgrid(x, y)
shape = X.shape
X_flatten, Y_flatten = np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))
Z = torch.from_numpy(np.concatenate([X_flatten, Y_flatten], 1))
U = torch.exp(-density(Z))
U = U.reshape(shape)
if ax is None:
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111)
ax.set_xlim(-xlim, xlim)
ax.set_ylim(-xlim, xlim)
ax.set_aspect(1)
ax.pcolormesh(X, Y, U, cmap=cmap, rasterized=True)
ax.tick_params(
axis="both",
left=False,
top=False,
right=False,
bottom=False,
labelleft=False,
labeltop=False,
labelright=False,
labelbottom=False,
)
return ax
plot_density(U_1)def U_1(z):
u = 0.5 * ((torch.norm(z, 2, dim=1) - 2) / 0.4) ** 2
u = u - torch.log(
torch.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2)
+ torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)
)
return u
test = U_1(torch.randn(128,2, 2))
test
#torch.randn(2, 2)import torch
def interpolate_tensor(tensor, coords):
# Get the dimensions of the tensor
height, width = tensor.shape
# Separate x and y coordinates from the input tensor
x = coords[:, 0]
y = coords[:, 1]
# Calculate the indices of the four surrounding elements
x1 = x.floor().clamp(max=width - 1).long()
x2 = x1 + 1
y1 = y.floor().clamp(max=height - 1).long()
y2 = y1 + 1
# Calculate the weight for interpolation
weight_x2 = x - x1.float()
weight_x1 = 1 - weight_x2
weight_y2 = y - y1.float()
weight_y1 = 1 - weight_y2
# Perform interpolation
value = (
tensor[y1.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y1 +
tensor[y1.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y1 +
tensor[y2.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y2 +
tensor[y2.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y2
)
return value
# Create a sample PyTorch tensor
tensor = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# Create a tensor with x and y coordinates
coords = torch.tensor([[1.5, 2.2],
[0.2, 1.7]])
# Interpolate at continuous coordinates
interpolated_values = interpolate_tensor(tensor, coords)
print(interpolated_values)unique_values, counts = torch.unique(torch_posterior, return_counts=True)
# Print the unique values and their counts
for value, count in zip(unique_values, counts):
print(f"Value: {value}, Count: {count}")fig, ax = plt.subplots()
# Plot the tensor values
ax.imshow(torch_posterior.cpu(), cmap='viridis')
# Show the colorbar
plt.colorbar()
# Show the plot
plt.show()batch = torch.zeros(size=(batch_size, 2)).normal_(mean=0, std=1)
batch2 = torch.zeros(size=(batch_size, 2)).normal_(mean=150, std=30)
def two_moons_density(z):
x = z[:, 0]
y = z[:, 1]
d = torch.sqrt(x**2 + y**2)
density = torch.exp(-0.2 * d) * torch.cos(4 * np.pi * d)
return density
def ring_density(z):
exp1 = torch.exp(-0.5 * ((z[:, 0] - 2) / 0.8) ** 2)
exp2 = torch.exp(-0.5 * ((z[:, 0] + 2) / 0.8) ** 2)
u = 0.5 * ((torch.norm(z, 2, dim=1) - 4) / 0.4) ** 2
u = u - torch.log(exp1 + exp2 + 1e-6)
return u
two_moons_density(batch).min()
ring_density(batch).min()
ring_density(batch).max()def interpolate_tensor(tensor, z):
# Get the dimensions of the tensor
height, width = tensor.shape[:2]
# Scale and shift the normal draws to match the image coordinates
x = (z[:, 0] * 150 + 150).clamp(0, width - 1).long()
y = (z[:, 1] * 150 + 150).clamp(0, height - 1).long()
# Calculate the indices of the four surrounding elements
x1 = x.floor()
x2 = x1 + 1
y1 = y.floor()
y2 = y1 + 1
# Calculate the weight for interpolation
weight_x2 = x - x1.float()
weight_x1 = 1 - weight_x2
weight_y2 = y - y1.float()
weight_y1 = 1 - weight_y2
# Perform interpolation
value = (
tensor[y1.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y1 +
tensor[y1.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y1 +
tensor[y2.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y2 +
tensor[y2.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y2
)
return value
batch = torch.zeros(size=(batch_size, 2)).normal_(mean=0, std=1)
interpolate_tensor(torch_posterior, batch)
#ring_density(batch)#ring_density(batch)
#interpolate_tensor(torch_posterior,batch2)
batch2fig, ax = plt.subplots()
# Plot the tensor values
ax.imshow(ring_density(batch).cpu(), cmap='viridis')
# Show the colorbar
plt.colorbar()
# Show the plot
plt.show()interpolate_tensor(torch_posterior, torch.tensor([[81,156.5]]))
torch_posterior[81,157]
indices = torch.nonzero(torch_posterior != 0)
len(indices)
# Print the pairs of indices
#for idx in indices:
# i, j = idx
# print(f"Pair of indices: ({i}, {j})")test = MultivariateNormal(torch.zeros(2), torch.eye(2))
test.sample()import os
import imageio
def make_gif_from_train_plots(fname: str) -> None:
# Hiding the directory when commiting, but easy to infer rihgt path
png_dir = ""
images = []
sort = sorted(os.listdir(png_dir))
for file_name in sort[1::1]:
if file_name.endswith(".png"):
file_path = os.path.join(png_dir, file_name)
images.append(imageio.imread(file_path))
imageio.mimsave("gifs/" + fname, images, duration=0.05)
make_gif_from_train_plots("32_layer.gif")