Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
809 views
in Technique[技术] by (71.8m points)

python - Getting a legend in a seaborn FacetGrid heatmap plot

How can we get legends for seaborn FacetGrid heatmaps? The .add_legend() method isn't working for me.

Using code from this previous question:

import pandas as pd
import numpy as np
import itertools
import seaborn as sns

print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
   product = list(itertools.product(*itrs))
   return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}

methods=['method 1', 'method2', 'method 3', 'method 4']
times = range(0,100,10)
data = pd.DataFrame(expandgrid(methods, times, times))
data.columns = ['method', 'dtsi','rtsi']
data['nw_score'] = np.random.sample(data.shape[0])

def facet(data,color):
    data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
    g = sns.heatmap(data, cmap='Blues', cbar=False)

with sns.plotting_context(font_scale=5.5):
    g = sns.FacetGrid(data, col="method", col_wrap=2, size=3, aspect=1)
    g = g.map_dataframe(facet)
    g.add_legend()
    g.set_titles(col_template="{col_name}", fontweight='bold', fontsize=18)

enter image description here

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

What you want (in matplotlib lingo) is a colorbar, not a legend. In matplotlib, the former is used for continuous data, while the latter is used for categorical data. Colorbar support isn't built into FacetGrid, but it is not hard to expand your example code to add a colorbar:

import pandas as pd
import numpy as np
import itertools
import seaborn as sns

methods=['method 1', 'method2', 'method 3', 'method 4']
times = range(0, 100, 10)
data = pd.DataFrame(list(itertools.product(methods, times, times)))
data.columns = ['method', 'dtsi','rtsi']
data['nw_score'] = np.random.sample(data.shape[0])

def facet_heatmap(data, color, **kws):
    data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
    sns.heatmap(data, cmap='Blues', **kws)  # <-- Pass kwargs to heatmap

with sns.plotting_context(font_scale=5.5):
    g = sns.FacetGrid(data, col="method", col_wrap=2, size=3, aspect=1)

cbar_ax = g.fig.add_axes([.92, .3, .02, .4])  # <-- Create a colorbar axes

g = g.map_dataframe(facet_heatmap,
                    cbar_ax=cbar_ax,
                    vmin=0, vmax=1)  # <-- Specify the colorbar axes and limits

g.set_titles(col_template="{col_name}", fontweight='bold', fontsize=18)
g.fig.subplots_adjust(right=.9)  # <-- Add space so the colorbar doesn't overlap the plot

enter image description here

I've indicated the changes I made and the rationale for them as inline comments.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...