Skip to content

Generate Data Distribution Heatmaps

This notebook provides code for generating a data distribution heatmap from the files in the data_distribution folder.

Imports

from os import listdir
from os.path import isfile, join
from statistics import fmean

import pandas as pd
from matplotlib import cm
from matplotlib.colors import Normalize

pd.options.mode.chained_assignment = None  # default='warn'
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

File loading

file_list = [f for f in listdir("outputs") if isfile(join("outputs", f))]
dfs = []
for i in range(len(file_list)):
    fname = "outputs/" + file_list[i]
    df_temp = pd.read_csv(fname)
    df_temp.set_index(['client'], inplace=True)
    df_temp['run'] = file_list[i].replace(".csv", "")
    dfs.append(df_temp)
df = pd.concat(dfs).reset_index()

Dataframe preparation

df = df.melt(id_vars=["client", "run"], var_name="class").fillna(0)
df['run'] = df['run'].astype("category")

Wrapper for drawing the heatmaps

def draw_heatmap(*args, **kwargs):
    data = kwargs.pop('data')
    d = data.pivot(index=args[0], columns=args[1], values=args[2])
    sns.heatmap(d, **kwargs)

Plotting

mpl.rcParams['font.family'] = "serif"
mpl.rcParams['font.serif'] = "Charter"
sns.set_style("ticks")
sns.set_theme(font="Charter")
g = sns.FacetGrid(df, col="run", col_wrap=2, sharey=True, sharex=True, xlim=(0.5,1.0), height=7, aspect=1)
cbar_ax = g.fig.add_axes([1, .15, .03, .7])

g.map_dataframe(draw_heatmap, 'client', 'class', 'value', cbar=True, square = False, vmin=0, vmax=400, cmap='viridis', cbar_ax=cbar_ax)
cbar_ax.set_ylabel("Number of samples", fontproperties={'family': 'Charter'},fontsize='x-large')
g.set_xlabels("Class", fontproperties={'family': 'Charter'},fontsize='x-large')
g.set_ylabels("Clients", fontproperties={'family': 'Charter'},fontsize='x-large')
g.fig.subplots_adjust(top=0.95)
axs = g.axes_dict
g.fig.suptitle("Supported Data Label Distributions", fontsize='xx-large', fontproperties={'family': 'Charter'})
g.show()
/home/jsteimle/anaconda3/envs/Flower/lib/python3.11/site-packages/seaborn/axisgrid.py:118: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  self._figure.tight_layout(*args, **kwargs)

png