#include "testsimulator2d.h"
#include "simulation.h"
#include <time.h>
#include <cmath>

#include "constants.h" // TODO: for debugging
#include <iostream>

TestSimulator2D::TestSimulator2D()
{
}

TestSimulator2D::~TestSimulator2D()
{
	clear();
}

bool TestSimulator2D::init(Simulation *pSim)
{
	clear();

	m_pSim = pSim;
	m_dtStart = 1e-15;

	int pixels = pSim->getNumXPixels() * pSim->getNumYPixels();

	std::vector<double> n(pixels), p(pixels), V(pixels);

	pSim->getElectronNumberDensity(n);
	pSim->getHoleNumberDensity(p);
	pSim->getPotential(V);

	SnapShot *pSnap = new SnapShot(n, p, V, m_dtStart, 0);
	m_snapShots.push_back(pSnap);

	return true;
}

bool TestSimulator2D::simulate(int maxSeconds)
{
	time_t startTime = time(0);
	time_t endTime = startTime+maxSeconds+1;
	time_t curTime;

	int pixels = m_pSim->getNumXPixels();
	int stepMultiplier = 1;
	double dtMultiplier = 2.0;
	double dt = m_dtStart;
	int curStep = 0;
	
	bool converged = false; 
	bool hadNaN = false;

	double prevUpdateN = 1e100;
	double prevUpdateP = 1e100;

	while ((curTime = time(0)) < endTime && !converged)
	{
		time_t timeLeft = endTime-curTime;
		int steps = 1000 * stepMultiplier;
		double stopCriterion = 1e-10;

		if (!hadNaN)
			stopCriterion = 1e-100; // only allow the algorithm to stop when a large enough dt has been reached, but this will make sure a stop occurs when a NaN is detected

		//std::cout << "Running with dt = " << dt << " for " << steps << " steps" << std::endl;
		if (!m_pSim->start(timeLeft, dt, steps, stopCriterion))
		{
			setErrorString("Error running simulation: " + m_pSim->getErrorString());
			return false;
		}

		curStep += steps;
		
		/*
		// for debugging
		{
			double left, right, center, overall, Pavg;

			m_pSim->calculateXCurrent(left, right, overall, center);
			Pavg = m_pSim->calculatePavg();

			left *= CHARGE_ELECTRON;
			right *= CHARGE_ELECTRON;
			overall *= CHARGE_ELECTRON;
			center *= CHARGE_ELECTRON;

			left = left;
			right = right;
			overall = overall;
			center = center;

			std::cerr << "step: " << curStep << " Left: " << left << " Right: " << right << " Center: " << center << " Avg: " << overall << " Pavg: " << Pavg << std::endl;
			std::cerr << "      " << m_pSim->getUpdateNAvg() << " " << m_pSim->getUpdatePAvg() << std::endl;
		}*/

		std::vector<double> n(pixels), p(pixels), V(pixels);

		m_pSim->getElectronNumberDensity(n);
		m_pSim->getHoleNumberDensity(p);
		m_pSim->getPotential(V);

		if (hasNaN(n) || hasNaN(p) || hasNaN(V) || m_pSim->getUpdateNAvg() > prevUpdateN*10.0 || m_pSim->getUpdatePAvg() > prevUpdateP*10.0)
		{
			hadNaN = true;

			if (m_snapShots.size() == 1)
			{
				setErrorString("Can't simulate using miminal time step");
				return false;
			}
			else
			{
				int idx = m_snapShots.size() - 1;
				SnapShot *pSnap = 0;
				bool found = false;

				do
				{
					pSnap = m_snapShots[idx];

					if (pSnap->getDt() < dt)
						found = true;
					else
						idx--;
				} while (!found && idx >= 0);

				if (!found)
				{
					setErrorString("Detected NaN using a previously successful step size, but no useful starting point can be found");
					return false;
				}

				m_pSim->setElectronNumberDensity(pSnap->getN());
				m_pSim->setHoleNumberDensity(pSnap->getP());
				m_pSim->setPotential(pSnap->getV());
			
				curStep = pSnap->getStep();
				dt = pSnap->getDt();
				pSnap->increaseStepMultiplier();
				stepMultiplier = pSnap->getStepMultiplier();

				if (idx < m_snapShots.size() - 1)
				{
					for (int i = idx+1 ; i < m_snapShots.size() ; i++)
						delete m_snapShots[i];
					m_snapShots.resize(idx+1);
					//std::cerr << "Resized snapshot list" << std::endl;

					dtMultiplier = 1.0; // make sure this is the final step size
					stepMultiplier = 1000000;
				}
			}
		}
		else // simulation is still ok, save snapshot and try larger step
		{
			//std::cout << m_pSim->getUpdateNAvg() << " " << m_pSim->getUpdatePAvg() << " " << std::endl;
			if ((hadNaN && m_pSim->getUpdateNAvg() < stopCriterion && m_pSim->getUpdatePAvg() < stopCriterion) ||
			    (std::abs(prevUpdateN - m_pSim->getUpdateNAvg())/m_pSim->getUpdateNAvg() < 1e-5 &&
			     std::abs(prevUpdateP - m_pSim->getUpdatePAvg())/m_pSim->getUpdatePAvg() < 1e-5 && 
			     std::abs(prevUpdateN - m_pSim->getUpdateNAvg())/prevUpdateN < 1e-5 &&
			     std::abs(prevUpdateP - m_pSim->getUpdatePAvg())/prevUpdateP < 1e-5 ) )
			{
				converged = true;
				std::cout << prevUpdateN << std::endl;
				std::cout << prevUpdateP << std::endl;
				std::cout << m_pSim->getUpdateNAvg() << std::endl;
				std::cout << m_pSim->getUpdatePAvg() << std::endl;
			}
			else
			{
				prevUpdateP = m_pSim->getUpdatePAvg();
				prevUpdateN = m_pSim->getUpdateNAvg();

				SnapShot *pSnap = new SnapShot(n, p, V, dt, curStep);
				SnapShot *pLastSnapShot = m_snapShots[m_snapShots.size()-1];

				if (pLastSnapShot->getDt() == dt)
				{
					int mult = pLastSnapShot->getStepMultiplier();
					pSnap->setStepMultiplier(mult);
				}
				m_snapShots.push_back(pSnap);

				dt *= dtMultiplier;

				/*
				if (dt > 1e-10)
				{
					dt = 1e-10;
					hadNaN = true;
				}
				*/
				stepMultiplier = 1; // we'll briefly try an increased step size
			}
		}
	}

	if (!converged)
	{
		setErrorString("Failed to converge in specified time");
		return false;
	}

	return true;
}

void TestSimulator2D::clear()
{
	for (int i = 0 ; i < m_snapShots.size() ; i++)
		delete m_snapShots[i];
	m_snapShots.clear();
}

bool TestSimulator2D::hasNaN(const std::vector<double> &v)
{
	for (int i = 0 ; i < v.size() ; i++)
	{
		if (!std::isfinite(v[i]))
		{ 
			//std::cout << v[i] << std::endl;
			return true;
		}
	}

	return false;
}

bool TestSimulator2D::hasNeg(const std::vector<double> &v)
{
	for (int i = 0 ; i < v.size() ; i++)
	{
		if (v[i] < 0)
		{ 
			//std::cout << v[i] << std::endl;
			return true;
		}
	}

	return false;
}


