#include "inversematrixpoissonsolver2d.h"
#include "constants.h"
#include <string.h>
#include <iostream>
#include <algorithm>

InverseMatrixPoissonSolver2D::InverseMatrixPoissonSolver2D()
{
#ifdef USEGSLFORTESTING
	m_pPerm = 0;
#else
	m_pSymbolic = 0;
	m_pNumeric = 0;
#endif // USEGSLFORTESTING
	m_pPotential = 0;
}

InverseMatrixPoissonSolver2D::~InverseMatrixPoissonSolver2D()
{
#ifdef USEGSLFORTESTING
	if (m_pPerm != 0)
		gsl_permutation_free((gsl_permutation *)m_pPerm);
#else
	if (m_pNumeric != 0)
		klu_free_numeric(&m_pNumeric, &m_common);
	if (m_pSymbolic != 0)
		klu_free_symbolic(&m_pSymbolic, &m_common);
#endif 
}

bool InverseMatrixPoissonSolver2D::init(int numX, int numY, double width, double height, double *pPotential)
{
	if (m_pPotential != 0)
	{
		setErrorString("Already initialized");
		return false;
	}

	if (numX < 1)
	{
		setErrorString("Number of x pixels must be at least 1");
		return false;
	}
	if (numY < 3)
	{
		setErrorString("Number of y pixels must be at least 3");
		return false;
	}

	if (width <= 0 || height <= 0)
	{
		setErrorString("Dimensions must be positive");
		return false;
	}

	m_numX = numX;
	m_numY = numY;
	m_width = width;
	m_height = height;

	double dx = m_width/(double)m_numX;
	double dy = m_height/(double)(m_numY-1);

	m_chargeFactor = -dx*dy*(CHARGE_ELECTRON/CONST_EPSILON0);
	m_dxOverDy = dx/dy;
	m_dyOverDx = dy/dx;

	m_potentialDifference = 0;

	int totalVariablePixels = m_numX*(m_numY-2);

	if (pPotential == 0)
	{
		m_potential.resize(m_numX*m_numY);
		memset(&(m_potential[0]), 0, sizeof(double)*m_potential.size());
		m_pPotential = &(m_potential[0]);
	}
	else
		m_pPotential = pPotential;

#ifdef USEGSLFORTESTING
	m_charge.resize(totalVariablePixels);
#endif
	m_topRowExtra.resize(m_numX);
	//m_bottowRowExtra.resize(m_numX); // not necessary as potential at bottom is zero
#ifdef USEGSLFORTESTING
	m_epsMatrixLU.resize(totalVariablePixels*totalVariablePixels); // TODO: is actually sparse
#else
#ifdef DUMPMATRIX
	m_epsMatrixLU.resize(totalVariablePixels*totalVariablePixels); // TODO: is actually sparse
#endif // DUMPMATRIX
	m_epsSparseElements.resize(totalVariablePixels*5);
#endif // USEGSLFORTESTING
	m_epsMatrix.resize(m_numX*m_numY);

#ifdef USEGSLFORTESTING
	m_pPerm = gsl_permutation_alloc(totalVariablePixels);
	m_epsMatrixLUView = gsl_matrix_view_array(&(m_epsMatrixLU[0]), totalVariablePixels, totalVariablePixels);
	m_chargeView = gsl_vector_view_array(&(m_charge[0]), totalVariablePixels);
	m_potentialView = gsl_vector_view_array(&(m_pPotential[m_numX]), totalVariablePixels);
#else
	m_nonZeroRows.resize(totalVariablePixels*5);
	m_colPointers.resize(totalVariablePixels+1);

	int count = 0;
	int col;

	int xDiff[5] = { -1, +1, 0,  0,  0 };
	int yDiff[5] = {  0,  0, 0, -1, +1 };

	for (col = 0 ; col < totalVariablePixels ; col++)
	{
		m_colPointers[col] = count;

		int colX = col%m_numX;
		int colY = col/m_numX + 1;
		
		int indices[5];
		int idxCount = 0;

		for (int k = 0 ; k < 5 ; k++)
		{
			int rowX = (colX + xDiff[k] + m_numX)%m_numX;
			int rowY = colY + yDiff[k];

			if (rowY > 0 && rowY < m_numY-1)
				indices[idxCount++] = rowX + (rowY-1)*m_numX;
		}

		std::sort(indices, indices + idxCount);
	
		for (int k = 0 ; k < idxCount ; k++)
		{
			m_nonZeroRows[count] = indices[k];
			count++;
		}
	}

	m_colPointers[col] = count;

	klu_defaults (&m_common);

#endif // USEGSLFORTESTING

	m_hasLU = false;

	return true;
}

bool InverseMatrixPoissonSolver2D::setRelativePermittivity(const double *pEps)
{
	if (m_pPotential == 0)
	{
		setErrorString("Not initialized");
		return false;
	}

	int idx = 0;

	memcpy(&(m_epsMatrix[0]), pEps, sizeof(double)*m_numX*m_numY);

	int variablePixels = (m_numY-2)*m_numX;

#ifdef USEGSLFORTESTING
	for (int i = 0 ; i < variablePixels ; i++)
	{
		int dstX = i%m_numX;
		int dstY = i/m_numX + 1;

		int j2 = vX + vY*m_numX;

		for (int j = 0 ; j < variablePixels ; j++)
		{
			int vX = j%m_numX;
			int vY = j/m_numX + 1;

			int i2 = dstX + dstY*m_numX;

			double value = 0;

			if (dstX == vX && dstY == vY)
				value = -2.0*(m_dyOverDx + m_dxOverDy)*m_epsMatrix[i2];
			else if ((dstX+1)%m_numX == vX && dstY == vY)
				value = m_dyOverDx*0.5*(m_epsMatrix[i2] + m_epsMatrix[j2]);
			else if ((dstX-1+m_numX)%m_numX == vX && dstY == vY)
				value = m_dyOverDx*0.5*(m_epsMatrix[i2] + m_epsMatrix[j2]);
			else if (dstX == vX && dstY+1 == vY)
				value = m_dxOverDy*0.5*(m_epsMatrix[i2] + m_epsMatrix[j2]);
			else if (dstX == vX && dstY-1 == vY)
				value = m_dxOverDy*0.5*(m_epsMatrix[i2] + m_epsMatrix[j2]);

			m_epsMatrixLU[i*variablePixels+j] = value;
		}
	}
#else
	// Fill in values

#ifdef DUMPMATRIX
	memset(&(m_epsMatrixLU[0]), 0, sizeof(double)*m_epsMatrixLU.size());
#endif // DUMPMATRIX

	for (int j = 0 ; j < variablePixels ; j++)
	{
		int startPos = m_colPointers[j];
		int endPos = m_colPointers[j+1];

		int vX = j%m_numX;
		int vY = j/m_numX + 1;

		int j2 = vX + vY*m_numX;

		for (int p = startPos ; p < endPos ; p++)
		{
			int i = m_nonZeroRows[p];

			int dstX = i%m_numX;
			int dstY = i/m_numX + 1;

			int i2 = dstX + dstY*m_numX;

			double value = 0;

			if (dstX == vX && dstY == vY)
				value = -2.0*(m_dyOverDx + m_dxOverDy)*m_epsMatrix[i2];
			else if ((dstX+1)%m_numX == vX && dstY == vY)
				value = m_dyOverDx*0.5*(m_epsMatrix[i2] + m_epsMatrix[j2]);
			else if ((dstX-1+m_numX)%m_numX == vX && dstY == vY)
				value = m_dyOverDx*0.5*(m_epsMatrix[i2] + m_epsMatrix[j2]);
			else if (dstX == vX && dstY+1 == vY)
				value = m_dxOverDy*0.5*(m_epsMatrix[i2] + m_epsMatrix[j2]);
			else if (dstX == vX && dstY-1 == vY)
				value = m_dxOverDy*0.5*(m_epsMatrix[i2] + m_epsMatrix[j2]);

			m_epsSparseElements[p] = value;
#ifdef DUMPMATRIX
			m_epsMatrixLU[i*variablePixels+j] = value;
#endif // DUMPMATRIX
		}
	}
#endif // USEGSLFORTESTING

#ifdef DUMPMATRIX
	for (int i = 0 ; i < variablePixels ; i++)
	{
		for (int j = 0 ; j < variablePixels ; j++)
		{
			std::cout << "\t" << m_epsMatrixLU[i*variablePixels+j];
		}
		std::cout << std::endl;
	}
	std::cout << std::endl << std::endl;
#endif // DUMPMATRIX

	int offset1 = (m_numY-2)*m_numX;
	int offset2 = (m_numY-1)*m_numX;

	for (int x = 0 ; x < m_numX ; x++)
		m_topRowExtra[x] = - m_potentialDifference * 0.5 * m_dxOverDy * (m_epsMatrix[x + offset1] + m_epsMatrix[x + offset2]);

#ifdef USEGSLFORTESTING
	int status;
	char str[1024];

	status = gsl_linalg_LU_decomp(&(m_epsMatrixLUView.matrix), (gsl_permutation *)m_pPerm, &m_permSign);
	if (status != GSL_SUCCESS)
	{
		sprintf(str, "Error in LU decomposition (GSL error %d)", status);
		setErrorString(str);
		return false;
	}
#else
	if (m_pNumeric != 0)
	{
		klu_free_numeric(&m_pNumeric, &m_common);
		m_pNumeric = 0;
	}
	if (m_pSymbolic != 0)
	{
		klu_free_symbolic(&m_pSymbolic, &m_common);
		m_pSymbolic = 0;
	}
	m_pSymbolic = klu_analyze(variablePixels, &(m_colPointers[0]), &(m_nonZeroRows[0]), &m_common);
	m_pNumeric = klu_factor(&(m_colPointers[0]), &(m_nonZeroRows[0]), &(m_epsSparseElements[0]), m_pSymbolic, &m_common);
#endif // USEGSLFORTESTING

	m_hasLU = true;

	return true;
}

bool InverseMatrixPoissonSolver2D::setPotentialDifference(double vDiff)
{
	if (m_pPotential == 0)
	{
		setErrorString("Not initialized");
		return false;
	}

	m_potentialDifference = vDiff;

	int offset1 = (m_numY-2)*m_numX;
	int offset2 = (m_numY-1)*m_numX;

	for (int x = 0 ; x < m_numX ; x++)
	{
		m_pPotential[x] = 0;
		m_pPotential[x+offset2] = m_potentialDifference;
		m_topRowExtra[x] = - m_potentialDifference * 0.5 * m_dxOverDy * (m_epsMatrix[x + offset1] + m_epsMatrix[x + offset2]);
	}

	return true;
}

bool InverseMatrixPoissonSolver2D::findPotential(const double *pNetNumberCharge)
{
	if (m_pPotential == 0)
	{
		setErrorString("Not initialized");
		return false;
	}

	if (!m_hasLU)
	{
		setErrorString("Nu LU decomposition is available yet, please set the relative permittivity first");
		return false;
	}

	int idx = 0;

	for (int y = 1 ; y < m_numY-1 ; y++)
	{
		for (int x = 0 ; x < m_numX ; x++, idx++)
		{
			int pos = x+y*m_numX;

#ifdef USEGSLFORTESTING
			m_charge[idx] = (pNetNumberCharge[pos]) * m_chargeFactor;
#else
			m_pPotential[pos] = (pNetNumberCharge[pos]) * m_chargeFactor;
#endif //USEGSLFORTESTING
		}
	}

#ifdef USEGSLFORTESTING
	int offset = m_numX * (m_numY-3); // grid with variable pixels is (m_numY-2)*m_numX pixels large

	for (int x = 0 ; x < m_numX ; x++)
		m_charge[x+offset] += m_topRowExtra[x];
#else
	int offset = m_numX * (m_numY-2);

	for (int x = 0 ; x < m_numX ; x++)
		m_pPotential[x+offset] += m_topRowExtra[x];
#endif //USEGSLFORTESTING

#ifdef USEGSLFORTESTING
	int status;
	char str[1024];
	
	status = gsl_linalg_LU_solve(&(m_epsMatrixLUView.matrix), (gsl_permutation *)m_pPerm, &(m_chargeView.vector), &(m_potentialView.vector));
	if (status != GSL_SUCCESS)
	{
		sprintf(str, "Error solving LU system (GSL error %d)", status);
		setErrorString(str);
		return false;
	}
#else
	int variablePixels = (m_numY-2)*m_numX;

	klu_solve(m_pSymbolic, m_pNumeric, variablePixels, 1, &(m_pPotential[m_numX]), &m_common);
#endif // USEGSLFORTESTING
		
	return true;
}

bool InverseMatrixPoissonSolver2D::findPotential(const double *pPosNumberCharge, const double *pNegNumberCharge)
{
	if (m_pPotential == 0)
	{
		setErrorString("Not initialized");
		return false;
	}

	if (!m_hasLU)
	{
		setErrorString("Nu LU decomposition is available yet, please set the relative permittivity first");
		return false;
	}

	int idx = 0;

	for (int y = 1 ; y < m_numY-1 ; y++)
	{
		for (int x = 0 ; x < m_numX ; x++, idx++)
		{
			int pos = x+y*m_numX;

#ifdef USEGSLFORTESTING
			m_charge[idx] = (pPosNumberCharge[pos] - pNegNumberCharge[pos]) * m_chargeFactor;
#else
			m_pPotential[pos] = (pPosNumberCharge[pos] - pNegNumberCharge[pos]) * m_chargeFactor;
#endif //USEGSLFORTESTING
		}
	}

#ifdef USEGSLFORTESTING
	int offset = m_numX * (m_numY-3); // grid with variable pixels is (m_numY-2)*m_numX pixels large

	for (int x = 0 ; x < m_numX ; x++)
		m_charge[x+offset] += m_topRowExtra[x];
#else
	int offset = m_numX * (m_numY-2);

	for (int x = 0 ; x < m_numX ; x++)
		m_pPotential[x+offset] += m_topRowExtra[x];
#endif //USEGSLFORTESTING

#ifdef USEGSLFORTESTING
	int status;
	char str[1024];
	
	status = gsl_linalg_LU_solve(&(m_epsMatrixLUView.matrix), (gsl_permutation *)m_pPerm, &(m_chargeView.vector), &(m_potentialView.vector));
	if (status != GSL_SUCCESS)
	{
		sprintf(str, "Error solving LU system (GSL error %d)", status);
		setErrorString(str);
		return false;
	}
#else
	int variablePixels = (m_numY-2)*m_numX;

	klu_solve(m_pSymbolic, m_pNumeric, variablePixels, 1, &(m_pPotential[m_numX]), &m_common);
#endif // USEGSLFORTESTING
		
	return true;
}

bool InverseMatrixPoissonSolver2D::findPotential(const double *pPosNumberCharge, const double *pNegNumberCharge, const double *pBgNumberCharge)
{
	if (m_pPotential == 0)
	{
		setErrorString("Not initialized");
		return false;
	}

	if (!m_hasLU)
	{
		setErrorString("Nu LU decomposition is available yet, please set the relative permittivity first");
		return false;
	}

	int idx = 0;

	for (int y = 1 ; y < m_numY-1 ; y++)
	{
		for (int x = 0 ; x < m_numX ; x++, idx++)
		{
			int pos = x+y*m_numX;

#ifdef USEGSLFORTESTING
			m_charge[idx] = (pPosNumberCharge[pos] - pNegNumberCharge[pos] + pBgNumberCharge[pos]) * m_chargeFactor;
#else
			m_pPotential[pos] = (pPosNumberCharge[pos] - pNegNumberCharge[pos] + pBgNumberCharge[pos]) * m_chargeFactor;
#endif //USEGSLFORTESTING
		}
	}

#ifdef USEGSLFORTESTING
	int offset = m_numX * (m_numY-3); // grid with variable pixels is (m_numY-2)*m_numX pixels large

	for (int x = 0 ; x < m_numX ; x++)
		m_charge[x+offset] += m_topRowExtra[x];
#else
	int offset = m_numX * (m_numY-2);

	for (int x = 0 ; x < m_numX ; x++)
		m_pPotential[x+offset] += m_topRowExtra[x];
#endif //USEGSLFORTESTING

#ifdef USEGSLFORTESTING
	int status;
	char str[1024];
	
	status = gsl_linalg_LU_solve(&(m_epsMatrixLUView.matrix), (gsl_permutation *)m_pPerm, &(m_chargeView.vector), &(m_potentialView.vector));
	if (status != GSL_SUCCESS)
	{
		sprintf(str, "Error solving LU system (GSL error %d)", status);
		setErrorString(str);
		return false;
	}
#else
	int variablePixels = (m_numY-2)*m_numX;

	klu_solve(m_pSymbolic, m_pNumeric, variablePixels, 1, &(m_pPotential[m_numX]), &m_common);
#endif // USEGSLFORTESTING
		
	return true;
}

