Source code for fsl.scripts.atlasq

#!/usr/bin/env python
#
# main.py - The atlasq program.
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
"""This module contains the FSL ``atlasq`` program, the successor to
``atlasquery``.
"""


from __future__ import print_function

import itertools as it
import              sys
import              argparse
import              textwrap
import              warnings
import              logging
import numpy     as np

import fsl.data.image   as fslimage
import fsl.data.atlases as fslatlases
import fsl.version      as fslversion


log = logging.getLogger(__name__)


SHORT_QUERY_DELIM = '\t'


[docs] class IdentifyError(Exception): """Exception raised by the ``identifyAtlas`` when an atlas cannot be identified. """ pass
[docs] class HelpFormatter(argparse.RawDescriptionHelpFormatter): """A custom ``argparse.HelpFormatter`` class which customises a few annoying things about default ``argparse`` behaviour. """
[docs] def _format_usage(self, usage, actions, groups, prefix): # Inhibit the 'usage: ' prefix return argparse.RawDescriptionHelpFormatter._format_usage( self, usage, actions, groups, '')
[docs] def listAtlases(namespace): """List all available atlases. """ atlases = fslatlases.listAtlases() if namespace.extended: for a in atlases: print('{} [{}]'.format(a.name, a.atlasID)) print(' Spec: {}'.format(a.specPath)) print(' Type: {}'.format(a.atlasType)) print(' Labels: {}'.format(len(a.labels))) for i in a.images: print(' Image: {}'.format(i)) for i in a.summaryImages: print(' Summary image: {}'.format(i)) print() else: ids = [a.atlasID for a in atlases] names = [a.name for a in atlases] printColumns((ids, names), ('ID', 'Full name'))
[docs] def summariseAtlas(namespace): """Print information about one atlas. """ a = identifyAtlas(namespace.atlas) print('{} [{}]'.format(a.name, a.atlasID)) print(' Spec: {}'.format(a.specPath)) print(' Type: {}'.format(a.atlasType)) print(' Labels: {}'.format(len(a.labels))) for i in a.images: print(' Image: {}'.format(i)) for i in a.summaryImages: print(' Summary image: {}'.format(i)) indices = [l.index for l in a.labels] names = [l.name for l in a.labels] xs = [l.x for l in a.labels] ys = [l.y for l in a.labels] zs = [l.z for l in a.labels] printColumns(( indices, names, xs, ys, zs), ('Index', 'Label', 'X', 'Y', 'Z'))
[docs] def queryAtlas(namespace): """Query an atlas with coordinates or masks.""" atlasDesc = identifyAtlas(namespace.atlas) wcoords = namespace.coord vcoords = namespace.voxel masks = [fslimage.Image(m) for m in namespace.mask] worder = namespace.coord_order vorder = namespace.voxel_order morder = namespace.mask_order atlas = fslatlases.loadAtlas(atlasDesc.atlasID, loadSummary=namespace.label, resolution=namespace.resolution) mlabels, mprops = maskQuery( atlas, masks) wlabels, wprops = coordQuery(atlas, wcoords, False) vlabels, vprops = coordQuery(atlas, vcoords, True) order = list(it.chain(morder, worder, vorder)) labels = list(it.chain(mlabels, wlabels, vlabels)) props = list(it.chain(mprops, wprops, vprops)) sources = list(it.chain(masks, wcoords, vcoords)) types = list(it.chain(['mask'] * len(masks), ['coordinate'] * len(wcoords), ['voxel'] * len(vcoords))) labels = [l for (o, l) in sorted(zip(order, labels))] props = [p for (o, p) in sorted(zip(order, props))] sources = [s for (o, s) in sorted(zip(order, sources))] types = [t for (o, t) in sorted(zip(order, types))] if namespace.short: queryShortOutput(atlas, sources, types, labels, props) else: queryLongOutput( atlas, sources, types, labels, props)
[docs] def queryShortOutput(atlas, sources, types, allLabels, allProps): """Called by ``queryAtlas`` when short output is requested. """ for source, stype, labels, props in zip(sources, types, allLabels, allProps): if stype == 'coordinate': source = '{:0.2f} {:0.2f} {:0.2f}'.format(*source) elif stype == 'voxel': source = '{:0.0f} {:0.0f} {:0.0f}'.format(*source) elif stype == 'mask': source = source.dataSource results = [] labels = labelNames(atlas, labels) # Coordinate lookup for a label # atlas just returns a label if stype in ('coordinate', 'voxel') and \ isinstance(atlas, fslatlases.LabelAtlas): results.append('{}'.format(labels[0])) # All other queries return a list of # labels and proportions. We output # them from highest proportion to # lowest else: for p, l in reversed(sorted(zip(props, labels))): results.append('{} {:0.4f}'.format(l, p)) print(SHORT_QUERY_DELIM.join([stype, source] + results))
[docs] def queryLongOutput(atlas, sources, types, allLabels, allProps): """Called by ``queryAtlas`` when long output is requested. """ def summaryCoord(source, stype, labels, props, names): label = labels[0] name = names[ 0] if label is None: label = np.nan else: label = int(label) fields = ['name', 'index'] values = [name, label] if atlas.desc.atlasType == 'probabilistic': fields.append('summary value') values.append(label + 1) else: fields[1] = 'label' printColumns((fields, values)) def proportions(source, stype, labels, props, names): if len(labels) == 0: print('No results') return proplabelnames = list(reversed(sorted(zip(props, labels, names)))) props = [pln[0] for pln in proplabelnames] labels = [pln[1] for pln in proplabelnames] names = [pln[2] for pln in proplabelnames] props = ['{:0.4f}'.format(p) for p in props] titles = ['name', 'index', 'proportion'] columns = [ names, labels, props] if atlas.desc.atlasType == 'probabilistic': sumvals = [l + 1 for l in labels] titles .insert(2, 'summary value') columns.insert(2, sumvals) else: titles[1] = 'label' printColumns(columns, titles) for source, stype, labels, props in zip(sources, types, allLabels, allProps): if stype == 'coordinate': sourcestr = '{:0.2f} {:0.2f} {:0.2f}'.format(*source) elif stype == 'voxel': sourcestr = '{:0.0f} {:0.0f} {:0.0f}'.format(*source) elif stype == 'mask': sourcestr = '{}'.format(source.dataSource) title = '{} {}'.format(stype, sourcestr) print('-' * (4 + len(title))) print('| {} |'.format(title)) print('-' * (4 + len(title))) print() names = labelNames(atlas, labels) if stype in ('coordinate', 'voxel') and \ isinstance(atlas, fslatlases.LabelAtlas): summaryCoord(source, stype, labels, props, names) else: proportions(source, stype, labels, props, names) print()
[docs] def ohi(namespace): """Emulates the FSL ``atlasquery`` tool.""" atlasDesc = None def dumpatlases(): atlases = [a.name for a in fslatlases.listAtlases()] print('\n'.join(sorted(atlases))) if namespace.dumpatlases: dumpatlases() return for a in fslatlases.listAtlases(): if a.name == namespace.atlas: atlasDesc = a break if atlasDesc is None: print('Invalid atlas name. Try one of:') dumpatlases() return # atlasquery always uses 2mm atlas # versions when a 2mm is available reses = [p[0] for p in atlasDesc.pixdims] if 2 in reses: res = 2 else: res = max(reses) # Mask query. if namespace.ohiMask is not None: mask = fslimage.Image(namespace.ohiMask) labels, props = maskQuery(atlasDesc, [mask], resolution=res) labels = labels[0] props = props[ 0] for lbl, prop in zip(labels, props): if atlasDesc.atlasType == 'probabilistic': lbl = atlasDesc.labels[int(lbl)].name elif atlasDesc.atlasType == 'label': lbl = atlasDesc.find(value=int(lbl)).name print('{}:{:0.4f}'.format(lbl, prop)) # Coordinate query else: coord = namespace.coord.strip('"') coord = [float(c) for c in coord.split(',')] labels, props = coordQuery(atlasDesc, [coord], False, resolution=res) labels = labels[0] props = props[ 0] if atlasDesc.atlasType == 'label': labels = labels[0] if labels is None: label = 'Unclassified' else: label = atlasDesc.find(value=int(labels)).name print('<b>{}</b><br>{}'.format(atlasDesc.name, label)) elif atlasDesc.atlasType == 'probabilistic': labelStrs = [] if len(labels) > 0: props, labels = zip(*reversed(sorted(zip(props, labels)))) for label, prop in zip(labels, props): label = atlasDesc.labels[int(label)].name labelStrs.append('{:d}% {}'.format(int(round(prop)), label)) if len(labelStrs) == 0: labels = 'No label found!' else: labels = ', '.join(labelStrs) print('<b>{}</b><br>{}'.format(atlasDesc.name, labels))
[docs] def atlasOrDesc(aord, *args, **kwargs): """If ``aord`` is an ``Atlas`` it is returned. Otherwise it is assumed to be an ``AtlasDescription``, in which case the corresponding ``Atlas`` is loaded and returned. """ if isinstance(aord, fslatlases.Atlas): return aord else: return fslatlases.loadAtlas(aord.atlasID, *args, **kwargs)
[docs] def labelNames(atlas, labels): """Converts the given sequence of ``labels`` into region names. """ names = [] for l in labels: if l is None: names.append('No label') else: names.append(atlas.desc.labels[int(l)].name) return names
[docs] def maskQuery(atlas, masks, *args, **kwargs): """Queries the ``atlas`` at the given ``masks``. """ allLabels = [] allProps = [] atlas = atlasOrDesc(atlas, *args, **kwargs) for mask in masks: if isinstance(atlas, fslatlases.LabelAtlas): labels, props = atlas.maskLabel(mask) # We need to subtract 1 from summary # image label values to get the label # index, for probabilistic atlases. if atlas.desc.atlasType == 'probabilistic': labels = [l - 1 for l in labels] elif isinstance(atlas, fslatlases.ProbabilisticAtlas): labels = [] props = [] zprops = atlas.maskValues(mask) for i in range(len(zprops)): if zprops[i] > 0: props.append(zprops[i]) labels.append(atlas.desc.labels[i].index) allLabels.append(labels) allProps .append(props) return allLabels, allProps
[docs] def coordQuery(atlas, coords, voxel, *args, **kwargs): """Queries the ``atlas`` at the given ``coords``. """ atlas = atlasOrDesc(atlas, *args, **kwargs) allLabels = [] allProps = [] for coord in coords: if isinstance(atlas, fslatlases.ProbabilisticAtlas): props = atlas.values(coord, voxel=voxel) labels = [] nzprops = [] for i, p in enumerate(props): if p != 0: nzprops.append(p) labels .append(atlas.desc.labels[i].index) allLabels.append(labels) allProps .append(nzprops) elif isinstance(atlas, fslatlases.LabelAtlas): label = atlas.label(coord, voxel=voxel) # we need to subtract 1 from the label # value to get the label index, for # probabilistic summary images. if atlas.desc.atlasType == 'probabilistic': # 0 == background if label == 0: label = None if label is not None: label = label - 1 allLabels.append([label]) allProps .append([None]) return allLabels, allProps
[docs] def identifyAtlas(idOrName): """Given a partial atlas ID or name, tries to find an atlas which uniquely matches it. """ # TODO Use difflib or some fuzzy matching library? idOrName = idOrName.lower().strip() atlases = fslatlases.listAtlases() allNames = [a.name .lower() for a in atlases] allIDs = [a.atlasID.lower() for a in atlases] # First test for an exact match nameMatches = [idOrName == n for n in allNames] idMatches = [idOrName == i for i in allIDs] nameMatches = [i for i in range(len(nameMatches)) if nameMatches[i]] idMatches = [i for i in range(len(idMatches)) if idMatches[ i]] if len(nameMatches) + len(idMatches) == 1: if len(nameMatches) == 1: return atlases[nameMatches[0]] else: return atlases[idMatches[ 0]] # If no exact match, test for a partial match nameMatches = [idOrName in n for n in allNames] idMatches = [idOrName in i for i in allIDs] nameMatches = [i for i in range(len(nameMatches)) if nameMatches[i]] idMatches = [i for i in range(len(idMatches)) if idMatches[ i]] totalMatches = len(nameMatches) + len(idMatches) # No matches if totalMatches == 0: raise IdentifyError('Could not find any atlas ' 'matching {}'.format(idOrName)) # More than two matches, or a # different ID/name pair matched if totalMatches > 2 or (totalMatches == 2 and nameMatches != idMatches): possible = [allNames[m] for m in nameMatches] + \ [allIDs[ m] for m in idMatches] raise IdentifyError('{} matched multiple atlases! Could match one ' 'of: {}'.format(idOrName, ', '.join(possible))) # Either one exact match to an ID or name, # or a match to an equivalent ID/name if len(nameMatches) == 1: return atlases[nameMatches[0]] else: return atlases[idMatches[ 0]]
[docs] def printColumns(columns, titles=None, delim=' | ', sep=True, strip=False): """Convenience function which pretty-prints a collection of columns in a tabular format. :arg columns: A sequence of columns, where each column is a list of strings. :arg titles: A sequence of titles, one for each column. """ if len(columns) == 0: return columns = list(columns) for i, c in enumerate(columns): col = list(map(str, c)) if titles is not None: columns[i] = [titles[i]] + col else: columns[i] = col colLens = [] for col in columns: maxLen = max([len(r) for r in col]) colLens.append(maxLen) fmtStr = delim.join(['{{:<{}s}}'.format(l) for l in colLens]) if titles is not None and sep: titles = [col[0] for col in columns] columns = [col[1:] for col in columns] separator = ['-' * l for l in colLens] print(fmtStr.format(*titles)) print(fmtStr.format(*separator)) nrows = len(columns[0]) for i in range(nrows): row = [col[i] for col in columns] row = fmtStr.format(*row) if strip: row = row.strip() print(row)
[docs] def parseArgs(args): """Parses command line arguments, returning an ``argparse.Namespace`` object. """ # Show help if no args are provided if len(args) == 0 or \ (len(args) == 1 and args[0] in ('ohi', 'summary', 'query')): args = list(args) + ['-h'] # Hack to make argparse accept # coordinates with a negative sign # (ohi/atlasquery interface only) if args[0] == 'ohi': try: cidx = args.index('-c') coord = args[cidx + 1] args[cidx + 1] = '"{}"'.format(coord) except: pass prolog = 'FSL atlasq {}'.format(fslversion.__version__) helps = { 'ohi' : 'Emulate the FSL atlasquery tool', 'list' : 'List available atlases', 'summary' : 'Print a summary of one atlas', 'query' : 'Query an atlas at specific coordinates' } usages = { 'main' : 'usage: atlasq [-h] command [options]', 'ohi' : textwrap.dedent(""" usage: atlasquery -h atlasquery --dumpatlases atlasquery -a atlas -c X,Y,Z atlasquery -a atlas -m mask """).strip(), 'list' : 'usage: atlasq list [-e]', 'summary' : 'usage: atlasq summary atlas', 'query' : textwrap.dedent(""" usage: atlasq query atlas [options] -m mask [-m mask ...] usage: atlasq query atlas [options] -c X Y Z [-c X Y Z...] usage: atlasq query atlas [options] -v X Y Z [-v X Y Z...] usage: atlasq query atlas [options] -v X Y Z \\ [-c X Y Z [-m mask]...] """).strip(), } for k in usages: usages[k] = '{}\n\n{}'.format(prolog, usages[k]) parser = argparse.ArgumentParser( prog='atlasq', usage=usages['main'], formatter_class=HelpFormatter) subParsers = parser.add_subparsers(title='Commands', dest='command') ohiParser = subParsers.add_parser( 'ohi', help=helps['ohi'], usage=usages['ohi'], formatter_class=HelpFormatter) listParser = subParsers.add_parser( 'list', help=helps['list'], usage=usages['list'], formatter_class=HelpFormatter) sumParser = subParsers.add_parser( 'summary', help=helps['summary'], usage=usages['summary'], formatter_class=HelpFormatter) queryParser = subParsers.add_parser( 'query', help=helps['query'], usage=usages['query'], formatter_class=HelpFormatter) # This is a custom argparse.Action used by the # query command parser which keeps track of the # order in which query arguments are passed. The # three query types (mask, voxel, coord) are # each added to separate lists. For each type, a # second list (called mask_order, voxel_order, # and coord_order) is maintained which stores # the index of each query across all types. class QueryAction(argparse.Action): queryCount = 0 def __init__(self, *args, **kwargs): argparse.Action.__init__(self, *args, **kwargs) def __call__(self, parser, namespace, values, option_string=None): dest = getattr(namespace, self.dest, None) order = getattr(namespace, '{}_order'.format(self.dest), None) if dest is None: dest = [] if order is None: order = [] dest .append(values) order.append(QueryAction.queryCount) setattr(namespace, self.dest, dest) setattr(namespace, '{}_order'.format(self.dest), order) QueryAction.queryCount += 1 # OldHorribleInterface parser ohiParser.add_argument( '-a', '--atlas', help='Name of atlas to use') ohiParser.add_argument( '-V', '--verbose', action='store_true', help='Switch on diagnostic messages') ohiSubParser = ohiParser.add_mutually_exclusive_group() ohiSubParser.add_argument( '-m', '--mask', dest='ohiMask', metavar='MASK', help='A mask image to use during structural lookups') ohiSubParser.add_argument( '-c', '--coord', help='Coordinate to query') ohiParser.add_argument( '--dumpatlases', action='store_true', help='Dump a list of the available atlases') # List parser listParser.add_argument( '-e', '--extended', action='store_true', help='Print more information about each atlas') # Summary parser sumParser.add_argument('atlas', help='Name or ID of atlas to summarise') # Query parser queryParser.add_argument( 'atlas', help='Name or ID of atlas to summarise.') queryParser.add_argument( '-r', '--resolution', type=float, help='Desired atlas resolution (mm). Default is highest available ' 'resolution.') queryParser.add_argument( '-s', '--short', action='store_true', help='Output in short (machine-friendly) format.') queryParser.add_argument( '-l', '--label', action='store_true', help='Query label/maxprob version of atlas (for probabilistic ' 'atlases).') queryParser.add_argument( '-m', '--mask', action=QueryAction, help='Mask to query with.') queryParser.add_argument( '-c', '--coord', nargs=3, type=float, metavar=('X', 'Y', 'Z'), action=QueryAction, help='World coordinates to look up.') queryParser.add_argument( '-v', '--voxel', nargs=3, type=float, metavar=('X', 'Y', 'Z'), action=QueryAction, help='Voxel coordinates to look up. Must be in terms of the atlas ' 'at the specified (or default) --resolution.') namespace = parser.parse_args(args) if namespace.command != 'query': return namespace # Make life easier for the queryAtlas code if namespace.mask is None: namespace.mask = [] if namespace.coord is None: namespace.coord = [] if namespace.voxel is None: namespace.voxel = [] if not hasattr(namespace, 'mask_order'): namespace.mask_order = [] if not hasattr(namespace, 'coord_order'): namespace.coord_order = [] if not hasattr(namespace, 'voxel_order'): namespace.voxel_order = [] return namespace
[docs] def main(args=None): """Entry point for ``atlasq``. Parses arguments, and runs the requested command. """ if args is None: args = sys.argv[1:] # Parse command line arguments namespace = parseArgs(args) # Initialise the atlas library fslatlases.rescanAtlases() # Run the command try: if namespace.command == 'list': listAtlases( namespace) elif namespace.command == 'query': queryAtlas( namespace) elif namespace.command == 'summary': summariseAtlas(namespace) elif namespace.command == 'ohi': ohi( namespace) except (IdentifyError, fslatlases.MaskError) as e: print(str(e)) return 1 return 0
[docs] def atlasquery_emulation(args=None): """Entry point for ``atlasquery``. Runs as ``atlasq`` in ``ohi`` mode. """ if args is None: args = sys.argv[1:] return main(['ohi'] + args)
if __name__ == '__main__': sys.exit(main())