from matplotlib import pyplot as plt import matplotlib.cm as cm import numpy as np import torch from IPython.display import HTML, display def set_default(figsize=(10, 10), dpi=100): plt.style.use(['dark_background', 'bmh']) plt.rc('axes', facecolor='k') plt.rc('figure', facecolor='k') plt.rc('figure', figsize=figsize, dpi=dpi) def plot_data(X, y, d=0, auto=False, zoom=1): X = X.cpu() y = y.cpu() plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral) plt.axis('square') plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom) if auto is True: plt.axis('equal') plt.axis('off') _m, _c = 0, '.15' plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0) plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0) def plot_model(X, y, model): model.cpu() mesh = np.arange(-1.1, 1.1, 0.01) xx, yy = np.meshgrid(mesh, mesh) with torch.no_grad(): data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float() Z = model(data).detach() Z = np.argmax(Z, axis=1).reshape(xx.shape) plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3) plot_data(X, y) def show_scatterplot(X, colors, title='', axis=True): colors = cm.rainbow(np.linspace(0, 1, len(ys))) X = X.numpy() # plt.figure() plt.axis('equal') plt.scatter(X[:, 0], X[:, 1], c=colors, s=30) # plt.grid(True) plt.title(title) plt.axis('off') _m, _c = 0, '.15' if axis: plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0) plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0) def plot_bases(bases, plotting=True, width=0.04): bases[2:] -= bases[:2] # if plot_bases.a: plot_bases.a.set_visible(False) # if plot_bases.b: plot_bases.b.set_visible(False) if plotting: plot_bases.a = plt.arrow(*bases[0], *bases[2], width=width, color='r', zorder=10, alpha=1., length_includes_head=True) plot_bases.b = plt.arrow(*bases[1], *bases[3], width=width, color='g', zorder=10, alpha=1., length_includes_head=True) plot_bases.a = None plot_bases.b = None def show_mat(mat, vect, prod, threshold=-1): # Subplot grid definition fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True, gridspec_kw={'width_ratios':[5,1,1]}) # Plot matrices cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1)) ax2.matshow(vect.numpy(), clim=(-1, 1)) cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1)) # Set titles ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}') ax2.set_title(f'a^(i): {vect.numel()}') ax3.set_title(f'p: {prod.numel()}') # Remove xticks for vectors ax2.set_xticks(tuple()) ax3.set_xticks(tuple()) # Plot colourbars fig.colorbar(cax1, ax=ax2) fig.colorbar(cax3, ax=ax3) # Fix y-axis limits ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5) colors = dict( aqua='#8dd3c7', yellow='#ffffb3', lavender='#bebada', red='#fb8072', blue='#80b1d3', orange='#fdb462', green='#b3de69', pink='#fccde5', grey='#d9d9d9', violet='#bc80bd', unk1='#ccebc5', unk2='#ffed6f', ) def _cstr(s, color='black'): if s == ' ': return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>' else: return f'<text style=color:#000;background-color:{color}>{s} </text>' # print html def _print_color(t): display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t]))) # get appropriate color for value def _get_clr(value): colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8', '#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8', '#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f', '#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e') value = int((value * 100) / 5) if value == len(colors): value -= 1 # fixing bugs... return colors[value] def _visualise_values(output_values, result_list): text_colours = [] for i in range(len(output_values)): text = (result_list[i], _get_clr(output_values[i])) text_colours.append(text) _print_color(text_colours) def print_colourbar(): color_range = torch.linspace(-2.5, 2.5, 20) to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range] _print_color(to_print) # Let's only focus on the last time step for now # First, the cell state (Long term memory) def plot_state(data, state, b, decoder, idx_to_symbol): actual_data = decoder(data[b, :, :].numpy(), idx_to_symbol) seq_len = len(actual_data) seq_len_w_pad = len(state) for s in range(state.size(2)): states = torch.sigmoid(state[:, b, s]) _visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))