pandas 在python中基于条件绘制多色线

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/31590184/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-09-13 23:40:14  来源:igfitidea点击:

Plot Multicolored line based on conditional in python

pythonpandasmatplotlibplot

提问by Mishiko

I have a pandas dataframe with three columns and a datetime index

我有一个包含三列和日期时间索引的 Pandas 数据框

date        px_last  200dma     50dma           
2014-12-24  2081.88 1953.16760  2019.2726
2014-12-26  2088.77 1954.37975  2023.7982
2014-12-29  2090.57 1955.62695  2028.3544
2014-12-30  2080.35 1956.73455  2032.2262
2014-12-31  2058.90 1957.66780  2035.3240

I would like to make a time series plot of the 'px_last' column that is colored green if on the given day the 50dma is above the 200dma value and colored red if the 50dma value is below the 200dma value. I have seen this example, but can't seem to make it work for my case http://matplotlib.org/examples/pylab_examples/multicolored_line.html

我想绘制“px_last”列的时间序列图,如果在给定的一天 50dma 高于 200dma 值,则该列显示为绿色,如果 50dma 值低于 200dma 值,则显示为红色。我看过这个例子,但似乎不能让它适用于我的案例 http://matplotlib.org/examples/pylab_examples/multicolored_line.html

采纳答案by Jianxun Li

Here is an example to do it without matplotlib.collections.LineCollection. The idea is to first identify the cross-over point and then using a plotfunction via groupby.

这是一个无需matplotlib.collections.LineCollection. 这个想法是首先确定交叉点,然后plot通过 groupby使用函数。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# simulate data
# =============================
np.random.seed(1234)
df = pd.DataFrame({'px_last': 100 + np.random.randn(1000).cumsum()}, index=pd.date_range('2010-01-01', periods=1000, freq='B'))
df['50dma'] = pd.rolling_mean(df['px_last'], window=50)
df['200dma'] = pd.rolling_mean(df['px_last'], window=200)
df['label'] = np.where(df['50dma'] > df['200dma'], 1, -1)


# plot
# =============================
df = df.dropna(axis=0, how='any')

fig, ax = plt.subplots()

def plot_func(group):
    global ax
    color = 'r' if (group['label'] < 0).all() else 'g'
    lw = 2.0
    ax.plot(group.index, group.px_last, c=color, linewidth=lw)

df.groupby((df['label'].shift() * df['label'] < 0).cumsum()).apply(plot_func)

# add ma lines
ax.plot(df.index, df['50dma'], 'k--', label='MA-50')
ax.plot(df.index, df['200dma'], 'b--', label='MA-200')
ax.legend(loc='best')

enter image description here

在此处输入图片说明

回答by hhquark

Building on @Jianxun Li's answer, here's a version that's more easily extendible to 3+ colors:

基于@Jianxun Li的回答,这里有一个更容易扩展到 3+ 种颜色的版本:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# Simulate data
np.random.seed(1234)
df = pd.DataFrame(
    {'px_last': 100 + np.random.randn(1000).cumsum()},
    index=pd.date_range('2010-01-01', periods=1000, freq='B'),
)
df['50dma'] = df['px_last'].rolling(window=50, center=False).mean()
df['200dma'] = df['px_last'].rolling(window=200, center=False).mean()

## Apply labels
df['label'] = 'out of bounds'
df.loc[abs(df['50dma'] - df['200dma']) >= 7, 'label'] = '|50dma - 200dma| >= 7'
df.loc[abs(df['50dma'] - df['200dma']) < 7, 'label'] = '|50dma - 200dma| < 7'
df.loc[abs(df['50dma'] - df['200dma']) < 5, 'label'] = '|50dma - 200dma| < 5'
df.loc[abs(df['50dma'] - df['200dma']) < 3, 'label'] = '|50dma - 200dma| < 3'
df = df[df['label'] != 'out of bounds']

## Convert labels to colors
label2color = {
    '|50dma - 200dma| < 3': 'green',
    '|50dma - 200dma| < 5': 'yellow',
    '|50dma - 200dma| < 7': 'orange',
    '|50dma - 200dma| >= 7': 'red',
}
df['color'] = df['label'].apply(lambda label: label2color[label])

# Create plot
fig, ax = plt.subplots()

def gen_repeating(s):
    """Generator: groups repeated elements in an iterable
    E.g.
        'abbccc' -> [('a', 0, 0), ('b', 1, 2), ('c', 3, 5)]
    """
    i = 0
    while i < len(s):
        j = i
        while j < len(s) and s[j] == s[i]:
            j += 1
        yield (s[i], i, j-1)
        i = j

## Add px_last lines
for color, start, end in gen_repeating(df['color']):
    if start > 0: # make sure lines connect
        start -= 1
    idx = df.index[start:end+1]
    df.loc[idx, 'px_last'].plot(ax=ax, color=color, label='')

## Add 50dma and 200dma lines
df['50dma'].plot(ax=ax, color='k', ls='--', label='MA$_{50}$')
df['200dma'].plot(ax=ax, color='b', ls='--', label='MA$_{200}$')

## Get artists and labels for legend and chose which ones to display
handles, labels = ax.get_legend_handles_labels()

## Create custom artists
g_line = plt.Line2D((0,1),(0,0), color='green')
y_line = plt.Line2D((0,1),(0,0), color='yellow')
o_line = plt.Line2D((0,1),(0,0), color='orange')
r_line = plt.Line2D((0,1),(0,0), color='red')

## Create legend from custom artist/label lists
ax.legend(
    handles + [g_line, y_line, o_line, r_line],
    labels + [
        '|MA$_{50} - $MA$_{200}| < 3$',
        '|MA$_{50} - $MA$_{200}| < 5$',
        '|MA$_{50} - $MA$_{200}| < 7$',
        '|MA$_{50} - $MA$_{200}| \geq 7$',
    ],
    loc='best',
)

# Display plot
plt.show()

I've also added a fancy-ish legend.

我还添加了一个奇特的传说。

multicolor-line

multicolor-line