Source code for cnmodel.an_model.cache

"""
Utilities for generating and caching spike trains from AN model.
"""

import logging
import os, sys, pickle
import numpy as np
from .wrapper import get_matlab, model_ihc, model_synapse, seed_rng
from ..util.filelock import FileLock

try:
    import cochlea
    HAVE_COCHLEA = True
except ImportError:
    HAVE_COCHLEA = False

_cache_version = 2
_cache_path = os.path.join(os.path.dirname(__file__), 'cache')
_index_file = os.path.join(_cache_path, 'index.pk')
_index = None


[docs]def get_spiketrain(cf, sr, stim, seed, **kwds): """ Return an array of spike times in response to the given stimulus. Arrays are automatically cached and may be returned from disk if available. See generate_spiketrain() for a description of arguments. If the flag --ignore-an-cache was given on the command line, then spike times will be regenerated and cached, regardless of the current cache state. If the flag --no-an-cache was given on the command line, then the cache will not be read or written. This can improve overall performance if there is little chance the cache would be re-used. """ filename = get_cache_filename(cf=cf, sr=sr, seed=seed, stim=stim, **kwds) subdir = os.path.dirname(filename) if not os.path.exists(subdir): try: os.mkdir(subdir) except OSError as err: # probably another process already created this directory # since we last checked pass with FileLock(filename): if '--ignore-an-cache' in sys.argv or '--no-an-cache' in sys.argv or not os.path.exists(filename): create = True else: create = False # try loading cached data try: data = np.load(open(filename, 'rb'))['data'] logging.info("Loaded AN spike train from cache: %s", filename) except Exception: create = True sys.excepthook(*sys.exc_info()) logging.error("Error reading AN spike train cache file; will " "re-generate. File: %s", filename) if create: logging.info("Generate new AN spike train: %s", filename) data = generate_spiketrain(cf, sr, stim, seed, **kwds) if '--no-an-cache' not in sys.argv: np.savez_compressed(filename, data=data) return data
[docs]def make_key(**kwds): """ Make a unique key used for caching spike time arrays. """ # flatten any nested dicts for key in list(kwds.keys()): if isinstance(kwds[key], dict): val = kwds.pop(key) for k,v in val.items(): kwds[key + '.' + k] = v # sort and convert to string kwds = list(kwds.items()) kwds.sort() return '_'.join(['%s=%s' % kv for kv in kwds])
[docs]def get_cache_filename(cf, sr, seed, stim, **kwds): global _cache_path subdir = os.path.join(_cache_path, make_key(**stim.key())) filename = make_key(cf=cf, sr=sr, seed=seed, **kwds) filename = os.path.join(subdir, filename) + '.npz' return filename
[docs]def generate_spiketrain(cf, sr, stim, seed, simulator=None, **kwds): """ Generate a new spike train from the auditory nerve model. Returns an array of spike times in seconds. Parameters ---------- cf : float Center frequency of the fiber to simulate sr : int Spontaneous rate group of the fiber: 0=low, 1=mid, 2=high. stim : Sound instance Stimulus sound to be presented on each repetition seed : int >= 0 Random seed simulator : 'cochlea' | 'matlab' | None Specifies the auditory periphery simulator to use. If None, then a simulator will be automatically chosen based on availability. **kwds : All other keyword arguments are given to model_ihc() and model_synapse() based on their names. These include 'species', 'nrep', 'reptime', 'cohc', 'cihc', and 'implnt'. 'simulator' is used to set the simulator ('matlab' or 'cochlea') """ for k in ['pin', 'CF', 'fiberType', 'noiseType']: if k in kwds: raise TypeError("Argument '%s' is not allowed here." % k) ihc_kwds = dict(pin=stim.sound, CF=cf, nrep=1, tdres=stim.dt, reptime=stim.duration*2, cohc=1, cihc=1, species=1) syn_kwds = dict(CF=cf, nrep=1, tdres=stim.dt, fiberType=sr, noiseType=1, implnt=0) # copy any given keyword args to the correct model function for kwd in kwds: if kwd in ihc_kwds: ihc_kwds[kwd] = kwds.pop(kwd) if kwd in syn_kwds: syn_kwds[kwd] = kwds.pop(kwd) if simulator is None: simulator = detect_simulator() if len(kwds) > 0: raise TypeError("Invalid keyword arguments: %s" % list(kwds.keys())) if simulator == 'matlab': seed_rng(seed) vihc = model_ihc(_transfer=False, **ihc_kwds) m, v, psth = model_synapse(vihc, _transfer=False, **syn_kwds) psth = psth.get().ravel() times = np.argwhere(psth).ravel() return times * stim.dt elif simulator == 'cochlea' and HAVE_COCHLEA: fs = int(0.5+1./stim.dt) # need to avoid roundoff error srgrp = [0,0,0] # H, M, L (but input is 1=L, 2=M, H = 3) srgrp[2-sr] = 1 sp = cochlea.run_zilany2014( stim.sound, fs=fs, anf_num=srgrp, cf=cf, seed=seed, species='cat') return np.array(sp.spikes.values[0]) else: # it remains possible to have a typo.... raise ValueError("anmodel/cache.py: Simulator must be specified as either MATLAB or cochlea; found %s" % simulator)
[docs]def detect_simulator(): """Return the name of any available auditory periphery model. Return 'cochlea' if the Rudnicki cochlea model can be imported. If not, return 'matlab' if the Zilany model can be accessed via MATLAB. If not, raise an exception. """ try: import cochlea simulator = 'cochlea' except ImportError: get_matlab() simulator = 'matlab' return simulator