"""
Test self-consistent iteration in simple cases

:Author: Pauli Virtanen <pauli@ltl.tkk.fi>
"""
from __future__ import division, absolute_import, print_function

from testutils import *
from usadel1 import *
from numpy import *

__revision__ = "$Id: test_selfcons_simple.py 3184 2006-10-02 07:05:45Z pauli $"

def setup():
    global g, Delta
#
# Bulk
# ----
#
# Let's aim to calculate the following value for the energy gap:
#
    Delta = 60
#
# in a single S wire connected between N and S terminals:
#
    g = Geometry(1, 2)
    g.t_type = [NODE_CLEAN_S_TERMINAL, NODE_CLEAN_N_TERMINAL]
    g.t_delta = Delta
    g.t_inelastic = 1e-9
    g.t_t = 0.4*Delta
    g.t_mu = 0
    g.w_type = [WIRE_TYPE_S]
    g.w_length = 1
    g.w_conductance = 1
    g.w_inelastic = 1e-9
    g.w_spinflip = 0
    g.w_ends[0,:] = [0,1]
#
# We need to fill in these parameters:
#
    g.omega_D = 4000
    g.coupling_lambda = 1/arccosh(g.omega_D/Delta)
#
# Also, put in some nasty initial guess for the wire delta and phase:
#
    g.w_phase = 0
    g.w_delta = 10
#
# So, solve the problem
#

def test():
    global g, Delta

    g.w_phase = 0
    g.w_delta = 10
    sol = CurrentSolver(g, ne=400, chunksize=400)
    sol.set_solvers(sp_solver=SP_SOLVER_COLNEW)
    it = self_consistent_realtime_iteration(sol)
    for k, d, v in it:
        print("Residual:", d.residual_norm())
        if d.residual_norm() < 1e-3:
            break
    else:
        raise RuntimeError("Iteration did not converge")

    realtime_delta = g.w_delta.copy()

    g.w_phase = 0
    g.w_delta = 10
    it = self_consistent_matsubara_iteration(g, cutoff_elimination=True, max_ne=50)
    for k, d, v in it:
        print("Residual:", d.residual_norm())
        if d.residual_norm() < 1e-3:
            break
    else:
        raise RuntimeError("Iteration did not converge")

    matsubara_delta = g.w_delta.copy()

    g.w_phase = 0
    g.w_delta = 10
    it = self_consistent_matsubara_iteration(g, cutoff_elimination=False, max_ne=10 + int(g.omega_D.max()/g.t_t.max()))
    for k, d, v in it:
        print("Residual:", d.residual_norm())
        if d.residual_norm() < 1e-3:
            break
    else:
        raise RuntimeError("Iteration did not converge")
    matsubara_delta_2 = g.w_delta.copy()

    print(around(realtime_delta, 3))
    print(around(matsubara_delta, 3))
    print(around(matsubara_delta_2, 3))

    assert allclose(realtime_delta, matsubara_delta, atol=0.5, rtol=5e-2)
    assert allclose(matsubara_delta_2, matsubara_delta, atol=0.05, rtol=5e-2)

    assert is_zero(abs(realtime_delta[0,0]) - Delta, tolerance=5)
    assert is_zero(abs(realtime_delta[0,-1]), tolerance=0.5)

    assert is_zero(abs(matsubara_delta[0,0]) - Delta, tolerance=5)
    assert is_zero(abs(matsubara_delta[0,-1]), tolerance=0.5)


def test_selfcons_s():
    # Check that the BCS gap relation is reproduced

    Delta = 1.0

    g = Geometry(1, 2)
    g.t_type = [NODE_FREE_INTERFACE, NODE_FREE_INTERFACE]
    g.t_delta = 0.0
    g.t_inelastic = 1e-9
    g.t_t = 0.1
    g.t_mu = 0
    g.w_type = [WIRE_TYPE_S]
    g.w_length = 1
    g.w_conductance = 1
    g.w_inelastic = 1e-9
    g.w_spinflip = 0
    g.w_ends[0,:] = [0,1]
    g.omega_D = 600
    g.coupling_lambda = 1/arccosh(g.omega_D/Delta)
    g.w_phase = 0
    g.w_delta = 1.1


    # Solve
    Ts = [0.01, 0.1, 0.4, 0.5, 0.55, 0.56, 0.57, 0.6]

    for T in Ts:
        g.t_t = T

        Delta_bulk = bulk_delta(g.coupling_lambda.mean(), g.omega_D.mean(), T)

        print("T = ", T)
        it = self_consistent_matsubara_iteration(g, real_delta=True)
        for k, d, v in it:
            print("Residual:", d.residual_norm())
            if d.residual_norm() < 1e-7:
                break
        else:
            raise RuntimeError("Iteration did not converge")

        assert_allclose(g.w_delta, Delta_bulk, rtol=0, atol=0.05)

        g.w_delta = 1.1
        it = self_consistent_matsubara_iteration(g, real_delta=True, cutoff_elimination=False, max_ne=1000)
        for k, d, v in it:
            print("Residual:", d.residual_norm())
            if d.residual_norm() < 1e-7:
                break
        else:
            raise RuntimeError("Iteration did not converge")

        assert_allclose(g.w_delta, Delta_bulk, rtol=0, atol=0.05)


if __name__ == "__main__":
    run_tests()
