Python 如何并排绘制 2 个 seaborn lmplots?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/33049884/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How to plot 2 seaborn lmplots side-by-side?
提问by samthebrand
Plotting 2 distplots or scatterplots in a subplot works great:
在子图中绘制 2 个 distplots 或散点图效果很好:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
%matplotlib inline
# create df
x = np.linspace(0, 2 * np.pi, 400)
df = pd.DataFrame({'x': x, 'y': np.sin(x ** 2)})
# Two subplots
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
ax1.plot(df.x, df.y)
ax1.set_title('Sharing Y axis')
ax2.scatter(df.x, df.y)
plt.show()
But when I do the same with an lmplotinstead of either of the other types of charts I get an error:
但是,当我使用 anlmplot而不是其他类型的图表执行相同操作时,出现错误:
AttributeError: 'AxesSubplot' object has no attribute 'lmplot'
AttributeError: 'AxesSubplot' 对象没有属性 'lmplot'
Is there any way to plot these chart types side by side?
有没有办法并排绘制这些图表类型?
采纳答案by Paul H
You get that error because matplotlib and its objects are completely unaware of seaborn functions.
您会收到该错误,因为 matplotlib 及其对象完全不知道 seaborn 函数。
Pass your axes objects (i.e., ax1and ax2) to seaborn.regplotor you can skip defining those and use the colkwarg of seaborn.lmplot
将您的轴对象(即,ax1和ax2)传递给,seaborn.regplot或者您可以跳过定义这些对象并使用colkwargseaborn.lmplot
With your same imports, pre-defining your axes and using regplotlooks like this:
使用相同的导入,预定义轴并使用regplot如下所示:
# create df
x = np.linspace(0, 2 * np.pi, 400)
df = pd.DataFrame({'x': x, 'y': np.sin(x ** 2)})
df.index.names = ['obs']
df.columns.names = ['vars']
idx = np.array(df.index.tolist(), dtype='float') # make an array of x-values
# call regplot on each axes
fig, (ax1, ax2) = plt.subplots(ncols=2, sharey=True)
sns.regplot(x=idx, y=df['x'], ax=ax1)
sns.regplot(x=idx, y=df['y'], ax=ax2)
Using lmplot requires your dataframe to be tidy. Continuing from the code above:
tidy = (
df.stack() # pull the columns into row variables
.to_frame() # convert the resulting Series to a DataFrame
.reset_index() # pull the resulting MultiIndex into the columns
.rename(columns={0: 'val'}) # rename the unnamed column
)
sns.lmplot(x='obs', y='val', col='vars', hue='vars', data=tidy)


