Python 如何并排绘制 2 个 seaborn lmplots?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/33049884/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How to plot 2 seaborn lmplots side-by-side?
提问by samthebrand
Plotting 2 distplots or scatterplots in a subplot works great:
在子图中绘制 2 个 distplots 或散点图效果很好:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
%matplotlib inline
# create df
x = np.linspace(0, 2 * np.pi, 400)
df = pd.DataFrame({'x': x, 'y': np.sin(x ** 2)})
# Two subplots
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
ax1.plot(df.x, df.y)
ax1.set_title('Sharing Y axis')
ax2.scatter(df.x, df.y)
plt.show()
But when I do the same with an lmplot
instead of either of the other types of charts I get an error:
但是,当我使用 anlmplot
而不是其他类型的图表执行相同操作时,出现错误:
AttributeError: 'AxesSubplot' object has no attribute 'lmplot'
AttributeError: 'AxesSubplot' 对象没有属性 'lmplot'
Is there any way to plot these chart types side by side?
有没有办法并排绘制这些图表类型?
采纳答案by Paul H
You get that error because matplotlib and its objects are completely unaware of seaborn functions.
您会收到该错误,因为 matplotlib 及其对象完全不知道 seaborn 函数。
Pass your axes objects (i.e., ax1
and ax2
) to seaborn.regplot
or you can skip defining those and use the col
kwarg of seaborn.lmplot
将您的轴对象(即,ax1
和ax2
)传递给,seaborn.regplot
或者您可以跳过定义这些对象并使用col
kwargseaborn.lmplot
With your same imports, pre-defining your axes and using regplot
looks like this:
使用相同的导入,预定义轴并使用regplot
如下所示:
# create df
x = np.linspace(0, 2 * np.pi, 400)
df = pd.DataFrame({'x': x, 'y': np.sin(x ** 2)})
df.index.names = ['obs']
df.columns.names = ['vars']
idx = np.array(df.index.tolist(), dtype='float') # make an array of x-values
# call regplot on each axes
fig, (ax1, ax2) = plt.subplots(ncols=2, sharey=True)
sns.regplot(x=idx, y=df['x'], ax=ax1)
sns.regplot(x=idx, y=df['y'], ax=ax2)
Using lmplot requires your dataframe to be tidy. Continuing from the code above:
tidy = (
df.stack() # pull the columns into row variables
.to_frame() # convert the resulting Series to a DataFrame
.reset_index() # pull the resulting MultiIndex into the columns
.rename(columns={0: 'val'}) # rename the unnamed column
)
sns.lmplot(x='obs', y='val', col='vars', hue='vars', data=tidy)