Project-5/python_scripts/colormap_all.py

94 lines
2.7 KiB
Python

import ast
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_theme()
params = {
"font.family": "Serif",
"font.serif": "Roman",
"text.usetex": True,
"axes.titlesize": "large",
"axes.labelsize": "large",
"xtick.labelsize": "large",
"ytick.labelsize": "large",
"legend.fontsize": "medium",
}
plt.rcParams.update(params)
def plot():
fig, axes = plt.subplots(nrows=3, ncols=3)
with open("data/color_map.txt") as f:
lines = f.readlines()
size = int(lines[0])
for i, line in enumerate(lines[1:]):
# Create figures for each plot
# fig1, ax1 = plt.subplots()
# fig2, ax2 = plt.subplots()
# fig3, ax3 = plt.subplots()
arr = line.strip().split("\t")
arr = np.asarray(list(map(lambda x: complex(*ast.literal_eval(x)), arr)))
# Reshape and transpose array
arr = arr.reshape(size, size).T
# Plot color maps
color_map1 = axes[i,0].imshow(
np.multiply(arr, arr.conj()).real,
interpolation="nearest",
cmap=sns.color_palette("mako", as_cmap=True)
)
color_map2 = axes[i,1].imshow(
arr.real,
interpolation="nearest",
cmap=sns.color_palette("mako", as_cmap=True)
)
color_map3 = axes[i,2].imshow(
arr.imag,
interpolation="nearest",
cmap=sns.color_palette("mako", as_cmap=True)
)
# Create color bar
fig.colorbar(color_map1, ax=axes)
# fig2.colorbar(color_map2, ax=ax2)
# fig3.colorbar(color_map3, ax=ax3)
# # Remove grids
axes.grid(False)
# ax2.grid(False)
# ax3.grid(False)
# # Set custom ticks
# ax1.set_xticks(ticks)
# ax1.set_yticks(ticks)
# ax2.set_xticks(ticks)
# ax2.set_yticks(ticks)
# ax3.set_xticks(ticks)
# ax3.set_yticks(ticks)
# # Set labels
# ax1.set_xlabel("x-axis")
# ax1.set_ylabel("y-axis")
# ax2.set_xlabel("x-axis")
# ax2.set_ylabel("y-axis")
# ax3.set_xlabel("x-axis")
# ax3.set_ylabel("y-axis")
# # Save the figures
fig.savefig(f"latex/images/color_map_all.pdf", bbox_inches="tight")
# fig2.savefig(f"latex/images/color_map_{i}_real.pdf", bbox_inches="tight")
# fig3.savefig(f"latex/images/color_map_{i}_imag.pdf", bbox_inches="tight")
# # Close figures
# plt.close(fig1)
# plt.close(fig2)
# plt.close(fig3)
if __name__ == "__main__":
plot()