Source code for spikelib.models

"""Set of lineal and no lineal neuron model.

For example Spike Triggered Average (STA) is a algorithm to compute a lineal
filter to estimate receptive field of a cell.
"""
from functools import partial
from multiprocessing import cpu_count
from multiprocessing import Pool
from multiprocessing import freeze_support

import numpy as np

from spikelib.io import load_stim_multi


# A global dictionary storing the variables passed from the initializer.
GLOBAL_STIM = {}


[docs]def ste(stim, stim_time, spikes, nsamples_before=30, nsamples_after=0): """Get all windows of stimulis triggered by a spike. This function create a iterator to get a set of stimulus for a spike. Parameters ---------- time_stim : array_like The time array corresponding to the start of each frame in the stimulus. stimulus : array_like A spatiotemporal or temporal stimulus array, where time is the first dimension. spikes : array_like A list or ndarray of spike times. nsamples_before : int Number of samples to include in the STE before the spike, defaults: 30. nsamples_after : int Number of samples to include in the STE after the spike, defaults: 0. Returns ------- ste : generator A generator that yields samples from the spike-triggered ensemble. Notes ----- The spike-triggered ensemble (STE) is the set of all stimuli immediately surrounding a spike. If the full stimulus distribution is p(s), the STE is p(s | spike). """ msg = 'time_stim.shape[0] must be equal than stim.shape[0]' assert stim.shape[0] == stim_time.size[0], msg bins_stim = np.append(stim_time, [stim_time[-1]*2 - stim_time[-2]]) nbefore, nafter = nsamples_before, nsamples_after len_stim = stim.shape[0] # Number of spikes in each frame of the stimulus (nspks_in_frames, _) = np.histogram(spikes, bins=bins_stim) valid_frames = np.where(nspks_in_frames > 0)[0] filter_valid_fame = (valid_frames >= nbefore) & \ (valid_frames < len_stim - nafter) valid_frames = valid_frames[filter_valid_fame] spike_in_frames = nspks_in_frames[valid_frames] # Valid frames consider itself as reference for kfr, nspks in zip(valid_frames, spike_in_frames): yield nspks*stim[kfr+1-nbefore:kfr+1+nafter, :, :].astype('float64')
[docs]def sta(stim, stim_time, spikes, nsamples_before=30, nsamples_after=0): """Compute a spike-triggered average. Parameters ---------- stim : array_like A spatiotemporal or temporal stimulus array, where time is the first dimension. stim_time : array_like The time array corresponding to the start of each frame in the stimulus. spikes : array_like A list or array of spike times nsamples_before : int Number of samples to include in the STA before the spike, defaults: 30. nsamples_after : int Number of samples to include in the STA after the spike, efaults: 0 Returns ------- sta : array_like The spatiotemporal spike-triggered average. References ---------- A simple white noise analysis of neuronal light responses. E J Chichilnisky """ nframe_stim, ysize, xsize = stim.shape sta_array = np.zeros((nsamples_before+nsamples_after, ysize, xsize)) ste_it = ste(stim, stim_time, spikes, nsamples_before, nsamples_after) for kwindow_stim in ste_it: sta_array += kwindow_stim if sta_array.any(): sta_array /= float(spikes.size) return sta_array
[docs]def multi_sta(spiketimes, stim_time, nsamples_before=30, nsamples_after=0): """Compute the Spike Triggered Average for a cell. Parameters ---------- spiketimes : tuple tuple(name, spiketimes) to compute STA stim_time : array_like array with start time of each frame of stim nsamples_before : int Number of samples to include in the STA before the spike nsamples_after : int Number of samples to include in the STA after the spike (default: 0) Returns ------- unit_name : str name of the unit sta_array : array_like STA array """ unit_name, spk_time = spiketimes stim_matrix = np.frombuffer( GLOBAL_STIM['stim']).reshape(GLOBAL_STIM['stim_shape']) nframe_stim, ysize, xsize = stim_matrix.shape bins_stim = np.append(stim_time, [stim_time[-1]*2 - stim_time[-2]]) nspikes_in_frame, _ = np.histogram(spk_time, bins=bins_stim) valid_frames = np.where(nspikes_in_frame > 0)[0] filter_valid_fame = (valid_frames >= nsamples_before) & \ (valid_frames < nframe_stim - nsamples_after) valid_frames = valid_frames[filter_valid_fame] spike_in_frames = nspikes_in_frame[valid_frames] nframes_sta = nsamples_before+nsamples_after sta_array = np.zeros((nframes_sta, ysize, xsize), dtype=np.float64) # Valid frames consider itself as reference for kframe, nspikes in zip(valid_frames, spike_in_frames): start_frame = kframe-nsamples_before+1 end_frame = kframe+nsamples_after+1 sta_array += nspikes*stim_matrix[start_frame:end_frame, :, :] if sta_array.any(): sta_array /= spike_in_frames.sum() return (unit_name, sta_array)
[docs]def init_multi_sta(stim, stim_shape): """Set stim array to a global variable.""" GLOBAL_STIM['stim'] = stim GLOBAL_STIM['stim_shape'] = stim_shape
[docs]def run_multi_sta(stim_path, stim_time, spiketimes, nsamples_before=30, nsamples_after=0, normed_stim=True, channel_stim='g'): """Run sta in multiprocessing. Parameters ---------- stim_path : str file of the stim stim_time : array_like times of start and end of stim spiketimes : dict spiketimes to compute sta Returns ------- stats: list of tuple return a list of tuple with unit_name and sta_array """ freeze_support() stim, stim_shape = load_stim_multi(stim_path, normed=normed_stim, channel=channel_stim, dataset='checkerboard', ) print(stim_shape) wrap_sta = partial(multi_sta, stim_time=stim_time, nsamples_before=nsamples_before, nsamples_after=nsamples_after, ) pool = Pool(processes=cpu_count(), initializer=init_multi_sta, initargs=(stim, stim_shape), ) result = pool.map(wrap_sta, spiketimes) return result