#include "inversematrixpoissonsolver1d.h"
#include "constants.h"
#include <string.h>
#include <iostream>
#include <algorithm>

InverseMatrixPoissonSolver1D::InverseMatrixPoissonSolver1D()
{
#ifdef USEGSLFORTESTING
	m_pPerm = 0;
#else
	m_pSymbolic = 0;
	m_pNumeric = 0;
#endif // USEGSLFORTESTING
	m_pPotential = 0;
}

InverseMatrixPoissonSolver1D::~InverseMatrixPoissonSolver1D()
{
#ifdef USEGSLFORTESTING
	if (m_pPerm != 0)
		gsl_permutation_free((gsl_permutation *)m_pPerm);
#else
	if (m_pNumeric != 0)
		klu_free_numeric(&m_pNumeric, &m_common);
	if (m_pSymbolic != 0)
		klu_free_symbolic(&m_pSymbolic, &m_common);
#endif 
}

bool InverseMatrixPoissonSolver1D::init(int numX, double width, double *pPotential)
{
	if (m_pPotential != 0)
	{
		setErrorString("Already initialized");
		return false;
	}

	if (numX < 3)
	{
		setErrorString("Number of x pixels must be at least 3");
		return false;
	}

	if (width <= 0)
	{
		setErrorString("Dimension must be positive");
		return false;
	}

	m_numX = numX;
	m_width = width;

	double dx = m_width/(double)(m_numX-1);

	m_chargeFactor = -dx*dx*(CHARGE_ELECTRON/CONST_EPSILON0);

	m_potentialDifference = 0;

	int totalVariablePixels = m_numX-2;

	if (pPotential == 0)
	{
		m_potential.resize(m_numX);
		memset(&(m_potential[0]), 0, sizeof(double)*m_potential.size());
		m_pPotential = &(m_potential[0]);
	}
	else
		m_pPotential = pPotential;

#ifdef USEGSLFORTESTING
	m_charge.resize(totalVariablePixels);
#endif
	m_topPixelExtra = 0;
#ifdef USEGSLFORTESTING
	m_epsMatrixLU.resize(totalVariablePixels*totalVariablePixels); // TODO: is actually sparse
#else
#ifdef DUMPMATRIX
	m_epsMatrixLU.resize(totalVariablePixels*totalVariablePixels); // TODO: is actually sparse
#endif // DUMPMATRIX
	m_epsSparseElements.resize(totalVariablePixels*3);
#endif // USEGSLFORTESTING
	m_epsMatrix.resize(m_numX);

#ifdef USEGSLFORTESTING
	m_pPerm = gsl_permutation_alloc(totalVariablePixels);
	m_epsMatrixLUView = gsl_matrix_view_array(&(m_epsMatrixLU[0]), totalVariablePixels, totalVariablePixels);
	m_chargeView = gsl_vector_view_array(&(m_charge[0]), totalVariablePixels);
	m_potentialView = gsl_vector_view_array(&(m_pPotential[1]), totalVariablePixels);
#else
	m_nonZeroRows.resize(totalVariablePixels*3);
	m_colPointers.resize(totalVariablePixels+1);

	int count = 0;
	int col;

	int xDiff[3] = { -1, 0,  1 };

	for (col = 0 ; col < totalVariablePixels ; col++)
	{
		m_colPointers[col] = count;

		int colX = col + 1;
		
		int indices[3];
		int idxCount = 0;

		for (int k = 0 ; k < 3 ; k++)
		{
			int rowX = colX + xDiff[k];

			if (rowX > 0 && rowX < m_numX-1)
				indices[idxCount++] = rowX - 1;
		}

		std::sort(indices, indices + idxCount);
	
		for (int k = 0 ; k < idxCount ; k++)
		{
			m_nonZeroRows[count] = indices[k];
			count++;
		}
	}

	m_colPointers[col] = count;

	klu_defaults (&m_common);

#endif // USEGSLFORTESTING

	m_hasLU = false;

	return true;
}

bool InverseMatrixPoissonSolver1D::setRelativePermittivity(const double *pEps)
{
	if (m_pPotential == 0)
	{
		setErrorString("Not initialized");
		return false;
	}

	int idx = 0;

	memcpy(&(m_epsMatrix[0]), pEps, sizeof(double)*m_numX);

	int variablePixels = m_numX-2;

#ifdef USEGSLFORTESTING
	for (int i = 0 ; i < variablePixels ; i++)
	{
		int dstX = i + 1;

		for (int j = 0 ; j < variablePixels ; j++)
		{
			int vX = j + 1;

			double value = 0;

			if (dstX == vX)
				value = -2.0*m_epsMatrix[dstX];
			else if (dstX+1 == vX)
				value = 0.5*(m_epsMatrix[dstX] + m_epsMatrix[vX]);
			else if (dstX-1 == vX)
				value = 0.5*(m_epsMatrix[dstX] + m_epsMatrix[vX]);

			m_epsMatrixLU[i*variablePixels+j] = value;
		}
	}
#else
	// Fill in values

#ifdef DUMPMATRIX
	memset(&(m_epsMatrixLU[0]), 0, sizeof(double)*m_epsMatrixLU.size());
#endif // DUMPMATRIX

	for (int j = 0 ; j < variablePixels ; j++)
	{
		int startPos = m_colPointers[j];
		int endPos = m_colPointers[j+1];

		int vX = j + 1;

		for (int p = startPos ; p < endPos ; p++)
		{
			int i = m_nonZeroRows[p];

			int dstX = i + 1;

			double value = 0;

			if (dstX == vX)
				value = -2.0*m_epsMatrix[dstX];
			else if (dstX+1 == vX)
				value = 0.5*(m_epsMatrix[dstX] + m_epsMatrix[vX]);
			else if (dstX-1 == vX)
				value = 0.5*(m_epsMatrix[dstX] + m_epsMatrix[vX]);

			m_epsSparseElements[p] = value;
#ifdef DUMPMATRIX
			m_epsMatrixLU[i*variablePixels+j] = value;
#endif // DUMPMATRIX
		}
	}
#endif // USEGSLFORTESTING

#ifdef DUMPMATRIX
	for (int i = 0 ; i < variablePixels ; i++)
	{
		for (int j = 0 ; j < variablePixels ; j++)
		{
			std::cout << "\t" << m_epsMatrixLU[i*variablePixels+j];
		}
		std::cout << std::endl;
	}
	std::cout << std::endl << std::endl;
#endif // DUMPMATRIX

	m_topPixelExtra = - m_potentialDifference * 0.5 * (m_epsMatrix[m_numX-2] + m_epsMatrix[m_numX-1]);

#ifdef USEGSLFORTESTING
	int status;
	char str[1024];

	status = gsl_linalg_LU_decomp(&(m_epsMatrixLUView.matrix), (gsl_permutation *)m_pPerm, &m_permSign);
	if (status != GSL_SUCCESS)
	{
		sprintf(str, "Error in LU decomposition (GSL error %d)", status);
		setErrorString(str);
		return false;
	}
#else
	if (m_pNumeric != 0)
	{
		klu_free_numeric(&m_pNumeric, &m_common);
		m_pNumeric = 0;
	}
	if (m_pSymbolic != 0)
	{
		klu_free_symbolic(&m_pSymbolic, &m_common);
		m_pSymbolic = 0;
	}
	m_pSymbolic = klu_analyze(variablePixels, &(m_colPointers[0]), &(m_nonZeroRows[0]), &m_common);
	m_pNumeric = klu_factor(&(m_colPointers[0]), &(m_nonZeroRows[0]), &(m_epsSparseElements[0]), m_pSymbolic, &m_common);
#endif // USEGSLFORTESTING

	m_hasLU = true;

	return true;
}

bool InverseMatrixPoissonSolver1D::setPotentialDifference(double vDiff)
{
	if (m_pPotential == 0)
	{
		setErrorString("Not initialized");
		return false;
	}

	m_potentialDifference = vDiff;

	m_pPotential[0] = 0;
	m_pPotential[m_numX-1] = m_potentialDifference;
	m_topPixelExtra = - m_potentialDifference * 0.5 * (m_epsMatrix[m_numX-2] + m_epsMatrix[m_numX-1]);

	return true;
}

bool InverseMatrixPoissonSolver1D::findPotential(const double *pNetNumberCharge)
{
	if (m_pPotential == 0)
	{
		setErrorString("Not initialized");
		return false;
	}

	if (!m_hasLU)
	{
		setErrorString("Nu LU decomposition is available yet, please set the relative permittivity first");
		return false;
	}

	int idx = 0;

	for (int x = 1 ; x < m_numX-1 ; x++, idx++)
	{
		int pos = x;
#ifdef USEGSLFORTESTING
		m_charge[idx] = (pNetNumberCharge[pos]) * m_chargeFactor;
#else
		m_pPotential[pos] = (pNetNumberCharge[pos]) * m_chargeFactor;
#endif //USEGSLFORTESTING
	}

#ifdef USEGSLFORTESTING
	m_charge[m_numX-3] += m_topPixelExtra;
#else
	m_pPotential[m_numX-2] += m_topPixelExtra;
#endif //USEGSLFORTESTING

#ifdef USEGSLFORTESTING
	int status;
	char str[1024];
	
	status = gsl_linalg_LU_solve(&(m_epsMatrixLUView.matrix), (gsl_permutation *)m_pPerm, &(m_chargeView.vector), &(m_potentialView.vector));
	if (status != GSL_SUCCESS)
	{
		sprintf(str, "Error solving LU system (GSL error %d)", status);
		setErrorString(str);
		return false;
	}
#else
	int variablePixels = (m_numX-2);

	klu_solve(m_pSymbolic, m_pNumeric, variablePixels, 1, &(m_pPotential[1]), &m_common);
#endif // USEGSLFORTESTING
		
	return true;
}

bool InverseMatrixPoissonSolver1D::findPotential(const double *pPosNumberCharge, const double *pNegNumberCharge)
{
	if (m_pPotential == 0)
	{
		setErrorString("Not initialized");
		return false;
	}

	if (!m_hasLU)
	{
		setErrorString("Nu LU decomposition is available yet, please set the relative permittivity first");
		return false;
	}

	int idx = 0;

	for (int x = 1 ; x < m_numX-1 ; x++, idx++)
	{
		int pos = x;
#ifdef USEGSLFORTESTING
		m_charge[idx] = (pPosNumberCharge[pos] - pNegNumberCharge[pos]) * m_chargeFactor;
#else
		m_pPotential[pos] = (pPosNumberCharge[pos] - pNegNumberCharge[pos]) * m_chargeFactor;
#endif //USEGSLFORTESTING
	}

#ifdef USEGSLFORTESTING
	m_charge[m_numX-3] += m_topPixelExtra;
#else
	m_pPotential[m_numX-2] += m_topPixelExtra;
#endif //USEGSLFORTESTING

#ifdef USEGSLFORTESTING
	int status;
	char str[1024];
	
	status = gsl_linalg_LU_solve(&(m_epsMatrixLUView.matrix), (gsl_permutation *)m_pPerm, &(m_chargeView.vector), &(m_potentialView.vector));
	if (status != GSL_SUCCESS)
	{
		sprintf(str, "Error solving LU system (GSL error %d)", status);
		setErrorString(str);
		return false;
	}
#else
	int variablePixels = (m_numX-2);

	klu_solve(m_pSymbolic, m_pNumeric, variablePixels, 1, &(m_pPotential[1]), &m_common);
#endif // USEGSLFORTESTING
		
	return true;
}

bool InverseMatrixPoissonSolver1D::findPotential(const double *pPosNumberCharge, const double *pNegNumberCharge, const double *pBgNumberCharge)
{
	if (m_pPotential == 0)
	{
		setErrorString("Not initialized");
		return false;
	}

	if (!m_hasLU)
	{
		setErrorString("Nu LU decomposition is available yet, please set the relative permittivity first");
		return false;
	}

	int idx = 0;

	for (int x = 1 ; x < m_numX-1 ; x++, idx++)
	{
		int pos = x;
#ifdef USEGSLFORTESTING
		m_charge[idx] = (pPosNumberCharge[pos] - pNegNumberCharge[pos] + pBgNumberCharge[pos]) * m_chargeFactor;
#else
		m_pPotential[pos] = (pPosNumberCharge[pos] - pNegNumberCharge[pos] + pBgNumberCharge[pos]) * m_chargeFactor;
#endif //USEGSLFORTESTING
	}

#ifdef USEGSLFORTESTING
	m_charge[m_numX-3] += m_topPixelExtra;
#else
	m_pPotential[m_numX-2] += m_topPixelExtra;
#endif //USEGSLFORTESTING

#ifdef USEGSLFORTESTING
	int status;
	char str[1024];
	
	status = gsl_linalg_LU_solve(&(m_epsMatrixLUView.matrix), (gsl_permutation *)m_pPerm, &(m_chargeView.vector), &(m_potentialView.vector));
	if (status != GSL_SUCCESS)
	{
		sprintf(str, "Error solving LU system (GSL error %d)", status);
		setErrorString(str);
		return false;
	}
#else
	int variablePixels = (m_numX-2);

	klu_solve(m_pSymbolic, m_pNumeric, variablePixels, 1, &(m_pPotential[1]), &m_common);
#endif // USEGSLFORTESTING
		
	return true;
}

