stMind

about Tech, Computer vision and Machine learning

mpld3の新しいサンプル

mpld3で新しいサンプルが紹介されていたので、mainとくっつけて動かしてみた。
点をマウスオーバーするとグラフが変化するようになっている。

from mpld3.plugins import PluginBase
import jinja2
import json
from mpld3 import show_d3
import matplotlib.pyplot as plt
import numpy as np


class LinkedView(PluginBase):
    """A simple plugin showing how multiple axes can be linked"""

    FIG_JS = jinja2.Template("""
    var linedata{{ id }} = {{ linedata }};

    ax{{ axid }}.axes.selectAll(".paths{{ collid }}")
    .on("mouseover", function(d, i){
    line{{ elid }}.data = linedata{{ id }}[i];
    line{{ elid }}.lineobj.transition()
    .attr("d", line{{ elid }}.line(line{{ elid }}.data))
    .style("stroke", this.style.fill);})
    """)

    def __init__(self, points, line, linedata):
        self.points = points
        self.line = line
        self.linedata = linedata
        self.id = self.generate_unique_id()

    def _fig_js_args(self):
        points = self._get_d3obj(self.points)
        line = self._get_d3obj(self.line)
        return dict(id=self.id,
                    axid=points.axid,
                    collid=points.collid,
                    elid=line.elid,
                    lineaxid=line.axid,
                    lineid=line.lineid,
                    linedata=json.dumps(self.linedata))


if __name__ == '__main__':
    fig, ax = plt.subplots(2)

    # scatter periods and amplitudes
    np.random.seed(0)
    P = np.random.random(size=10)
    A = np.random.random(size=10)
    x = np.linspace(0, 10, 100)
    data = np.array([[x, Ai * np.sin(x / Pi)]
                     for (Ai, Pi) in zip(A, P)])
    points = ax[1].scatter(P, A, c=P + A,
                           s=200, alpha=0.5)
    ax[1].set_xlabel('Period')
    ax[1].set_ylabel('Amplitude')

    # create the line object
    lines = ax[0].plot(x, 0 * x, '-w', lw=3, alpha=0.5)
    ax[0].set_ylim(-1, 1)

    # transpose line data and add plugin
    linedata = data.transpose(0, 2, 1).tolist()
    fig.plugins = [LinkedView(points, lines[0], linedata)]

    # open graphs in a browser
    show_d3(fig)