Python Plotting Tests

import torch
import numpy as np
import matplotlib.pyplot as plt

xlim = ylim = 4

def density(z):
  u = 0.5 * ((torch.norm(z, 2, dim=1) - 2) / 0.4) ** 2
  u = u - torch.log(
      torch.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2)
      + torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)
  )
  return u

x = y = np.linspace(-xlim, xlim, 300)
X, Y = np.meshgrid(x, y)
shape = X.shape
X_flatten, Y_flatten = np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))
Z = torch.from_numpy(np.concatenate([X_flatten, Y_flatten], 1))

u = 0.5 * ((torch.norm(Z, 2, dim=1) - 2) / 0.4) ** 2
u = u - torch.log(
      torch.exp(-0.5 * ((Z[:, 0] - 2) / 0.6) ** 2)
      + torch.exp(-0.5 * ((Z[:, 0] + 2) / 0.6) ** 2))

U = torch.exp(-density(Z))
U = U.reshape(shape)

U.size()
U

x
def U_1(z):
  u = 0.5 * ((torch.norm(z, 2, dim=1) - 2) / 0.4) ** 2
  u = u - torch.log(
      torch.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2)
      + torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)
  )
  return u

def plot_density(density, xlim=4, ylim=4, ax=None, cmap="Purples"):
  x = y = np.linspace(-xlim, xlim, 300)
  X, Y = np.meshgrid(x, y)
  shape = X.shape
  X_flatten, Y_flatten = np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))
  Z = torch.from_numpy(np.concatenate([X_flatten, Y_flatten], 1))
  U = torch.exp(-density(Z))
  U = U.reshape(shape)
  if ax is None:
      fig = plt.figure(figsize=(7, 7))
      ax = fig.add_subplot(111)
  
  ax.set_xlim(-xlim, xlim)
  ax.set_ylim(-xlim, xlim)
  ax.set_aspect(1)
  
  ax.pcolormesh(X, Y, U, cmap=cmap, rasterized=True)
  ax.tick_params(
      axis="both",
      left=False,
      top=False,
      right=False,
      bottom=False,
      labelleft=False,
      labeltop=False,
      labelright=False,
      labelbottom=False,
  )
  return ax

plot_density(U_1)
def U_1(z):
                u = 0.5 * ((torch.norm(z, 2, dim=1) - 2) / 0.4) ** 2
                u = u - torch.log(
                    torch.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2)
                    + torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)
                )
                return u
              
test = U_1(torch.randn(128,2, 2))

test

#torch.randn(2, 2)
import torch

def interpolate_tensor(tensor, coords):
    # Get the dimensions of the tensor
    height, width = tensor.shape

    # Separate x and y coordinates from the input tensor
    x = coords[:, 0]
    y = coords[:, 1]

    # Calculate the indices of the four surrounding elements
    x1 = x.floor().clamp(max=width - 1).long()
    x2 = x1 + 1
    y1 = y.floor().clamp(max=height - 1).long()
    y2 = y1 + 1

    # Calculate the weight for interpolation
    weight_x2 = x - x1.float()
    weight_x1 = 1 - weight_x2
    weight_y2 = y - y1.float()
    weight_y1 = 1 - weight_y2

    # Perform interpolation
    value = (
        tensor[y1.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y1 +
        tensor[y1.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y1 +
        tensor[y2.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y2 +
        tensor[y2.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y2
    )

    return value

# Create a sample PyTorch tensor
tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])

# Create a tensor with x and y coordinates
coords = torch.tensor([[1.5, 2.2],
                       [0.2, 1.7]])

# Interpolate at continuous coordinates
interpolated_values = interpolate_tensor(tensor, coords)

print(interpolated_values)
unique_values, counts = torch.unique(torch_posterior, return_counts=True)

# Print the unique values and their counts
for value, count in zip(unique_values, counts):
    print(f"Value: {value}, Count: {count}")
fig, ax = plt.subplots()

# Plot the tensor values
ax.imshow(torch_posterior.cpu(), cmap='viridis')

# Show the colorbar
plt.colorbar()

# Show the plot
plt.show()
batch = torch.zeros(size=(batch_size, 2)).normal_(mean=0, std=1)
batch2 = torch.zeros(size=(batch_size, 2)).normal_(mean=150, std=30)

def two_moons_density(z):
  x = z[:, 0]
  y = z[:, 1]
  d = torch.sqrt(x**2 + y**2)
  density = torch.exp(-0.2 * d) * torch.cos(4 * np.pi * d)
  return density

def ring_density(z):
    exp1 = torch.exp(-0.5 * ((z[:, 0] - 2) / 0.8) ** 2)
    exp2 = torch.exp(-0.5 * ((z[:, 0] + 2) / 0.8) ** 2)
    u = 0.5 * ((torch.norm(z, 2, dim=1) - 4) / 0.4) ** 2
    u = u - torch.log(exp1 + exp2 + 1e-6)
    return u

two_moons_density(batch).min()
ring_density(batch).min()
ring_density(batch).max()
def interpolate_tensor(tensor, z):
    # Get the dimensions of the tensor
    height, width = tensor.shape[:2]

    # Scale and shift the normal draws to match the image coordinates
    x = (z[:, 0] * 150 + 150).clamp(0, width - 1).long()
    y = (z[:, 1] * 150 + 150).clamp(0, height - 1).long()

    # Calculate the indices of the four surrounding elements
    x1 = x.floor()
    x2 = x1 + 1
    y1 = y.floor()
    y2 = y1 + 1

    # Calculate the weight for interpolation
    weight_x2 = x - x1.float()
    weight_x1 = 1 - weight_x2
    weight_y2 = y - y1.float()
    weight_y1 = 1 - weight_y2

    # Perform interpolation
    value = (
        tensor[y1.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y1 +
        tensor[y1.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y1 +
        tensor[y2.clamp(max=height - 1), x1.clamp(max=width - 1)] * weight_x1 * weight_y2 +
        tensor[y2.clamp(max=height - 1), x2.clamp(max=width - 1)] * weight_x2 * weight_y2
    )

    return value
  
batch = torch.zeros(size=(batch_size, 2)).normal_(mean=0, std=1)
interpolate_tensor(torch_posterior, batch)

#ring_density(batch)
#ring_density(batch)
#interpolate_tensor(torch_posterior,batch2)
batch2
fig, ax = plt.subplots()

# Plot the tensor values
ax.imshow(ring_density(batch).cpu(), cmap='viridis')

# Show the colorbar
plt.colorbar()

# Show the plot
plt.show()
interpolate_tensor(torch_posterior, torch.tensor([[81,156.5]]))

torch_posterior[81,157]

indices = torch.nonzero(torch_posterior != 0)

len(indices)

# Print the pairs of indices
#for idx in indices:
#    i, j = idx
#    print(f"Pair of indices: ({i}, {j})")
test = MultivariateNormal(torch.zeros(2), torch.eye(2))

test.sample()
import os
import imageio


def make_gif_from_train_plots(fname: str) -> None:
  # Hiding the directory when commiting, but easy to infer rihgt path
    png_dir = ""
    images = []
    sort = sorted(os.listdir(png_dir))
    for file_name in sort[1::1]:
        if file_name.endswith(".png"):
            file_path = os.path.join(png_dir, file_name)
            images.append(imageio.imread(file_path))

    imageio.mimsave("gifs/" + fname, images, duration=0.05)
    
make_gif_from_train_plots("32_layer.gif")