#!/usr/bin/env python
# Copyright (c) 2014-2015, András Wacha <awacha@gmail.com>
# All rights reserved.
#
# This software including the files in this directory is provided under
# the following license.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
# SUCH DAMAGE.

import itertools
import collections
import sys
_PY2 = sys.version_info.major == 2
if _PY2:
    import Tkinter as tkinter
    import ttk as _ttk
    tkinter.ttk = _ttk
    import tkFont as _tkFont
    tkinter.font = _tkFont
    import tkMessageBox as _tkMessageBox
    tkinter.messagebox = _tkMessageBox
else:
    import tkinter
    import tkinter.ttk
    import tkinter.font
    import tkinter.messagebox
import math
import numpy as np
import matplotlib.figure
import threading
import queue
import time
import warnings
if _PY2:
    from backports import configparser
else:
    import configparser
import os
import matplotlib
matplotlib.use('TkAgg')
matplotlib.rcParams['savefig.dpi'] = 600

import matplotlib.backends.backend_tkagg
try:
    import pyperclip
except ImportError:
    warnings.warn('Cannot import pyperclip, exporting to HTML won\'t work.')


def all_possible_selections(set_):
    for r in range(1, len(set_) + 1):
        for c in sorted(set(itertools.combinations(set_, r))):
            yield c
    return


class PinholeConfiguration(object):

    def __init__(self, L1, L2, D1, D2, ls, lbs, sd, mindist_l1=0, mindist_l2=0,
                 sealringwidth=0, wavelength=0.15418):
        self.l1_elements = L1
        self.l2_elements = L2
        self.mindist_l1 = mindist_l1
        self.mindist_l2 = mindist_l2
        self.sealringwidth = sealringwidth
        self.r1 = D1 * 0.5e-3
        self.r2 = D2 * 0.5e-3
        self.ls = ls
        self.lbs = lbs
        self.sd = sd
        self.wavelength = wavelength

    def __copy__(self):
        return PinholeConfiguration(
            self.l1_elements, self.l2_elements, self.D1, self.D2, self. ls,
            self.lbs, self.sd, self.mindist_l1, self.mindist_l2,
            self.sealringwidth, self.wavelength)

    copy = __copy__

    @property
    def l1(self):
        if isinstance(self.l1_elements, collections.Iterable):
            return float(sum(self.l1_elements) +
                         self.sealringwidth * (1 + len(self.l1_elements)) +
                         self.mindist_l1)
        else:
            return self.l1_elements

    @l1.setter
    def l1(self, value):
        self.l1_elements = value

    @property
    def l2(self):
        if isinstance(self.l2_elements, collections.Iterable):
            return float(sum(self.l2_elements) +
                         self.sealringwidth * (1 + len(self.l2_elements)) +
                         self.mindist_l2)
        else:
            return self.l2_elements

    @l2.setter
    def l2(self, value):
        self.l2_elements = value

    @property
    def D1(self):
        return 2000 * self.r1

    @D1.setter
    def D1(self, value):
        self.r1 = value * 0.5e-3

    @property
    def D2(self):
        return 2000 * self.r2

    @D2.setter
    def D2(self, value):
        self.r2 = value * 0.5e-3

    @property
    def D3(self):
        return 2000 * self.r3

    @property
    def Dsample_direct(self):
        return 2 * self.rs_direct

    @property
    def Dsample_parasitic(self):
        return 2 * self.rs_parasitic

    Dsample = Dsample_direct

    @property
    def Dbs_parasitic(self):
        return 2 * self.rbs_parasitic

    @property
    def Dbs_direct(self):
        return 2 * self.rbs_direct

    Dbs = Dbs_parasitic

    @property
    def Ddet_parasitic(self):
        return 2 * self.rdet_parasitic

    @property
    def Ddet_direct(self):
        return 2 * self.rdet_direct

    Ddet = Ddet_parasitic

    @property
    def l3(self):
        return self.sd + self.ls

    @property
    def r3(self):
        return self.r2 + (self.r1 + self.r2) * self.l2 / self.l1

    @property
    def rs_direct(self):
        return self.r2 + (self.r1 + self.r2) * (self.l2 + self.ls) / self.l1

    @property
    def rs_parasitic(self):
        return self.r3 + (self.r2 + self.r3) * (self.ls / self.l2)

    rs = rs_direct

    @property
    def rbs_direct(self):
        return (self.r2 + (self.r1 + self.r2) * (self.l2 + self. l3 -
                                                 self.lbs) / self.l1)

    @property
    def rbs_parasitic(self):
        return ((self.r2 + self.r3) * (self.l2 + self.l3 - self.lbs) /
                self.l2 - self.r2)

    rbs = rbs_parasitic

    @property
    def rdet_direct(self):
        return (self.r2 + (self.r1 + self.r2) * (self.l2 + self. l3) /
                self.l1)

    @property
    def rdet_parasitic(self):
        return ((self.r2 + self.r3) * (self.l2 + self.l3) /
                self.l2 - self.r2)

    rdet = rdet_parasitic

    @property
    def tantthmin(self):
        return (self.rbs / (self.sd - self.lbs))

    @property
    def qmin(self):
        return (4 * math.pi * math.sin(0.5 * math.atan(self.tantthmin)) /
                self.wavelength)

    @property
    def dmax(self):
        return 2 * math.pi / self.qmin

    @property
    def intensity(self):
        return self.D1**2 * self.D2**2 / self.l1**2 / 64 * math.pi

    @property
    def alpha(self):
        return math.atan2((self.r2 + self.r1), self.l1)

    @property
    def dspheremax(self):
        return (5 / 3.)**0.5 * 2 * self.Rgmin

    @property
    def Rgmin(self):
        return 1 / self.qmin

    def __str__(self):
        return 'l1: %.2f mm; l2: %.2f mm; D1: %.0f um; D2: %.0f um;\
D3: %.0f um; I: %.2f; qmin: %.5f 1/nm' % (self.l1, self.l2, self.D1, self.D2,
                                          self.D3, self.intensity, self.qmin)

    @property
    def dominant_constraint(self):
        lambda1 = (self.ls + self.l2) / self.l1
        lambda0 = (self.l2 / self.l1)
        rho = self.rs / self.rbs
        d_SsBS = (
            (2 * lambda1 * (lambda1 + 1) - rho * (lambda0 + lambda1 +
                                                  2 * lambda0 * lambda1)) /
            (rho * (lambda0 + 2 * lambda1 + 2 * lambda0 * lambda1)) *
            self.l2 + self.lbs - self.ls)
        Discr = 4 * rho**2 * lambda0**2 + rho * \
            (4 * lambda0**2 - 8 * lambda0 * lambda1) + \
            (lambda0 + 2 * lambda1 + 2 * lambda0 * lambda1)**2
        lam2plus = ((2 * lambda1 + lambda0 + 2 * lambda0 * lambda1 -
                     6 * rho * lambda0 - 4 * rho * lambda0**2 + Discr**0.5) /
                    (8 * rho * lambda0 + 4 * rho * lambda0**2))
        d_BSsSplus = lam2plus * self.l2 + self.lbs - self.ls
        if self.sd < d_SsBS:
            return 'sample'
        elif self.sd < d_BSsSplus:
            return 'beamstop'
        else:
            return 'neither'


class ListboxWithScroll(tkinter.ttk.Frame):
    """This is a listbox with the following functionality:
            - working horizontal and vertical scrollbars
            - entry row at the top for adding new items
            - buttons on the right for adding, ordering, removing items
    """
    def __init__(self, *args, **kwargs):
        tkinter.ttk.Frame.__init__(self, *args, **kwargs)
        self._entry = tkinter.ttk.Entry(master=self)
        self._entry.grid(row=0, column=0, columnspan=2, sticky=tkinter.NSEW)
        self._button_add = tkinter.ttk.Button(
            master=self, text='Add', command=self.do_add)
        self._button_add.grid(row=0, column=2, sticky=tkinter.NSEW)
        self._listbox = tkinter.Listbox(
            master=self, height=0, selectmode=tkinter.EXTENDED)
        self._listbox.grid(row=1, column=0, sticky=tkinter.NSEW)
        self._xscroll = tkinter.ttk.Scrollbar(master=self,
                                              orient=tkinter.HORIZONTAL,
                                              command=self._listbox.xview)
        self._xscroll.grid(
            row=2, column=0, sticky=tkinter.S + tkinter.E + tkinter.W)
        self._listbox['xscrollcommand'] = self._xscroll.set
        self._yscroll = tkinter.ttk.Scrollbar(master=self,
                                              orient=tkinter.VERTICAL,
                                              command=self._listbox.yview)
        self._yscroll.grid(
            row=1, column=1, sticky=tkinter.N + tkinter.S + tkinter.E)
        self._listbox['yscrollcommand'] = self._yscroll.set

        self._buttonbox = tkinter.ttk.Frame(master=self)
        self._buttonbox.grid(row=1, column=2, rowspan=2, sticky=tkinter.NSEW)
        self._buttons = {}
        for label, cmd in [('Remove', self.do_remove),
                           ('Clear', self.do_clear),
                           ('To top', self.do_movetotop),
                           ('Up', self.do_moveup),
                           ('Down', self.do_movedown),
                           ('To bottom', self.do_movetobottom)]:
            self._buttons[label] = tkinter.ttk.Button(
                master=self._buttonbox, text=label, command=cmd)
            self._buttons[label].grid(sticky=tkinter.NW)
        self.grid_columnconfigure(0, weight=1)
        self.grid_rowconfigure(1, weight=1)

    def do_movetotop(self):
        try:
            idx = self.curselection()[0]
        except IndexError:
            return
        value = self.get(idx)
        self.delete(idx)
        self.insert(0, value)
        self.selection_set(0)

    def do_movetobottom(self):
        try:
            idx = self.curselection()[0]
        except IndexError:
            return
        value = self.get(idx)
        self.delete(idx)
        self.insert(tkinter.END, value)
        self.selection_set(tkinter.END)

    def do_moveup(self):
        try:
            idx = self.curselection()[0]
        except IndexError:
            return
        if idx == 0:
            return
        value = self.get(idx)
        self.delete(idx)
        self.insert(idx - 1, value)
        self.selection_set(idx - 1)

    def do_movedown(self):
        try:
            idx = self.curselection()[0]
        except IndexError:
            return
        value = self.get(idx)
        self.delete(idx)
        self.insert(idx + 1, value)
        self.selection_set(min(idx + 1, self.size() - 1))

    def do_add(self):
        try:
            idx = self.curselection()[0]
        except IndexError:
            idx = tkinter.END
        self.insert(idx, self._entry.get())

    def do_remove(self):
        try:
            while True:
                self.delete(self.curselection()[0])
        except IndexError:
            pass

    def do_clear(self):
        self.delete(0, tkinter.END)

    def insert(self, *args, **kwargs):
        return self._listbox.insert(*args, **kwargs)

    def activate(self, *args, **kwargs):
        return self._listbox.activate(*args, **kwargs)

    def bbox(self, *args, **kwargs):
        return self._listbox.bbox(*args, **kwargs)

    def curselection(self):
        return self._listbox.curselection()

    def delete(self, *args, **kwargs):
        return self._listbox.delete(*args, **kwargs)

    def get(self, *args, **kwargs):
        return self._listbox.get(*args, **kwargs)

    def index(self, *args, **kwargs):
        return self._listbox.get(*args, **kwargs)

    def selection_set(self, *args, **kwargs):
        return self._listbox.selection_set(*args, **kwargs)

    def size(self):
        return self._listbox.size()


class SearchFrame(tkinter.ttk.Frame):

    def __init__(self, appframe, *args, **kwargs):
        tkinter.ttk.Frame.__init__(self, *args, **kwargs)
        self._optimum = None
        self._appframe = appframe

    def execute(self):
        pass

    def get_optimum(self):
        return self._optimum

    def set_optimum(self, phconf):
        raise NotImplementedError

    def save_defaults(self):
        raise NotImplementedError


class BruteForceSearchFrame(SearchFrame):

    def __init__(self, *args, **kwargs):
        SearchFrame.__init__(self, *args, **kwargs)
        f1 = tkinter.ttk.Frame(master=self)
        f1.pack(side=tkinter.LEFT, fill=tkinter.BOTH, expand=True)
        f1.grid_columnconfigure(1, weight=1)
        f1.grid_rowconfigure(0, weight=1)
        row = 0
        lf = tkinter.ttk.LabelFrame(master=f1, text='Distance elements (mm):')
        lf.grid(row=row, column=0, columnspan=2, sticky=tkinter.NSEW)
        self._distelements = ListboxWithScroll(master=lf)
        self._distelements.grid(sticky=tkinter.NSEW)
        for i in [float(x.strip()) for x in CONFIG['L_elements'].split(',')]:
            self._distelements.insert(tkinter.END, i)
        lf.grid_columnconfigure(0, weight=1)
        lf.grid_rowconfigure(0, weight=1)
        row += 1
        l = tkinter.ttk.Label(master=f1, text='Sealing ring width (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._sealringwidth = tkinter.Spinbox(
            master=f1, from_=0, to=50, increment=0.1)
        self._sealringwidth.delete(0, tkinter.END)
        self._sealringwidth.insert(0, CONFIG['sealring'])
        self._sealringwidth.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=f1, text='L1 without spacers (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._mindist_l1 = tkinter.Spinbox(master=f1)
        self._mindist_l1.delete(0, tkinter.END)
        self._mindist_l1.insert(0, CONFIG['L1_bare'])
        self._mindist_l1.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=f1, text='L2 without spacers (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._mindist_l2 = tkinter.Spinbox(master=f1)
        self._mindist_l2.delete(0, tkinter.END)
        self._mindist_l2.insert(0, CONFIG['L2_bare'])
        self._mindist_l2.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=f1, text='Det-beamstop dist. (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._lbs = tkinter.Spinbox(master=f1)
        self._lbs.delete(0, tkinter.END)
        self._lbs.insert(0, CONFIG['lbs'])
        self._lbs.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=f1, text='PH#3-sample dist. (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._ls = tkinter.Spinbox(master=f1)
        self._ls.delete(0, tkinter.END)
        self._ls.insert(0, CONFIG['ls'])
        self._ls.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=f1, text='Sample-det. dist. (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._sd = tkinter.Spinbox(master=f1)
        self._sd.delete(0, tkinter.END)
        self._sd.insert(0, CONFIG['SD'])
        self._sd.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=f1, text='Wavelength (nm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._wavelength = tkinter.Spinbox(master=f1)
        self._wavelength.delete(0, tkinter.END)
        self._wavelength.insert(0, CONFIG['wavelength'])
        self._wavelength.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        lf = tkinter.ttk.LabelFrame(master=f1, text='Search criteria:')
        lf.grid(row=row, column=0, columnspan=2, sticky=tkinter.NSEW)
        row_inner = 0
        lf1 = tkinter.ttk.LabelFrame(master=lf, text='Sample diameter:')
        lf1.pack(side=tkinter.TOP, fill=tkinter.BOTH, expand=True)
        lf1.rowconfigure(0, weight=1)
        lf1.rowconfigure(1, weight=1)
        lf1.columnconfigure(1, weight=1)
        row_inner1 = 0
        l = tkinter.ttk.Label(master=lf1, text='Minimum (mm):')
        l.grid(row=row_inner1, column=0, sticky=tkinter.NSEW)
        self._dsamplemin = tkinter.Spinbox(
            master=lf1, from_=0, to=1000, increment=0.1)
        self._dsamplemin.delete(0, tkinter.END)
        self._dsamplemin.insert(0, CONFIG['dsample_minimum'])
        self._dsamplemin.grid(row=row_inner1, column=1, sticky=tkinter.NSEW)
        row_inner1 += 1
        l = tkinter.ttk.Label(master=lf1, text='Maximum (mm):')
        l.grid(row=row_inner1, column=0, sticky=tkinter.NSEW)
        self._dsamplemax = tkinter.Spinbox(
            master=lf1, from_=0, to=100, increment=0.1)
        self._dsamplemax.delete(0, tkinter.END)
        self._dsamplemax.insert(0, CONFIG['dsample_maximum'])
        self._dsamplemax.grid(
            row=row_inner1, column=1, sticky=tkinter.NSEW)
        row_inner += 1
        lf1 = tkinter.ttk.LabelFrame(master=lf, text='Beamstop diameter:')
        lf1.pack(side=tkinter.TOP, fill=tkinter.BOTH, expand=True)
        lf1.grid_rowconfigure(0, weight=1)
        lf1.grid_rowconfigure(1, weight=1)
        lf1.grid_columnconfigure(1, weight=1)
        row_inner1 = 0
        l = tkinter.ttk.Label(master=lf1, text='Minimum (mm):')
        l.grid(row=row_inner1, column=0, sticky=tkinter.NSEW)
        self._dbsmin = tkinter.Spinbox(
            master=lf1, from_=0, to=1000, increment=0.1)
        self._dbsmin.delete(0, tkinter.END)
        self._dbsmin.insert(0, CONFIG['dbs_minimum'])
        self._dbsmin.grid(row=row_inner1, column=1, sticky=tkinter.NSEW)
        row_inner1 += 1
        l = tkinter.ttk.Label(master=lf1, text='Maximum (mm):')
        l.grid(row=row_inner1, column=0, sticky=tkinter.NSEW)
        self._dbsmax = tkinter.Spinbox(
            master=lf1, from_=0, to=100, increment=0.1)
        self._dbsmax.delete(0, tkinter.END)
        self._dbsmax.insert(0, CONFIG['dbs_maximum'])
        self._dbsmax.grid(row=row_inner1, column=1, sticky=tkinter.NSEW)

        f2 = tkinter.ttk.Frame(master=self)
        f2.pack(side=tkinter.LEFT, fill=tkinter.BOTH, expand=True)
        f2.grid_columnconfigure(0, weight=1)
        f2.grid_rowconfigure(0, weight=1)
        f2.grid_rowconfigure(1, weight=1)
        row = 0
        lf = tkinter.ttk.LabelFrame(master=f2, text='PH#1 diameters (um):')
        lf.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._PH1choices = ListboxWithScroll(master=lf)
        self._PH1choices.grid(sticky=tkinter.NSEW)
        for d in [float(x.strip()) for x in CONFIG['D1_choices'].split(',')]:
            self._PH1choices.insert(tkinter.END, d)
        lf.grid_columnconfigure(0, weight=1)
        lf.grid_rowconfigure(0, weight=1)
        row += 1
        lf = tkinter.ttk.LabelFrame(master=f2, text='PH#2 diameters (um):')
        lf.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._PH2choices = ListboxWithScroll(master=lf)
        self._PH2choices.grid(sticky=tkinter.NSEW)
        for d in [float(x.strip()) for x in CONFIG['D2_choices'].split(',')]:
            self._PH2choices.insert(tkinter.END, d)
        lf.grid_columnconfigure(0, weight=1)
        lf.grid_rowconfigure(0, weight=1)
        self.pack_propagate(True)

    def execute(self):
        dsamplemin = float(self._dsamplemin.get())
        dsamplemax = float(self._dsamplemax.get())
        dbsmin = float(self._dbsmin.get())
        dbsmax = float(self._dbsmax.get())
        dsample = 0.5 * (dsamplemin + dsamplemax)
        dbs = 0.5 * (dbsmin + dbsmax)
        dsampletolerance = abs(dsamplemax - dsamplemin) * 0.5
        dbstolerance = abs(dbsmax - dbsmin) * 0.5
        wavelength = float(self._wavelength.get())
        criteria = [('Sample diameter',
                     lambda pc:(abs(pc.Dsample - dsample) <=
                                dsampletolerance)),
                    ('Beamstop diameter', lambda pc:abs(
                        pc.Dbs - dbs) <= dbstolerance),
                    ]
        mindist_l1 = float(self._mindist_l1.get())
        mindist_l2 = float(self._mindist_l2.get())
        sealringwidth = float(self._sealringwidth.get())
        found = []
        criteria_failed = {critname: 0 for critname, critfunc in criteria}
        allpossiblesetups = 0
        L_elements = [float(x) for x in self._distelements.get(0, tkinter.END)]
        sd = float(self._sd.get())
        ls = float(self._ls.get())
        pinholes_stage1 = [float(x)
                           for x in self._PH1choices.get(0, tkinter.END)]
        pinholes_stage2 = [float(x)
                           for x in self._PH2choices.get(0, tkinter.END)]
        lbs = float(self._lbs.get())
        for c1 in all_possible_selections(L_elements):
            list2 = L_elements[:]
            for e in c1:
                list2.remove(e)
            for c2 in all_possible_selections(list2):
                c1 = tuple(sorted(c1))
                c2 = tuple(sorted(c2))
                if bool([f for f in found
                         if f.l1_elements == c1 and f.l2_elements == c2]):
                    continue
                for D1 in pinholes_stage1:
                    for D2 in pinholes_stage2:
                        pc = PinholeConfiguration(c1, c2, D1, D2, ls, lbs,
                                                  sd, mindist_l1, mindist_l2,
                                                  sealringwidth, wavelength)
                        criteria_results = {
                            critname: critfunc(pc)
                            for critname, critfunc in criteria}
                        for k in criteria_failed:
                            criteria_failed[k] += int(not criteria_results[k])
                        if all(criteria_results.values()):
                            found.append(pc)
                        allpossiblesetups += 1

        toplevel = tkinter.Toplevel(master=self)
        toplevel.wm_title('Results of brute-force search')
        f = BruteForceResultsFrame(appframe=self._appframe, master=toplevel)
        f.pack(side=tkinter.TOP, fill=tkinter.BOTH, expand=True)
        for pc in found:
            f.add(pc)
        f.redraw()
        self._optimum = next(reversed(
            sorted(found, key=lambda pc: pc.intensity)))

    def set_optimum(self, phconf):
        pass

    def save_defaults(self):
        CONFIG['dsample_minimum'] = self._dsamplemin.get()
        CONFIG['dsample_maximum'] = self._dsamplemax.get()
        CONFIG['dbs_minimum'] = self._dbsmin.get()
        CONFIG['dbs_maximum'] = self._dbsmax.get()
        CONFIG['L1_bare'] = self._mindist_l1.get()
        CONFIG['L2_bare'] = self._mindist_l2.get()
        CONFIG['sealring'] = self._sealringwidth.get()
        CONFIG['L_elements'] = ', '.join(
            str(x) for x in self._distelements.get(0, tkinter.END))
        CONFIG['SD'] = self._sd.get()
        CONFIG['wavelength'] = self._wavelength.get()
        CONFIG['ls'] = self._ls.get()
        CONFIG['D1_choices'] = ', '.join(
            str(x) for x in self._PH1choices.get(0, tkinter.END))
        CONFIG['D2_choices'] = ', '.join(
            str(x) for x in self._PH2choices.get(0, tkinter.END))
        CONFIG['lbs'] = self._lbs.get()


class BruteForceResultsFrame(tkinter.ttk.Frame):
    _columns = [('L1 (mm)', 'l1', 0),
                ('L2 (mm)', 'l2', 0),
                ('L1 parts', 'l1_elements', 0),
                ('L2 parts', 'l2_elements', 0),
                ('S-D dist (mm)', 'sd', 2),
                ('PH#1 (um)', 'D1', 0),
                ('PH#2 (um)', 'D2', 0),
                ('alpha (rad)', 'alpha', 5),
                ('Sample diam. (mm)', 'Dsample', 3),
                ('BS diam. (mm)', 'Dbs', 3),
                ('qmin (1/nm)', 'qmin', 5),
                ('dmax (nm)', 'dmax', 1),
                ('Max. sphere diam. (nm)', 'dspheremax', 1),
                ('PH#3 (um)', 'D3', 0),
                ('Intensity (um^4/mm^2)', 'intensity', 2),
                ('Dominant constraint', 'dominant_constraint', None),
                ]

    def __init__(self, appframe, *args, **kwargs):
        self._sortkey = 'intensity'
        self._sortdescending = True
        tkinter.ttk.Frame.__init__(self, *args, **kwargs)
        self._appframe = appframe
        self._treeview = tkinter.ttk.Treeview(master=self,
                                              selectmode=tkinter.EXTENDED,
                                              show=['headings'],
                                              displaycolumns='#all',
                                              columns=[c[1] for c in
                                                       self._columns])
        for heading, c, digits in self._columns:
            self._treeview.heading(
                c, text=heading, anchor=tkinter.CENTER,
                command=lambda c=c: self._heading_clicked(c))
            self._treeview.column(
                c, anchor=tkinter.W, minwidth=10, width=10)
        self._treeview.grid(sticky=tkinter.NSEW)
        self._xscroll = tkinter.ttk.Scrollbar(master=self,
                                              orient=tkinter.HORIZONTAL,
                                              command=self._treeview.xview)
        self._xscroll.grid(row=1, column=0, sticky=tkinter.NSEW)
        self._yscroll = tkinter.ttk.Scrollbar(master=self,
                                              orient=tkinter.VERTICAL,
                                              command=self._treeview.yview)
        self._yscroll.grid(row=0, column=1, sticky=tkinter.NSEW)
        self._treeview['xscrollcommand'] = self._xscroll.set
        self._treeview['yscrollcommand'] = self._yscroll.set
        self._treeview.tag_configure('regularrow', font='TkTextFont')
        self.grid_rowconfigure(0, weight=1)
        self.grid_columnconfigure(0, weight=1)
        self._pinholeconfigs = []
        self._formatted_strings = []
        font = tkinter.font.Font(font='TkHeadingFont')
        self._columnwidths = {
            k: font.measure(self._treeview.heading(k, 'text'))
            for k in self._treeview['columns']}
        f = tkinter.ttk.Frame(master=self)
        f.grid(columnspan=2, sticky=tkinter.NSEW)
        b = tkinter.ttk.Button(master=f, text='Select all',
                               command=self.select_all)
        b.pack(side=tkinter.LEFT)
        b = tkinter.ttk.Button(master=f, text='Deselect all',
                               command=self.deselect_all)
        b.pack(side=tkinter.LEFT)
        b = tkinter.ttk.Button(master=f, text='Copy as HTML',
                               command=self.copy_as_html)
        b.pack(side=tkinter.LEFT)
        b = tkinter.ttk.Button(master=f, text='Store selected as optimum',
                               command=self.save_as_optimum)
        b.pack(side=tkinter.LEFT)
        b = tkinter.ttk.Button(master=f, text='Close window',
                               command=lambda: (
                                   self.winfo_toplevel().destroy()))
        b.pack(side=tkinter.LEFT)

    def select_all(self):
        self._treeview.selection_set(self._treeview.get_children())

    def deselect_all(self):
        self._treeview.selection_remove(self._treeview.get_children())

    def copy_as_html(self):
        txt = '<table border="1" cellpadding="1" cellspacing="1">\n  \
<thead>\n    <tr>\n'
        for col in self._treeview['columns']:
            print(col)
            txt += ('      <th scope="col">%s</th>\n' %
                    self._treeview.heading(col)['text'])
        txt += '    </tr>\n  </thead>\n  <tbody>\n'
        for item in self._treeview.selection():
            txt += '    <tr>\n'
            for v in self._treeview.item(item)['values']:
                txt += '      <td>%s</td>\n' % str(v)
            txt += '    </tr>\n'
        txt += '  </tbody>\n</table>'
        pyperclip.copy(txt)

    def save_as_optimum(self):
        try:
            item = self._treeview.selection()[0]
        except IndexError:
            return
        self._appframe._optimum = self._pinholeconfigs[
            self._treeview.index(item)]

    def add(self, pc):
        self._pinholeconfigs.append(pc)
        self._formatted_strings.append(
            [_format_text(getattr(pc, col), digs)
             for c, col, digs in self._columns])
        self._columnwidths_changed=True
        
    def _recalculate_columnwidths(self):
        font = tkinter.font.Font(font='TkTextFont')
        for fs, col in zip(
                self._formatted_strings[-1], self._treeview['columns']):
            wid = font.measure(fs)
            if self._columnwidths[col] < wid:
                self._columnwidths[col] = wid

    def redraw(self):
        self._redraw_sorted()

    def _heading_clicked(self, columnid):
        if self._sortkey == columnid:
            self._sortdescending = not self._sortdescending
        else:
            self._sortkey = columnid
            self._sortdescending = False
        self._redraw_sorted()

    def _redraw_sorted(self):
        if hasattr(self, '_columnwidths_changed'):
            self._recalculate_columnwidths()
            del self._columnwidths_changed
        for item in self._treeview.get_children():
            self._treeview.delete(item)
        pinholeconfigs_sorted = []
        formatted_strings_sorted = []
        for f, fs in sorted(zip(self._pinholeconfigs, self._formatted_strings),
                            key=lambda pcfs: getattr(
                                pcfs[0], self._sortkey) *
                            (-1)**int(self._sortdescending)):
            self._treeview.insert('', tkinter.END,
                                  values=fs,
                                  tags=['regularrow'])
            pinholeconfigs_sorted.append(f)
            formatted_strings_sorted.append(fs)
        self._pinholeconfigs = pinholeconfigs_sorted
        self._formatted_strings = formatted_strings_sorted
        for col in self._treeview['columns']:
            self._treeview.column(
                col, width=self._columnwidths[col] + 5,
                #minwidth=self._columnwidths[col] + 5, 
                stretch=False)


def _format_text(value, digits):
    if digits is None:
        return str(value)
    try:
        text = ' + '.join('%%.%df' % digits % i for i in value)
    except TypeError:
        text = '%%.%df' % digits % value
    return text


class FixedApertureSearchFrame(SearchFrame):

    def __init__(self, *args, **kwargs):
        SearchFrame.__init__(self, *args, **kwargs)
        frame = tkinter.ttk.Frame(master=self)
        frame.pack(side=tkinter.TOP, fill=tkinter.X, expand=False)

        lf = tkinter.ttk.LabelFrame(master=frame, text='Distances & apertures')
        lf.pack(side=tkinter.LEFT, fill=tkinter.BOTH, expand=True)
        lf.columnconfigure(1, weight=1)
        row = 0
        l = tkinter.ttk.Label(master=lf, text='PH#1 aperture (um):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._d1 = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._d1.delete(0, tkinter.END)
        self._d1.insert(0, CONFIG['D1'])
        self._d1.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='PH#2 aperture (um):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._d2 = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._d2.delete(0, tkinter.END)
        self._d2.insert(0, CONFIG['D2'])
        self._d2.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='PH#3-sample distance (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._ls = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._ls.delete(0, tkinter.END)
        self._ls.insert(0, CONFIG['ls'])
        self._ls.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(
            master=lf, text='Detector-beamstop distance (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._lbs = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._lbs.delete(0, tkinter.END)
        self._lbs.insert(0, CONFIG['lbs'])
        self._lbs.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1

        lf = tkinter.ttk.LabelFrame(master=frame, text='Criteria')
        lf.pack(side=tkinter.LEFT, fill=tkinter.BOTH, expand=True)
        lf.columnconfigure(1, weight=1)
        row = 0
        l = tkinter.ttk.Label(master=lf, text='Sample diameter (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._dsample = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._dsample.delete(0, tkinter.END)
        self._dsample.insert(0, CONFIG['dsample'])
        self._dsample.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Beamstop diameter (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._dbs = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._dbs.delete(0, tkinter.END)
        self._dbs.insert(0, CONFIG['dbs'])
        self._dbs.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Sample-detector dist. (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._sd = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._sd.delete(0, tkinter.END)
        self._sd.insert(0, CONFIG['SD'])
        self._sd.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Wavelength (nm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._wavelength = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._wavelength.delete(0, tkinter.END)
        self._wavelength.insert(0, CONFIG['wavelength'])
        self._wavelength.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        self._figure = matplotlib.figure.Figure(figsize=(1, 1), dpi=100)
        self._canvas = matplotlib.backends.backend_tkagg.FigureCanvasTkAgg(
            self._figure, master=self,
            resize_callback=lambda event: (
                (self._figure.set_figheight(
                    event.height / self._figure.get_dpi()),
                 self._figure.set_figwidth(
                     event.width / self._figure.get_dpi())),
                self._canvas.draw_idle()))
        self._canvas.show()
        self._canvas.get_tk_widget().pack(
            side=tkinter.TOP, fill=tkinter.BOTH, expand=True)
        self._toolbar = \
            matplotlib.backends.backend_tkagg.NavigationToolbar2TkAgg(
                self._canvas, self)
        self._toolbar.update()
        self._optimum = None

    def set_optimum(self, phconf):
        for entry, attr in [('d1', 'D1'), ('d2', 'D2'), ('ls', 'ls'),
                            ('lbs', 'lbs'), ('sd', 'sd'),
                            ('wavelength', 'wavelength')]:
            ent = getattr(self, '_' + entry)
            ent.delete(0, tkinter.END)
            ent.insert(0, str(getattr(phconf, attr)))
        self._optimum = phconf

    def execute(self):
        rs = 0.5 * float(self._dsample.get())
        rbs = 0.5 * float(self._dbs.get())
        r1 = 0.5e-3 * float(self._d1.get())
        r2 = 0.5e-3 * float(self._d2.get())
        ls = float(self._ls.get())
        lbs = float(self._lbs.get())
        sd = float(self._sd.get())
        wavelength = float(self._wavelength.get())

        if rs <= r2:
            raise ValueError('Pinhole #2 must be smaller than the sample')

        if rbs <= r2:
            raise ValueError('Pinhole #2 must be smaller than the beamstop')

        l3 = sd + ls
        l3prime = l3 - lbs

        if l3prime < 0:
            raise ValueError('The beamstop is before the sample')

        A = rbs - rs
        B = - (l3prime * r2 - ls * rbs + ls * r2 + l3prime * rs)
        C = - 2 * r2 * ls * l3prime
        l2opt = (-B + (B**2 - 4 * A * C)**0.5) / (2 * A)
        l1opt = (r1 + r2) * (l2opt + ls) / (rs - r2)
        self._optimum = PinholeConfiguration(
            l1opt, l2opt, r1 * 2000, r2 * 2000, ls, lbs, sd, 0, 0, 0,
            wavelength)

        l2min = 2 * r2 * l3prime / (rbs - r2)
        l2_ = np.linspace(l2min, 3 * self._optimum.l2, 100)[1:]
        l1_bs = ((r1 + r2) * (l2_ + l3prime) /
                 (rbs - r2 * (1 + 2 * l3prime / l2_)))
        l1_sample = (l2_ + ls) * (r1 + r2) / (rs - r2)
        l1_min = min(l1_bs.min(), l1_sample.min())
        l1_max = self._optimum.l1 * 2
        l1_ = np.linspace(l1_min, l1_max, 100)
        intensity_ = 1e12 * r1**2 * r2**2 / \
            np.outer(np.ones_like(l2_), l1_)**2 * np.pi / 4
        self._figure.clear()
        ax = self._figure.add_subplot(1, 1, 1)
        img = ax.imshow(intensity_, aspect='auto', cmap=matplotlib.cm.gray_r,
                        origin='lower', extent=(l1_.min(), l1_.max(),
                                                l2_.min(), l2_.max()))
        self._figure.colorbar(img)
        ax.set_xlabel('$l_1\, (\mathrm{mm})$')
        ax.set_ylabel('$l_2\, (\mathrm{mm})$')
        lims = ax.axis()
        ax.plot(l1_sample, l2_, 'b-', label='Sample')
        ax.plot(l1_bs, l2_, 'r-', label='Beamstop')
        ax.fill_betweenx(l2_, l1_sample, l1_.max(), color='b', alpha=0.2)
        ax.fill_betweenx(l2_, l1_bs, l1_.max(), color='r', alpha=0.2)
        ax.plot([self._optimum.l1], [self._optimum.l2], 'wo', markersize=10)
        text = u"""\
Optimum l1: %.2f mm
Optimum l2: %.2f mm
l1+l2: %.2f mm
Intensity: %.2f um^4/mm^2
D3: %.2f um
qmin: %.5f 1/nm
dmax: %.2f nm""" % (self._optimum.l1, self._optimum.l2,
                    self._optimum.l1 + self._optimum.l2,
                    self._optimum.intensity, self._optimum.D3,
                    self._optimum.qmin, self._optimum.dmax)
        ax.text(0.01, 0.97,
                text,
                transform=ax.transAxes, ha='left', va='top', alpha=1,
                bbox={'boxstyle': 'round', 'alpha': 0.3, 'color': 'white'},
                fontsize='xx-small')
        ax.legend(loc='best', fontsize='xx-small')
        ax.axis(lims)
        self._figure.tight_layout()
        self._canvas.show()

    def save_defaults(self):
        CONFIG['dsample'] = self._dsample.get()
        CONFIG['dbs'] = self._dbs.get()
        CONFIG['D1'] = self._d1.get()
        CONFIG['D2'] = self._d2.get()
        CONFIG['ls'] = self._ls.get()
        CONFIG['lbs'] = self._lbs.get()
        CONFIG['SD'] = self._sd.get()
        CONFIG['wavelength'] = self._wavelength.get()


class FixedDistanceSearchFrame(SearchFrame):

    def __init__(self, *args, **kwargs):
        SearchFrame.__init__(self, *args, **kwargs)
        frame = tkinter.ttk.Frame(master=self)
        frame.pack(side=tkinter.TOP, fill=tkinter.X, expand=False)

        lf = tkinter.ttk.LabelFrame(master=frame, text='Distances')
        lf.pack(side=tkinter.LEFT, fill=tkinter.BOTH, expand=True)
        lf.columnconfigure(1, weight=1)
        row = 0
        l = tkinter.ttk.Label(master=lf, text='PH#1-PH#2 distance (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._l1 = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._l1.delete(0, tkinter.END)
        self._l1.insert(0, CONFIG['L1'])
        self._l1.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='PH#2-PH#3 distance (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._l2 = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._l2.delete(0, tkinter.END)
        self._l2.insert(0, CONFIG['L2'])
        self._l2.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='PH#3-sample distance (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._ls = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._ls.delete(0, tkinter.END)
        self._ls.insert(0, CONFIG['ls'])
        self._ls.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(
            master=lf, text='Detector-beamstop distance (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._lbs = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._lbs.delete(0, tkinter.END)
        self._lbs.insert(0, CONFIG['lbs'])
        self._lbs.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1

        lf = tkinter.ttk.LabelFrame(master=frame, text='Criteria')
        lf.pack(side=tkinter.LEFT, fill=tkinter.BOTH, expand=True)
        lf.columnconfigure(1, weight=1)
        row = 0
        l = tkinter.ttk.Label(master=lf, text='Sample diameter (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._dsample = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._dsample.delete(0, tkinter.END)
        self._dsample.insert(0, CONFIG['dsample'])
        self._dsample.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Beamstop diameter (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._dbs = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._dbs.delete(0, tkinter.END)
        self._dbs.insert(0, CONFIG['dbs'])
        self._dbs.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Sample-detector dist. (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._sd = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._sd.delete(0, tkinter.END)
        self._sd.insert(0, CONFIG['SD'])
        self._sd.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Wavelength (nm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._wavelength = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._wavelength.delete(0, tkinter.END)
        self._wavelength.insert(0, CONFIG['wavelength'])
        self._wavelength.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        self._figure = matplotlib.figure.Figure(figsize=(1, 1), dpi=100)
        self._canvas = matplotlib.backends.backend_tkagg.FigureCanvasTkAgg(
            self._figure, master=self,
            resize_callback=lambda event: (
                (self._figure.set_figheight(
                    event.height / self._figure.get_dpi()),
                 self._figure.set_figwidth(
                     event.width / self._figure.get_dpi())),
                self._canvas.draw_idle()))
        self._canvas.show()
        self._canvas.get_tk_widget().pack(
            side=tkinter.TOP, fill=tkinter.BOTH, expand=True)
        self._toolbar = \
            matplotlib.backends.backend_tkagg.NavigationToolbar2TkAgg(
                self._canvas, self)
        self._toolbar.update()
        self._optimum = None

    def set_optimum(self, phconf):
        for entry, attr in [('l1', 'l1'), ('l2', 'l2'), ('ls', 'ls'),
                            ('lbs', 'lbs'), ('sd', 'sd'),
                            ('wavelength', 'wavelength')]:
            ent = getattr(self, '_' + entry)
            ent.delete(0, tkinter.END)
            ent.insert(0, str(getattr(phconf, attr)))
        self._optimum = phconf

    def execute(self):
        l1 = float(self._l1.get())
        l2 = float(self._l2.get())
        ls = float(self._ls.get())
        lbs = float(self._lbs.get())
        rs = float(self._dsample.get()) * 0.5
        rbs = float(self._dbs.get()) * 0.5
        wavelength = float(self._wavelength.get())
        sd = float(self._sd.get())

        lambda1 = (ls + l2) / l1
        lambda0 = (l2 / l1)
        lambda2 = (ls + sd - lbs) / l2

        if lambda2*l2 < 0:
            raise ValueError('The beamstop is before the sample')

        
        A_s = rs
        B_s = lambda1
        C_s = 1 + lambda1
        A_bs = rbs
        B_bs = lambda0 * (1 + lambda2)
        C_bs = (1 + lambda0) * (1 + lambda2) + lambda2
        r2max_s = A_s / C_s * 0.5
        r1max_s = A_s / B_s * 0.5
        r2max_bs = A_bs / C_bs * 0.5
        r1max_bs = A_bs / B_bs * 0.5
        S_satisfies_BS = (r2max_s <= (A_bs - B_bs * r1max_s) / C_bs)
        BS_satisfies_S = (r2max_bs <= (A_s - B_s * r1max_bs) / C_s)
        if S_satisfies_BS:
            r1opt = r1max_s
            r2opt = r2max_s
            strictercriterion = 'sample-limited'
        elif BS_satisfies_S:
            r1opt = r1max_bs
            r2opt = r2max_bs
            strictercriterion = 'beamstop-limited'
        else:
            r1opt = ((rs * ((1 + lambda0) * (1 + lambda2) + lambda2) -
                      rbs * (1 + lambda1)) /
                     ((lambda1 - lambda0) * (1 + lambda2) + lambda1 *
                      lambda2))
            r2opt = ((rbs * lambda1 - rs * lambda0 * (1 + lambda2)) /
                     ((lambda1 - lambda0) * (1 + lambda2) + lambda1 *
                      lambda2))
            strictercriterion = 'transitory case'
        self._optimum = PinholeConfiguration(
            l1, l2, 2000 * r1opt, 2000 * r2opt, ls, lbs, sd, 0, 0, 0,
            wavelength)
        r1max = max(r1max_s, r1max_bs)
        r2max = max(r2max_s, r2max_bs)
        rho = rs / rbs
        d_SsBS = (
            (2 * lambda1 * (lambda1 + 1) - rho * (lambda0 + lambda1 +
                                                  2 * lambda0 * lambda1)) /
            (rho * (lambda0 + 2 * lambda1 + 2 * lambda0 * lambda1)) *
            l2 + lbs - ls)
        Discr = 4 * rho**2 * lambda0**2 + rho * \
            (4 * lambda0**2 - 8 * lambda0 * lambda1) + \
            (lambda0 + 2 * lambda1 + 2 * lambda0 * lambda1)**2
#        lam2minus = ((2 * lambda1 + lambda0 + 2 * lambda0 * lambda1 -
#                      6 * rho * lambda0 - 4 * rho * lambda0**2 - Discr**0.5) /
#                     (8 * rho * lambda0 + 4 * rho * lambda0**2))
        lam2plus = ((2 * lambda1 + lambda0 + 2 * lambda0 * lambda1 -
                     6 * rho * lambda0 - 4 * rho * lambda0**2 + Discr**0.5) /
                    (8 * rho * lambda0 + 4 * rho * lambda0**2))
#        d_BSsSminus = lam2minus * l2 + lbs - ls
        d_BSsSplus = lam2plus * l2 + lbs - ls
        r1_ = np.linspace(0, r1max * 1.5, 100)
        R1 = r1_[np.newaxis, :]
        r2_ = np.linspace(0, r2max * 1.5, 200)
        R2 = r2_[:, np.newaxis]
        intensity_ = 1e12 * R1**2 * R2**2 / l1**2 * \
            np.pi / 4.  # D1^2*D2^2/l1^2, in um^4/mm^2 units
        self._figure.clear()
        ax = self._figure.add_subplot(1, 1, 1)
        img = ax.imshow(intensity_, aspect='auto', cmap=matplotlib.cm.gray_r,
                        origin='lower', extent=(2000 * R1.min(),
                                                2000 * R1.max(),
                                                2000 * R2.min(),
                                                2000 * R2.max()))
        self._figure.colorbar(img)
        ax.set_xlabel('$d_1\, (\mu\mathrm{m})$')
        ax.set_ylabel('$d_2\, (\mu\mathrm{m})$')
        lims = ax.axis()
        ax.plot(r1_ * 2000, 2000 * (A_s - r1_ * B_s) / C_s,
                'b-', lw=2, label='Sample')
        ax.fill_between(
            r1_ * 2000, 2000 * (A_s - r1_ * B_s) / C_s, color='b', alpha=0.2)
        ax.plot(r1_ * 2000, 2000 * (A_bs - r1_ * B_bs) / C_bs,
                'r-', lw=2, label='Beamstop')
        ax.fill_between(
            r1_ * 2000, 2000 * (A_bs - r1_ * B_bs) / C_bs, color='r',
            alpha=0.2)
        ax.plot([r1max_s * 2000], [r2max_s * 2000], 'bo', markersize=10)
        ax.plot([r1max_bs * 2000], [r2max_bs * 2000], 'ro', markersize=10)
        ax.plot([self._optimum.D1], [self._optimum.D2], 'wo', markersize=5)
        text = u"""\
Optimum D1: %.2f um
Optimum D2: %.2f um
Optimum D3: %.2f um
Intensity: %.2f um^4/mm^2
Qmin: %.3f 1/nm
Dmax: %.3f nm
Sample-limited: SD<%.2f mm
Beamstop-limited: SD>%.2f mm
Current mode: %s""" % (self._optimum.D1, self._optimum.D2,
                             self._optimum.D3, self._optimum.intensity,
                             self._optimum.qmin, self._optimum.dmax, d_SsBS,
                             d_BSsSplus, strictercriterion)
        ax.text(0.01, 0.97,
                text,
                transform=ax.transAxes, ha='left', va='top', alpha=1,
                bbox={'boxstyle': 'round', 'alpha': 0.3, 'color': 'white'},
                fontsize='xx-small')
        ax.axis(lims)
        ax.legend(loc='lower right', fontsize='xx-small')
        self._figure.tight_layout()
        self._canvas.draw()

    def save_defaults(self):
        CONFIG['L1'] = self._l1.get()
        CONFIG['L2'] = self._l2.get()
        CONFIG['ls'] = self._ls.get()
        CONFIG['lbs'] = self._lbs.get()
        CONFIG['dsample'] = self._dsample.get()
        CONFIG['dbs'] = self._dbs.get()
        CONFIG['wavelength'] = self._wavelength.get()
        CONFIG['SD'] = self._sd.get()


class CalculatorFrame(SearchFrame):

    def __init__(self, *args, **kwargs):
        SearchFrame.__init__(self, *args, **kwargs)
        f = tkinter.ttk.Frame(master=self)
        f.pack(side=tkinter.LEFT, fill=tkinter.Y, expand=False)
        self._entries = {}
        self._outputlabels = {}
        tkinter.ttk.Style().configure(
            'leftjustified.TButton', anchor=tkinter.W)
        for row, name, label, initvalue, in zip(
                itertools.count(0), ['d1', 'd2', 'l1', 'intensity', 'l2',
                                     'd3', 'ls', 'dsprime', 'ds', 'd', 'l3',
                                     'ddetprime', 'ddet', 'wavelength', 'qmin',
                                     'dmax', 'lbs', 'dbsprime', 'dbs'],
                ['PH#1 aperture (%s, um)',
                 'PH#2 aperture (%s, um)',
                 'PH#1-PH#2 distance (%s, mm)',
                 'Intensity factor (%s, um^4/mm^2)',
                 'PH#2-PH#3 distance (%s, mm)',
                 'PH#3 aperture (%s, um)',
                 'PH#3-Sample distance (%s, um)',
                 'Beam size at sample (%s, mm)',
                 'Extent of parasitic scattering at sample (%s, mm)',
                 'Sample-to-detector distance (%s, mm)',
                 'PH#3-Detector distance (%s, mm)',
                 'Beam size at detector (%s, mm)',
                 'Extent of parasitic scattering at detector (%s, mm)',
                 'Wavelength (%s, nm)',
                 'Lowest attainable q (%s, 1/mm)',
                 'Maximum lattice spacing (%s, mm)',
                 'Detector-Beamstop distance (%s, mm)',
                 'Beam size at beamstop (%s, mm)',
                 'Extent of parasitic scattering at beamstop (%s, mm)'],
                [CONFIG['D1'], CONFIG['D2'], CONFIG['L1'], None,
                 CONFIG['L2'], None, CONFIG['ls'], None, None,
                 CONFIG['SD'], None, None, None, CONFIG['wavelength'],
                 None, None, CONFIG['lbs'], None, None]):
            if initvalue is not None:
                l = tkinter.ttk.Label(master=f, text=label % name + ':')
                l.grid(row=row, column=0, sticky=tkinter.NSEW)
                self._entries[name] = tkinter.Spinbox(
                    master=f, from_=0, to=500000)
                self._entries[name].delete(0, tkinter.END)
                self._entries[name].insert(0, str(initvalue))
                self._entries[name].grid(
                    row=row, column=1, sticky=tkinter.NSEW)
            else:
                b = tkinter.ttk.Button(
                    master=f, text=label % name + ':',
                    command=lambda n=name: self.do_calculate(n),
                    style='leftjustified.TButton')
                b.grid(row=row, column=0, sticky=tkinter.NSEW)
                self._outputlabels[name] = tkinter.Label(
                    master=f, text='--', anchor=tkinter.W)
                self._outputlabels[name].grid(
                    row=row, column=1, sticky=tkinter.NSEW)
        f.columnconfigure(1, weight=1)

    def set_optimum(self, phconf):
        for entry, attr in [('d1', 'D1'), ('d2', 'D2'), ('l1', 'l1'),
                            ('l2', 'l2'), ('ls', 'ls'), ('lbs', 'lbs'),
                            ('d', 'sd'), ('wavelength', 'wavelength')]:
            self._entries[entry].delete(0, tkinter.END)
            self._entries[entry].insert(0, str(getattr(phconf, attr)))
        self._optimum = phconf
        self.after_idle(self.execute)

    def do_calculate(self, name):
        inputs = {}
        for n in self._entries:
            try:
                inputs[n] = float(self._entries[n].get())
            except:
                inputs[n] = 0
        phconf = PinholeConfiguration(
            inputs['l1'], inputs['l2'], inputs['d1'], inputs['d2'],
            inputs['ls'], inputs['lbs'], inputs['d'], 0, 0, 0,
            inputs['wavelength'])
        for nm in ['intensity', 'd3', 'dsprime', 'ds', 'l3', 'ddetprime',
                   'ddet', 'qmin', 'dmax', 'dbsprime', 'dbs']:
            if nm == 'intensity':
                result = phconf.intensity
            elif nm == 'd3':
                result = phconf.D3
            elif nm == 'ds':
                result = phconf.Dsample_parasitic
            elif nm == 'dsprime':
                result = phconf.Dsample_direct
            elif nm == 'l3':
                result = phconf.l3
            elif nm == 'ddet':
                result = phconf.Ddet_parasitic
            elif nm == 'ddetprime':
                result = phconf.Ddet_direct
            elif nm == 'qmin':
                result = phconf.qmin
            elif nm == 'dmax':
                result = phconf.dmax
            elif nm == 'dbsprime':
                result = phconf.Dbs_direct
            elif nm == 'dbs':
                result = phconf.Dbs_parasitic
            else:
                raise NotImplementedError(nm)

            # common parts
            inputs[nm] = result
            self._outputlabels[nm]['text'] = str(result)
            if name == nm:
                break
        self._optimum = phconf
        return result

    def execute(self):
        return self.do_calculate(None)

    def save_defaults(self):
        CONFIG['D1'] = self._entries['d1'].get()
        CONFIG['D2'] = self._entries['d2'].get()
        CONFIG['L1'] = self._entries['l1'].get()
        CONFIG['L2'] = self._entries['l2'].get()
        CONFIG['ls'] = self._entries['ls'].get()
        CONFIG['lbs'] = self._entries['lbs'].get()
        CONFIG['SD'] = self._entries['d'].get()
        CONFIG['wavelength'] = self._entries['wavelength'].get()


class DeadlockException(Exception):
    pass


class MonteCarloSearchFrame(SearchFrame):

    _printqueue_interval=100 # ms

    def __init__(self, *args, **kwargs):
        SearchFrame.__init__(self, *args, **kwargs)
        frame = tkinter.ttk.Frame(master=self)
        frame.pack(side=tkinter.TOP, fill=tkinter.X, expand=False)

        lf = tkinter.ttk.LabelFrame(master=frame, text='Parameters')
        lf.pack(side=tkinter.LEFT, fill=tkinter.BOTH, expand=True)
        lf.columnconfigure(1, weight=1)
        row = 0
        l = tkinter.ttk.Label(master=lf, text='Sample-detector dist. (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._sd = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._sd.delete(0, tkinter.END)
        self._sd.insert(0, CONFIG['SD'])
        self._sd.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Wavelength (nm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._wavelength = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._wavelength.delete(0, tkinter.END)
        self._wavelength.insert(0, CONFIG['wavelength'])
        self._wavelength.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='PH#3-sample distance (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._ls = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._ls.delete(0, tkinter.END)
        self._ls.insert(0, CONFIG['ls'])
        self._ls.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(
            master=lf, text='Detector-beamstop distance (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._lbs = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._lbs.delete(0, tkinter.END)
        self._lbs.insert(0, CONFIG['lbs'])
        self._lbs.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1

        lf = tkinter.ttk.LabelFrame(master=frame, text='Constraints')
        lf.pack(side=tkinter.LEFT, fill=tkinter.BOTH, expand=True)
        lf.columnconfigure(1, weight=1)
        row = 0
        l = tkinter.ttk.Label(
            master=lf, text='Maximum collimation length (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._l12 = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._l12.delete(0, tkinter.END)
        self._l12.insert(0, CONFIG['L12_max'])
        self._l12.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Maximum PH#1 aperture (um):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._D1max = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._D1max.delete(0, tkinter.END)
        self._D1max.insert(0, CONFIG['D1_max'])
        self._D1max.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Maximum PH#2 aperture (um):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._D2max = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._D2max.delete(0, tkinter.END)
        self._D2max.insert(0, CONFIG['D2_max'])
        self._D2max.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Maximum sample diameter (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._dsample = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._dsample.delete(0, tkinter.END)
        self._dsample.insert(0, CONFIG['dsample'])
        self._dsample.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(
            master=lf, text='Maximum beamstop diameter (mm):')
        l.grid(row=row, column=0, sticky=tkinter.NSEW)
        self._dbs = tkinter.Spinbox(master=lf, from_=0, to=500000)
        self._dbs.delete(0, tkinter.END)
        self._dbs.insert(0, CONFIG['dbs'])
        self._dbs.grid(row=row, column=1, sticky=tkinter.NSEW)
        row += 1

        lf = tkinter.ttk.LabelFrame(master=self, text='Monte Carlo parameters')
        lf.pack(side=tkinter.TOP, fill=tkinter.X, expand=False)
        lf.columnconfigure(1, weight=1)
        lf.columnconfigure(3, weight=1)
        row = 0
        column = 0
        l = tkinter.ttk.Label(master=lf, text='Monte Carlo steps per run:')
        l.grid(row=row, column=column, sticky=tkinter.NSEW)
        self._NMC = tkinter.Spinbox(master=lf, from_=0, to=5000000000)
        self._NMC.delete(0, tkinter.END)
        self._NMC.insert(0, CONFIG['MCsteps'])
        self._NMC.grid(row=row, column=column + 1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Monte Carlo runs:')
        l.grid(row=row, column=column, sticky=tkinter.NSEW)
        self._NMCiters = tkinter.Spinbox(master=lf, from_=0, to=5000000000)
        self._NMCiters.delete(0, tkinter.END)
        self._NMCiters.insert(0, CONFIG['MCruns'])
        self._NMCiters.grid(row=row, column=column + 1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Maximum retries for deadlock:')
        l.grid(row=row, column=column, sticky=tkinter.NSEW)
        self._Ndeadlock = tkinter.Spinbox(master=lf, from_=0, to=5000000000)
        self._Ndeadlock.delete(0, tkinter.END)
        self._Ndeadlock.insert(0, CONFIG['Ndeadlock'])
        self._Ndeadlock.grid(row=row, column=column + 1, sticky=tkinter.NSEW)
        row += 1
        column = 2
        row = 0
        l = tkinter.ttk.Label(
            master=lf, text='Initial relative temperature (kbT):')
        l.grid(row=row, column=column, sticky=tkinter.NSEW)
        self._kbT = tkinter.Spinbox(master=lf, from_=0, to=5000000000)
        self._kbT.delete(0, tkinter.END)
        self._kbT.insert(0, CONFIG['kbT'])
        self._kbT.grid(row=row, column=column + 1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Relative stepsize:')
        l.grid(row=row, column=column, sticky=tkinter.NSEW)
        self._relstep = tkinter.Spinbox(master=lf, from_=0, to=1)
        self._relstep.delete(0, tkinter.END)
        self._relstep.insert(0, CONFIG['relstepsize'])
        self._relstep.grid(row=row, column=column + 1, sticky=tkinter.NSEW)
        row += 1
        l = tkinter.ttk.Label(master=lf, text='Aimed acceptance ratio:')
        l.grid(row=row, column=column, sticky=tkinter.NSEW)
        self._acceptance = tkinter.Spinbox(master=lf, from_=0, to=1)
        self._acceptance.delete(0, tkinter.END)
        self._acceptance.insert(0, CONFIG['accratio'])
        self._acceptance.grid(row=row, column=column + 1, sticky=tkinter.NSEW)
        row += 1

        lf = tkinter.ttk.LabelFrame(master=self, text='Results')
        lf.pack(side=tkinter.TOP, fill=tkinter.BOTH, expand=True)
        lf.rowconfigure(0, weight=1)
        lf.columnconfigure(0, weight=1)
        f = tkinter.ttk.Frame(master=lf)
        f.grid(sticky=tkinter.NSEW)
        f.rowconfigure(0, weight=1)
        f.columnconfigure(0, weight=1)
        self._text = tkinter.Text(
            master=f, width=30, height=15, state=tkinter.DISABLED)
        self._text.grid(row=0, column=0, sticky=tkinter.NSEW)
        hs = tkinter.ttk.Scrollbar(
            master=f, orient=tkinter.HORIZONTAL, command=self._text.xview)
        hs.grid(row=1, column=0, sticky=tkinter.NSEW)
        vs = tkinter.ttk.Scrollbar(
            master=f, orient=tkinter.VERTICAL, command=self._text.yview)
        vs.grid(row=0, column=1, sticky=tkinter.NSEW)
        self._text['xscrollcommand'] = hs.set
        self._text['yscrollcommand'] = vs.set
        self._text['state'] = tkinter.DISABLED
        self._queue = queue.Queue()
        self.after(self._printqueue_interval, self._queue_consumer)
        
    def _queue_consumer(self):
        self._text['state'] = tkinter.NORMAL
        while True:
            try:
                message = self._queue.get_nowait()
                self._text.insert(tkinter.END, message + '\n')
            except queue.Empty:
                self.after(self._printqueue_interval,self._queue_consumer)
                self._text['state']=tkinter.DISABLED
                self._text.see(tkinter.END)
                self._text.update_idletasks()
                return

    def _get_new_value(self, oldvalue, steprange, maxvalue, Ndeadlock):
        for k in range(Ndeadlock):
            if (not k % (Ndeadlock / 10)) and (k / (Ndeadlock / 10)) > 2:
                self.write(
                    "Inner deadlock suspicion:",
                    (k / (Ndeadlock / 10) + 1), '/10')
            newvalue = oldvalue + np.random.randn() * steprange
            if (newvalue > 0) and (newvalue <= maxvalue):
                return newvalue
        raise DeadlockException

    def write(self, *args, **kwargs):
        self._queue.put_nowait(' '.join(str(a) for a in args))

    def _runMC(self, starting_state, l1_step, l2_step, r1_step, r2_step,
               l1max, l2max, r1max, r2max, NMC, Ndeadlock,
               Ninnerdeadlock, rbs, rs, L12, kbT):
        accepted = 0
        optimum = starting_state.copy()
        state = starting_state.copy()
        intensity = state.intensity
        intensityopt = intensity
        for i in range(NMC):
            if (not i % (NMC / 10)):
                self.write("  Iteration #%d: Intensity: %f" % (i, intensity))
            what = ['l1', 'l2', 'r1', 'r2'][i % 4]
            for j in range(Ndeadlock):
                if (not j % (Ndeadlock / 10)) and (j / (Ndeadlock / 10)) > 7:
                    self.write(
                        "  Iteration #%d: deadlock suspicion: %d/10" %
                        (i, j / (Ndeadlock / 10) + 1))
                state1 = state.copy()
                try:
                    if what == 'l1':
                        state1.l1 = self._get_new_value(state1.l1, l1_step,
                                                        l1max, Ninnerdeadlock)
                    elif what == 'l2':
                        state1.l2 = self._get_new_value(state1.l2, l2_step,
                                                        l2max, Ninnerdeadlock)
                    elif what == 'r1':
                        state1.r1 = self._get_new_value(state1.r1, r1_step,
                                                        r1max, Ninnerdeadlock)
                    elif what == 'r2':
                        state1.r2 = self._get_new_value(state1.r2, r2_step,
                                                        r2max, Ninnerdeadlock)
                except DeadlockException:
                    self.write("  Recovering from inner deadlock.")
                    continue
                if ((state1.rbs <= rbs) and
                    (state1.rs <= rs) and (state1.l1 + state1.l2 <= L12)):
                    break
            if j >= Ndeadlock - 1:
                self.write("  Deadlock detected at iteration #%d" % i)
                continue
            intensity1 = state1.intensity
            if np.random.rand() < np.exp((intensity1 - intensity) / kbT):
                state = state1
                intensity = intensity1
                accepted += 1
            if intensity > intensityopt:
                optimum = state
                intensityopt = intensity
        return optimum, state, accepted / float(NMC)

    def run_thread(self, state, L12, r1max, r2max, relstep, NMC,
                   NMCiters, Ndeadlock, rbs, rs, kbT, acceptance):
        starttime = time.time()
        self.write('Monte Carlo simulation started.')
        self.write('Max. collimation length:', L12)
        self.write('Beamstop radius:', rbs)
        self.write('Sample radius:', rs)
        optimum = state.copy()
        for k in range(NMCiters):
            l2_step = l1_step = L12 * relstep
            r1_step = r1max * relstep
            r2_step = r2max * relstep
            optimum1, state, acc = self._runMC(state, l1_step, l2_step,
                                               r1_step, r2_step, L12,
                                               L12, r1max, r2max, NMC,
                                               Ndeadlock, Ndeadlock, rbs,
                                               rs, L12, kbT)
            if optimum1.intensity > optimum.intensity:
                optimum = optimum1
            self.write(
                "After MC loop #%d/%d, acceptance ratio is: %.3f" %
                (k + 1, NMCiters, acc))
            if acc < acceptance:
                kbT *= 2
            else:
                kbT /= 2
            self.write("New temperature: %s" % kbT)
            self.write("Optimum up to now: %s" % optimum)
        self.write('---------------')
        self.write('Found solution:')
        self.write('---------------')
        self.write('D1: %.2f um' % optimum.D1)
        self.write('D2: %.2f um' % optimum.D2)
        self.write('l1: %.2f mm' % optimum.l1)
        self.write('l2: %.2f mm' % optimum.l2)
        self.write('Ds: %.2f mm' % optimum.Dsample)
        self.write('Dbs: %.2f mm' % optimum.Dbs)
        self.write('Intensity: %.2f um^4/mm^2' %
                   optimum.intensity)
        self.write('---------------')
        t = time.time() - starttime
        h = np.floor(t / 3600)
        m = np.floor((t - 3600 * h) / 60)
        s = np.floor((t - 3600 * h - 60 * m))
        ms = (t - 3600 * h - 60 * m - s)
        self.write('Elapsed time: %02d:%02d:%02d.%d' % (h, m, s, ms))
        del self._thread
        self._optimum = optimum

    def execute(self):
        if hasattr(self, '_thread'):
            tkinter.messagebox.showerror(
                'Error', 'Monte Carlo thread is already running.')
            return

        L12 = float(self._l12.get())
        rs = 0.5 * float(self._dsample.get())
        rbs = 0.5 * float(self._dbs.get())
        sd = float(self._sd.get())
        ls = float(self._ls.get())
        lbs = float(self._lbs.get())

        r1max = float(self._D1max.get()) * 0.5e-3
        r2max = float(self._D2max.get()) * 0.5e-3

        NMC = int(self._NMC.get())
        NMCiters = int(self._NMCiters.get())
        Ndeadlock = int(self._Ndeadlock.get())
        acceptance = float(self._acceptance.get())
        relstep = float(self._relstep.get())
        kbT = float(self._kbT.get())

        # initialization
        while True:
            l1 = np.random.uniform(0, L12)
            state = PinholeConfiguration(l1, np.random.uniform(0, L12 - l1),
                                         2000 * np.random.uniform(0, r1max),
                                         2000 * np.random.uniform(0, r2max),
                                         ls, lbs, sd)
            if (state.rs < rs) and (state.rbs < rbs):
                break

        self._thread = threading.Thread(target=self.run_thread,
                                        name='MC_calculation',
                                        args=(state, L12, r1max, r2max,
                                              relstep, NMC, NMCiters,
                                              Ndeadlock, rbs, rs, kbT,
                                              acceptance), daemon=True)
        self._thread.start()

    def save_defaults(self):
        CONFIG['L12_max'] = self._l12.get()
        CONFIG['dsample'] = self._dsample.get()
        CONFIG['dbs'] = self._dbs.get()
        CONFIG['SD'] = self._sd.get()
        CONFIG['ls'] = self._ls.get()
        CONFIG['lbs'] = self._lbs.get()
        CONFIG['D1_max'] = self._D1max.get()
        CONFIG['D2_max'] = self._D2max.get()
        CONFIG['MCsteps'] = self._NMC.get()
        CONFIG['MCiters'] = self._NMCiters.get()
        CONFIG['Ndeadlock'] = self._Ndeadlock.get()
        CONFIG['accratio'] = self._acceptance.get()
        CONFIG['relstepsize'] = self._relstep.get()
        CONFIG['kbT'] = self._kbT.get()


class MainWindowFrame(tkinter.ttk.Frame):

    def __init__(self, *args, **kwargs):
        tkinter.ttk.Frame.__init__(self, *args, **kwargs)
        self._notebook = tkinter.ttk.Notebook(master=self)
        self._notebook.pack(fill=tkinter.BOTH, expand=True)
        self._tabs = []
        for tabframeclass, tabname in [(BruteForceSearchFrame, 'Brute-force'),
                                       (FixedApertureSearchFrame,
                                        'Fixed aperture'),
                                       (FixedDistanceSearchFrame,
                                        'Fixed distance'),
                                       (MonteCarloSearchFrame,
                                        'Global Monte Carlo'),
                                       (CalculatorFrame, 'Calculator'),
                                       ]:
            self._tabs.append(tabframeclass(appframe=self,
                                            master=self._notebook))
            self._notebook.add(
                self._tabs[-1], text=tabname, sticky=tkinter.NSEW)
        buttonframe = tkinter.ttk.Frame(master=self)
        buttonframe.pack(side=tkinter.BOTTOM, fill=tkinter.BOTH, expand=False)
        for buttontext, buttoncommand in [('Quit', self.do_quit),
                                          ('Execute', self.do_execute),
                                          ('Store optimum',
                                           self.do_store_optimum),
                                          ('Retrieve optimum',
                                           self.do_retrieve_optimum)
                                          ]:
            b = tkinter.ttk.Button(master=buttonframe,
                                   text=buttontext,
                                   command=buttoncommand)
            b.pack(side=tkinter.LEFT, fill=tkinter.BOTH, expand=True)
        self._optimum = None

    def do_quit(self):
        self.winfo_toplevel().destroy()

    def do_store_optimum(self):
        self._optimum = self._tabs[
            self._notebook.index('current')].get_optimum()

    def do_retrieve_optimum(self):
        if self._optimum is None:
            return
        self._tabs[self._notebook.index('current')].set_optimum(self._optimum)

    def do_execute(self):
        try:
            self._tabs[self._notebook.index('current')].execute()
            self._tabs[self._notebook.index('current')].save_defaults()
        except ValueError as ve:
            tkinter.messagebox.showerror('Error',ve.args[0])

CONFIGroot = configparser.ConfigParser()
CONFIGroot['CONFIG'] = {}
CONFIG = CONFIGroot['CONFIG']
CONFIG['wavelength'] = '0.15418'
CONFIG['SD'] = '520.0'
CONFIG['ls'] = '130.0'
CONFIG['lbs'] = '54.0'
CONFIG['dsample'] = '0.8'
CONFIG['dbs'] = '4.0'
CONFIG['L_elements'] = '100, 100, 200, 200, 500, 800'
CONFIG['D1_choices'] = '150, 300, 600, 1000, 1250'
CONFIG['D2_choices'] = '150, 300, 400, 500, 1000'
CONFIG['sealring'] = '4.0'
CONFIG['L1_bare'] = '104.0'
CONFIG['L2_bare'] = '104.0'
CONFIG['dsample_minimum'] = '0.0'
CONFIG['dsample_maximum'] = '0.9'
CONFIG['dbs_minimum'] = '0.0'
CONFIG['dbs_maximum'] = '4.0'
CONFIG['D1'] = '500.0'
CONFIG['D2'] = '200.0'
CONFIG['L1'] = '500.0'
CONFIG['L2'] = '800.0'
CONFIG['L12_max'] = '3000.0'
CONFIG['D1_max'] = '1000.0'
CONFIG['D2_max'] = '1000.0'
CONFIG['MCsteps'] = '10000'
CONFIG['MCruns'] = '10'
CONFIG['Ndeadlock'] = '10000'
CONFIG['kbT'] = '500.0'
CONFIG['relstepsize'] = '0.01'
CONFIG['accratio'] = '0.5'
CONFIGroot.read(os.path.expanduser('~/.SASCollOpt'), encoding='utf-8')
root = tkinter.Tk()
root.wm_title('SASCollOpt -- Optimum Collimation for Small-Angle Scattering')
mwf = MainWindowFrame(master=root)
mwf.pack(fill=tkinter.BOTH, expand=True)
root.mainloop()
with open(os.path.expanduser('~/.SASCollOpt'), 'wt', encoding='utf-8') as f:
    CONFIGroot.write(f)
