import ast import matplotlib.pyplot as plt import numpy as np import seaborn as sns sns.set_theme() params = { "font.family": "Serif", "font.serif": "Roman", "text.usetex": True, "axes.titlesize": "large", "axes.labelsize": "large", "xtick.labelsize": "large", "ytick.labelsize": "large", "legend.fontsize": "medium", } plt.rcParams.update(params) def plot(): ticks = [0, 0.25, 0.5, 0.75, 1.0] with open("data/color_map.txt") as f: lines = f.readlines() size = int(lines[0]) for i, line in enumerate(lines[1:]): # Create figures for each plot fig1, ax1 = plt.subplots() fig2, ax2 = plt.subplots() fig3, ax3 = plt.subplots() arr = line.strip().split("\t") arr = np.asarray(list(map(lambda x: complex(*ast.literal_eval(x)), arr))) # Reshape and transpose array arr = arr.reshape(size, size).T # Plot color maps color_map1 = ax1.imshow( np.multiply(arr, arr.conj()).real, interpolation="nearest", cmap=sns.color_palette("mako", as_cmap=True), extent=[0, 1.0, 0, 1.0] ) color_map2 = ax2.imshow( arr.real, interpolation="nearest", cmap=sns.color_palette("mako", as_cmap=True), extent=[0, 1.0, 0, 1.0] ) color_map3 = ax3.imshow( arr.imag, interpolation="nearest", cmap=sns.color_palette("mako", as_cmap=True), extent=[0, 1.0, 0, 1.0] ) # Create color bar fig1.colorbar(color_map1, ax=ax1) fig2.colorbar(color_map2, ax=ax2) fig3.colorbar(color_map3, ax=ax3) # Remove grids ax1.grid(False) ax2.grid(False) ax3.grid(False) # Set custom ticks ax1.set_xticks(ticks) ax1.set_yticks(ticks) ax2.set_xticks(ticks) ax2.set_yticks(ticks) ax3.set_xticks(ticks) ax3.set_yticks(ticks) # Set labels ax1.set_xlabel("x-axis") ax1.set_ylabel("y-axis") ax2.set_xlabel("x-axis") ax2.set_ylabel("y-axis") ax3.set_xlabel("x-axis") ax3.set_ylabel("y-axis") # Save the figures fig1.savefig(f"latex/images/color_map_{i}_prob.pdf", bbox_inches="tight") fig2.savefig(f"latex/images/color_map_{i}_real.pdf", bbox_inches="tight") fig3.savefig(f"latex/images/color_map_{i}_imag.pdf", bbox_inches="tight") # Close figures plt.close(fig1) plt.close(fig2) plt.close(fig3) if __name__ == "__main__": plot()