#!/usr/bin/env python
#
# main.py - The atlasq program.
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
"""This module contains the FSL ``atlasq`` program, the successor to
``atlasquery``.
"""
from __future__ import print_function
import itertools as it
import sys
import argparse
import textwrap
import warnings
import logging
import numpy as np
import fsl.data.image as fslimage
import fsl.data.atlases as fslatlases
import fsl.version as fslversion
log = logging.getLogger(__name__)
SHORT_QUERY_DELIM = '\t'
[docs]
class IdentifyError(Exception):
"""Exception raised by the ``identifyAtlas`` when an atlas cannot be
identified.
"""
pass
[docs]
def listAtlases(namespace):
"""List all available atlases. """
atlases = fslatlases.listAtlases()
if namespace.extended:
for a in atlases:
print('{} [{}]'.format(a.name, a.atlasID))
print(' Spec: {}'.format(a.specPath))
print(' Type: {}'.format(a.atlasType))
print(' Labels: {}'.format(len(a.labels)))
for i in a.images:
print(' Image: {}'.format(i))
for i in a.summaryImages:
print(' Summary image: {}'.format(i))
print()
else:
ids = [a.atlasID for a in atlases]
names = [a.name for a in atlases]
printColumns((ids, names), ('ID', 'Full name'))
[docs]
def summariseAtlas(namespace):
"""Print information about one atlas. """
a = identifyAtlas(namespace.atlas)
print('{} [{}]'.format(a.name, a.atlasID))
print(' Spec: {}'.format(a.specPath))
print(' Type: {}'.format(a.atlasType))
print(' Labels: {}'.format(len(a.labels)))
for i in a.images:
print(' Image: {}'.format(i))
for i in a.summaryImages:
print(' Summary image: {}'.format(i))
indices = [l.index for l in a.labels]
names = [l.name for l in a.labels]
xs = [l.x for l in a.labels]
ys = [l.y for l in a.labels]
zs = [l.z for l in a.labels]
printColumns(( indices, names, xs, ys, zs),
('Index', 'Label', 'X', 'Y', 'Z'))
[docs]
def queryAtlas(namespace):
"""Query an atlas with coordinates or masks."""
atlasDesc = identifyAtlas(namespace.atlas)
wcoords = namespace.coord
vcoords = namespace.voxel
masks = [fslimage.Image(m) for m in namespace.mask]
worder = namespace.coord_order
vorder = namespace.voxel_order
morder = namespace.mask_order
atlas = fslatlases.loadAtlas(atlasDesc.atlasID,
loadSummary=namespace.label,
resolution=namespace.resolution)
mlabels, mprops = maskQuery( atlas, masks)
wlabels, wprops = coordQuery(atlas, wcoords, False)
vlabels, vprops = coordQuery(atlas, vcoords, True)
order = list(it.chain(morder, worder, vorder))
labels = list(it.chain(mlabels, wlabels, vlabels))
props = list(it.chain(mprops, wprops, vprops))
sources = list(it.chain(masks, wcoords, vcoords))
types = list(it.chain(['mask'] * len(masks),
['coordinate'] * len(wcoords),
['voxel'] * len(vcoords)))
labels = [l for (o, l) in sorted(zip(order, labels))]
props = [p for (o, p) in sorted(zip(order, props))]
sources = [s for (o, s) in sorted(zip(order, sources))]
types = [t for (o, t) in sorted(zip(order, types))]
if namespace.short: queryShortOutput(atlas, sources, types, labels, props)
else: queryLongOutput( atlas, sources, types, labels, props)
[docs]
def queryShortOutput(atlas, sources, types, allLabels, allProps):
"""Called by ``queryAtlas`` when short output is requested. """
for source, stype, labels, props in zip(sources,
types,
allLabels,
allProps):
if stype == 'coordinate':
source = '{:0.2f} {:0.2f} {:0.2f}'.format(*source)
elif stype == 'voxel':
source = '{:0.0f} {:0.0f} {:0.0f}'.format(*source)
elif stype == 'mask':
source = source.dataSource
results = []
labels = labelNames(atlas, labels)
# Coordinate lookup for a label
# atlas just returns a label
if stype in ('coordinate', 'voxel') and \
isinstance(atlas, fslatlases.LabelAtlas):
results.append('{}'.format(labels[0]))
# All other queries return a list of
# labels and proportions. We output
# them from highest proportion to
# lowest
else:
for p, l in reversed(sorted(zip(props, labels))):
results.append('{} {:0.4f}'.format(l, p))
print(SHORT_QUERY_DELIM.join([stype, source] + results))
[docs]
def queryLongOutput(atlas, sources, types, allLabels, allProps):
"""Called by ``queryAtlas`` when long output is requested. """
def summaryCoord(source, stype, labels, props, names):
label = labels[0]
name = names[ 0]
if label is None: label = np.nan
else: label = int(label)
fields = ['name', 'index']
values = [name, label]
if atlas.desc.atlasType == 'probabilistic':
fields.append('summary value')
values.append(label + 1)
else:
fields[1] = 'label'
printColumns((fields, values))
def proportions(source, stype, labels, props, names):
if len(labels) == 0:
print('No results')
return
proplabelnames = list(reversed(sorted(zip(props, labels, names))))
props = [pln[0] for pln in proplabelnames]
labels = [pln[1] for pln in proplabelnames]
names = [pln[2] for pln in proplabelnames]
props = ['{:0.4f}'.format(p) for p in props]
titles = ['name', 'index', 'proportion']
columns = [ names, labels, props]
if atlas.desc.atlasType == 'probabilistic':
sumvals = [l + 1 for l in labels]
titles .insert(2, 'summary value')
columns.insert(2, sumvals)
else:
titles[1] = 'label'
printColumns(columns, titles)
for source, stype, labels, props in zip(sources,
types,
allLabels,
allProps):
if stype == 'coordinate':
sourcestr = '{:0.2f} {:0.2f} {:0.2f}'.format(*source)
elif stype == 'voxel':
sourcestr = '{:0.0f} {:0.0f} {:0.0f}'.format(*source)
elif stype == 'mask':
sourcestr = '{}'.format(source.dataSource)
title = '{} {}'.format(stype, sourcestr)
print('-' * (4 + len(title)))
print('| {} |'.format(title))
print('-' * (4 + len(title)))
print()
names = labelNames(atlas, labels)
if stype in ('coordinate', 'voxel') and \
isinstance(atlas, fslatlases.LabelAtlas):
summaryCoord(source, stype, labels, props, names)
else:
proportions(source, stype, labels, props, names)
print()
[docs]
def ohi(namespace):
"""Emulates the FSL ``atlasquery`` tool."""
atlasDesc = None
def dumpatlases():
atlases = [a.name for a in fslatlases.listAtlases()]
print('\n'.join(sorted(atlases)))
if namespace.dumpatlases:
dumpatlases()
return
for a in fslatlases.listAtlases():
if a.name == namespace.atlas:
atlasDesc = a
break
if atlasDesc is None:
print('Invalid atlas name. Try one of:')
dumpatlases()
return
# atlasquery always uses 2mm atlas
# versions when a 2mm is available
reses = [p[0] for p in atlasDesc.pixdims]
if 2 in reses: res = 2
else: res = max(reses)
# Mask query.
if namespace.ohiMask is not None:
mask = fslimage.Image(namespace.ohiMask)
labels, props = maskQuery(atlasDesc, [mask], resolution=res)
labels = labels[0]
props = props[ 0]
for lbl, prop in zip(labels, props):
if atlasDesc.atlasType == 'probabilistic':
lbl = atlasDesc.labels[int(lbl)].name
elif atlasDesc.atlasType == 'label':
lbl = atlasDesc.find(value=int(lbl)).name
print('{}:{:0.4f}'.format(lbl, prop))
# Coordinate query
else:
coord = namespace.coord.strip('"')
coord = [float(c) for c in coord.split(',')]
labels, props = coordQuery(atlasDesc,
[coord],
False,
resolution=res)
labels = labels[0]
props = props[ 0]
if atlasDesc.atlasType == 'label':
labels = labels[0]
if labels is None: label = 'Unclassified'
else: label = atlasDesc.find(value=int(labels)).name
print('<b>{}</b><br>{}'.format(atlasDesc.name, label))
elif atlasDesc.atlasType == 'probabilistic':
labelStrs = []
if len(labels) > 0:
props, labels = zip(*reversed(sorted(zip(props, labels))))
for label, prop in zip(labels, props):
label = atlasDesc.labels[int(label)].name
labelStrs.append('{:d}% {}'.format(int(round(prop)), label))
if len(labelStrs) == 0: labels = 'No label found!'
else: labels = ', '.join(labelStrs)
print('<b>{}</b><br>{}'.format(atlasDesc.name, labels))
[docs]
def atlasOrDesc(aord, *args, **kwargs):
"""If ``aord`` is an ``Atlas`` it is returned. Otherwise it is assumed to
be an ``AtlasDescription``, in which case the corresponding ``Atlas`` is
loaded and returned.
"""
if isinstance(aord, fslatlases.Atlas):
return aord
else:
return fslatlases.loadAtlas(aord.atlasID, *args, **kwargs)
[docs]
def labelNames(atlas, labels):
"""Converts the given sequence of ``labels`` into region names. """
names = []
for l in labels:
if l is None: names.append('No label')
else: names.append(atlas.desc.labels[int(l)].name)
return names
[docs]
def maskQuery(atlas, masks, *args, **kwargs):
"""Queries the ``atlas`` at the given ``masks``. """
allLabels = []
allProps = []
atlas = atlasOrDesc(atlas, *args, **kwargs)
for mask in masks:
if isinstance(atlas, fslatlases.LabelAtlas):
labels, props = atlas.maskLabel(mask)
# We need to subtract 1 from summary
# image label values to get the label
# index, for probabilistic atlases.
if atlas.desc.atlasType == 'probabilistic':
labels = [l - 1 for l in labels]
elif isinstance(atlas, fslatlases.ProbabilisticAtlas):
labels = []
props = []
zprops = atlas.maskValues(mask)
for i in range(len(zprops)):
if zprops[i] > 0:
props.append(zprops[i])
labels.append(atlas.desc.labels[i].index)
allLabels.append(labels)
allProps .append(props)
return allLabels, allProps
[docs]
def coordQuery(atlas, coords, voxel, *args, **kwargs):
"""Queries the ``atlas`` at the given ``coords``. """
atlas = atlasOrDesc(atlas, *args, **kwargs)
allLabels = []
allProps = []
for coord in coords:
if isinstance(atlas, fslatlases.ProbabilisticAtlas):
props = atlas.values(coord, voxel=voxel)
labels = []
nzprops = []
for i, p in enumerate(props):
if p != 0:
nzprops.append(p)
labels .append(atlas.desc.labels[i].index)
allLabels.append(labels)
allProps .append(nzprops)
elif isinstance(atlas, fslatlases.LabelAtlas):
label = atlas.label(coord, voxel=voxel)
# we need to subtract 1 from the label
# value to get the label index, for
# probabilistic summary images.
if atlas.desc.atlasType == 'probabilistic':
# 0 == background
if label == 0: label = None
if label is not None: label = label - 1
allLabels.append([label])
allProps .append([None])
return allLabels, allProps
[docs]
def identifyAtlas(idOrName):
"""Given a partial atlas ID or name, tries to find an atlas which
uniquely matches it.
"""
# TODO Use difflib or some fuzzy matching library?
idOrName = idOrName.lower().strip()
atlases = fslatlases.listAtlases()
allNames = [a.name .lower() for a in atlases]
allIDs = [a.atlasID.lower() for a in atlases]
# First test for an exact match
nameMatches = [idOrName == n for n in allNames]
idMatches = [idOrName == i for i in allIDs]
nameMatches = [i for i in range(len(nameMatches)) if nameMatches[i]]
idMatches = [i for i in range(len(idMatches)) if idMatches[ i]]
if len(nameMatches) + len(idMatches) == 1:
if len(nameMatches) == 1: return atlases[nameMatches[0]]
else: return atlases[idMatches[ 0]]
# If no exact match, test for a partial match
nameMatches = [idOrName in n for n in allNames]
idMatches = [idOrName in i for i in allIDs]
nameMatches = [i for i in range(len(nameMatches)) if nameMatches[i]]
idMatches = [i for i in range(len(idMatches)) if idMatches[ i]]
totalMatches = len(nameMatches) + len(idMatches)
# No matches
if totalMatches == 0:
raise IdentifyError('Could not find any atlas '
'matching {}'.format(idOrName))
# More than two matches, or a
# different ID/name pair matched
if totalMatches > 2 or (totalMatches == 2 and nameMatches != idMatches):
possible = [allNames[m] for m in nameMatches] + \
[allIDs[ m] for m in idMatches]
raise IdentifyError('{} matched multiple atlases! Could match one '
'of: {}'.format(idOrName, ', '.join(possible)))
# Either one exact match to an ID or name,
# or a match to an equivalent ID/name
if len(nameMatches) == 1: return atlases[nameMatches[0]]
else: return atlases[idMatches[ 0]]
[docs]
def printColumns(columns, titles=None, delim=' | ', sep=True, strip=False):
"""Convenience function which pretty-prints a collection of columns in a
tabular format.
:arg columns: A sequence of columns, where each column is a list of
strings.
:arg titles: A sequence of titles, one for each column.
"""
if len(columns) == 0:
return
columns = list(columns)
for i, c in enumerate(columns):
col = list(map(str, c))
if titles is not None: columns[i] = [titles[i]] + col
else: columns[i] = col
colLens = []
for col in columns:
maxLen = max([len(r) for r in col])
colLens.append(maxLen)
fmtStr = delim.join(['{{:<{}s}}'.format(l) for l in colLens])
if titles is not None and sep:
titles = [col[0] for col in columns]
columns = [col[1:] for col in columns]
separator = ['-' * l for l in colLens]
print(fmtStr.format(*titles))
print(fmtStr.format(*separator))
nrows = len(columns[0])
for i in range(nrows):
row = [col[i] for col in columns]
row = fmtStr.format(*row)
if strip:
row = row.strip()
print(row)
[docs]
def parseArgs(args):
"""Parses command line arguments, returning an ``argparse.Namespace``
object.
"""
# Show help if no args are provided
if len(args) == 0 or \
(len(args) == 1 and args[0] in ('ohi', 'summary', 'query')):
args = list(args) + ['-h']
# Hack to make argparse accept
# coordinates with a negative sign
# (ohi/atlasquery interface only)
if args[0] == 'ohi':
try:
cidx = args.index('-c')
coord = args[cidx + 1]
args[cidx + 1] = '"{}"'.format(coord)
except:
pass
prolog = 'FSL atlasq {}'.format(fslversion.__version__)
helps = {
'ohi' : 'Emulate the FSL atlasquery tool',
'list' : 'List available atlases',
'summary' : 'Print a summary of one atlas',
'query' : 'Query an atlas at specific coordinates'
}
usages = {
'main' : 'usage: atlasq [-h] command [options]',
'ohi' : textwrap.dedent("""
usage: atlasquery -h
atlasquery --dumpatlases
atlasquery -a atlas -c X,Y,Z
atlasquery -a atlas -m mask
""").strip(),
'list' : 'usage: atlasq list [-e]',
'summary' : 'usage: atlasq summary atlas',
'query' : textwrap.dedent("""
usage: atlasq query atlas [options] -m mask [-m mask ...]
usage: atlasq query atlas [options] -c X Y Z [-c X Y Z...]
usage: atlasq query atlas [options] -v X Y Z [-v X Y Z...]
usage: atlasq query atlas [options] -v X Y Z \\
[-c X Y Z [-m mask]...]
""").strip(),
}
for k in usages:
usages[k] = '{}\n\n{}'.format(prolog, usages[k])
parser = argparse.ArgumentParser(
prog='atlasq',
usage=usages['main'],
formatter_class=HelpFormatter)
subParsers = parser.add_subparsers(title='Commands', dest='command')
ohiParser = subParsers.add_parser(
'ohi',
help=helps['ohi'],
usage=usages['ohi'],
formatter_class=HelpFormatter)
listParser = subParsers.add_parser(
'list',
help=helps['list'],
usage=usages['list'],
formatter_class=HelpFormatter)
sumParser = subParsers.add_parser(
'summary',
help=helps['summary'],
usage=usages['summary'],
formatter_class=HelpFormatter)
queryParser = subParsers.add_parser(
'query',
help=helps['query'],
usage=usages['query'],
formatter_class=HelpFormatter)
# This is a custom argparse.Action used by the
# query command parser which keeps track of the
# order in which query arguments are passed. The
# three query types (mask, voxel, coord) are
# each added to separate lists. For each type, a
# second list (called mask_order, voxel_order,
# and coord_order) is maintained which stores
# the index of each query across all types.
class QueryAction(argparse.Action):
queryCount = 0
def __init__(self, *args, **kwargs):
argparse.Action.__init__(self, *args, **kwargs)
def __call__(self, parser, namespace, values, option_string=None):
dest = getattr(namespace, self.dest, None)
order = getattr(namespace, '{}_order'.format(self.dest), None)
if dest is None: dest = []
if order is None: order = []
dest .append(values)
order.append(QueryAction.queryCount)
setattr(namespace, self.dest, dest)
setattr(namespace, '{}_order'.format(self.dest), order)
QueryAction.queryCount += 1
# OldHorribleInterface parser
ohiParser.add_argument(
'-a', '--atlas',
help='Name of atlas to use')
ohiParser.add_argument(
'-V', '--verbose',
action='store_true',
help='Switch on diagnostic messages')
ohiSubParser = ohiParser.add_mutually_exclusive_group()
ohiSubParser.add_argument(
'-m', '--mask',
dest='ohiMask',
metavar='MASK',
help='A mask image to use during structural lookups')
ohiSubParser.add_argument(
'-c', '--coord',
help='Coordinate to query')
ohiParser.add_argument(
'--dumpatlases',
action='store_true',
help='Dump a list of the available atlases')
# List parser
listParser.add_argument(
'-e', '--extended',
action='store_true',
help='Print more information about each atlas')
# Summary parser
sumParser.add_argument('atlas', help='Name or ID of atlas to summarise')
# Query parser
queryParser.add_argument(
'atlas',
help='Name or ID of atlas to summarise.')
queryParser.add_argument(
'-r', '--resolution',
type=float,
help='Desired atlas resolution (mm). Default is highest available '
'resolution.')
queryParser.add_argument(
'-s', '--short',
action='store_true',
help='Output in short (machine-friendly) format.')
queryParser.add_argument(
'-l', '--label',
action='store_true',
help='Query label/maxprob version of atlas (for probabilistic '
'atlases).')
queryParser.add_argument(
'-m', '--mask',
action=QueryAction,
help='Mask to query with.')
queryParser.add_argument(
'-c', '--coord',
nargs=3,
type=float,
metavar=('X', 'Y', 'Z'),
action=QueryAction,
help='World coordinates to look up.')
queryParser.add_argument(
'-v', '--voxel',
nargs=3,
type=float,
metavar=('X', 'Y', 'Z'),
action=QueryAction,
help='Voxel coordinates to look up. Must be in terms of the atlas '
'at the specified (or default) --resolution.')
namespace = parser.parse_args(args)
if namespace.command != 'query':
return namespace
# Make life easier for the queryAtlas code
if namespace.mask is None: namespace.mask = []
if namespace.coord is None: namespace.coord = []
if namespace.voxel is None: namespace.voxel = []
if not hasattr(namespace, 'mask_order'): namespace.mask_order = []
if not hasattr(namespace, 'coord_order'): namespace.coord_order = []
if not hasattr(namespace, 'voxel_order'): namespace.voxel_order = []
return namespace
[docs]
def main(args=None):
"""Entry point for ``atlasq``. Parses arguments, and runs the requested
command.
"""
if args is None:
args = sys.argv[1:]
# Parse command line arguments
namespace = parseArgs(args)
# Initialise the atlas library
fslatlases.rescanAtlases()
# Run the command
try:
if namespace.command == 'list': listAtlases( namespace)
elif namespace.command == 'query': queryAtlas( namespace)
elif namespace.command == 'summary': summariseAtlas(namespace)
elif namespace.command == 'ohi': ohi( namespace)
except (IdentifyError, fslatlases.MaskError) as e:
print(str(e))
return 1
return 0
[docs]
def atlasquery_emulation(args=None):
"""Entry point for ``atlasquery``. Runs as ``atlasq`` in ``ohi``
mode.
"""
if args is None:
args = sys.argv[1:]
return main(['ohi'] + args)
if __name__ == '__main__':
sys.exit(main())