union floatUnion
{
	float m_float;
	uint m_int;
};

#define LIM 1e-5

float getSGCurrent(float v, float D, float delta, float n1, float n2)
{
	float vDelta = v*delta;
	float x = vDelta/D;
	float j = 0;

	if (x < -LIM)
	{
		float factor = exp(-x);
		float n2factor = n2*factor;
		float oneMinusFactor = 1.0f-factor;
		float n1MinusN2Factor = n1-n2factor;
		float n1MinusN2FactorOverOneMinusFactor = n1MinusN2Factor/oneMinusFactor;

		j = v*n1MinusN2FactorOverOneMinusFactor;
	}
	else if (x > LIM)
	{
		float factor = exp(x);
		float factorMinusOne = factor - 1.0f;
		float n1factor = n1*factor;
		float n1factorMinusN2 = n1factor - n2;
		float n1factorMinusN2OverFactorMinusOne = n1factorMinusN2/factorMinusOne;

		j = v*n1factorMinusN2OverFactorMinusOne;
	}
	else
	{
		x /= 2.0f;

		float twoD = 2.0f*D;
		float factor = twoD/delta;
		float n1MinusN2 = n1-n2;
		float n1PlusN2 = n1+n2;
		float term1 = 0.5f*n1MinusN2;
		float term2Part = 0.5f*n1PlusN2;
		float term2 = term2Part*x;
		float xSquared = x*x;
		float term3Part = n1MinusN2*xSquared;
		float term3 = term3Part/6.0f;

		float seriesPart = term1 + term2;
		float series = seriesPart/* + term3*/;

		j = factor*series;
	}

	return j;
}

float2 getXRelCurrents(int idx,
			  __global const float *pV, 
			  __global const float *pExBaseN, __global const float *pExBaseP,
                          __global const float *pDe, __global const float *pDh,
			  __global const float *pEMob, __global const float *pHMob,
			  __global const float *pN, __global const float *pP,
			  __global const float *pNBase, __global const float *pPBase,
			  int width, int height)
{
	int x = idx % width;
	int y = idx / width;
	int nextX = (x+1) % width;
	int nextIdx = nextX + y*width;

	float Ex1 = - ( pV[nextIdx] - pV[idx] );
	float De = (pDe[idx]+pDe[nextIdx])/2.0f;
	float Dh = (pDh[idx]+pDh[nextIdx])/2.0f;
	float eMob = (pEMob[idx]+pEMob[nextIdx])/2.0f;
	float hMob = (pHMob[idx]+pHMob[nextIdx])/2.0f;

	float ve0 = -eMob*pExBaseN[idx];
	float vh0 = +hMob*pExBaseP[idx];
	float ve1 = -eMob*Ex1;
	float vh1 = +hMob*Ex1;

	float2 currents;

	currents.x = getSGCurrent(ve0+ve1, De, 1.0f, pN[idx], pN[nextIdx]) + 0.5f*ve1*(pNBase[idx]+pNBase[nextIdx]);
	currents.y = getSGCurrent(vh0+vh1, Dh, 1.0f, pP[idx], pP[nextIdx]) + 0.5f*vh1*(pPBase[idx]+pPBase[nextIdx]);

	return currents;
}

float2 getYRelCurrents(int idx,
			  __global const float *pV, 
			  __global const float *pEyBaseN, __global const float *pEyBaseP,
                          __global const float *pDe, __global const float *pDh,
			  __global const float *pEMob, __global const float *pHMob,
			  __global const float *pN, __global const float *pP,
			  __global const float *pNBase, __global const float *pPBase,
			  int width, int height, float pixelFrac, float pixelFracInv)
{
	int x = idx % width;
	int y = idx / width;
	int nextY = y + 1 - y/(height-1); // if no next y is available, this just yields the same y, we won't be using this calculation, so it doesn't really matter
	int nextIdx = x + nextY*width;

	float Ey1 = -(pV[nextIdx]-pV[idx])*pixelFrac;
	float De = (pDe[idx]+pDe[nextIdx])/2.0f;
	float Dh = (pDh[idx]+pDh[nextIdx])/2.0f;
	float eMob = (pEMob[idx]+pEMob[nextIdx])/2.0f;
	float hMob = (pHMob[idx]+pHMob[nextIdx])/2.0f;

	float ve0 = -eMob*pEyBaseN[idx];
	float vh0 = +hMob*pEyBaseP[idx];
	float ve1 = -eMob*Ey1;
	float vh1 = +hMob*Ey1;

	float2 currents;

	currents.x = getSGCurrent(ve0+ve1, De, pixelFracInv, pN[idx], pN[nextIdx]) + 0.5f*ve1*(pNBase[idx]+pNBase[nextIdx]);
	currents.y = getSGCurrent(vh0+vh1, Dh, pixelFracInv, pP[idx], pP[nextIdx]) + 0.5f*vh1*(pPBase[idx]+pPBase[nextIdx]);

	return currents;
}

__kernel void currentsKernel(__global const float *pV, 
			  __global const float *pExBaseN, __global const float *pEyBaseN,
			  __global const float *pExBaseP, __global const float *pEyBaseP,
                          __global const float *pDe, __global const float *pDh,
			  __global const float *pEMob, __global const float *pHMob,
			  __global const float *pN, __global const float *pP,
			  __global const float *pNBase, __global const float *pPBase,
			  __global float *pECurX, __global float *pECurY,
			  __global float *pHCurX, __global float *pHCurY,
			  int width, int height, float pixelFrac, float pixelFracInv)
{
	int idx = get_global_id(0);

	float2 xCurrents = getXRelCurrents(idx, pV, pExBaseN, pExBaseP, pDe, pDh, pEMob, pHMob, pN, pP, pNBase, pPBase, width, height);
	float2 yCurrents = getYRelCurrents(idx, pV, pEyBaseN, pEyBaseP, pDe, pDh, pEMob, pHMob, pN, pP, pNBase, pPBase, width, height, pixelFrac, pixelFracInv);

	pECurX[idx] = xCurrents.x;
	pHCurX[idx] = xCurrents.y;
	pECurY[idx] = yCurrents.x;
	pHCurY[idx] = yCurrents.y;
}

__kernel void updateDensitiesKernel(__global const float *pV, 
			  __global const float *pExBaseN, __global const float *pEyBaseN,
			  __global const float *pExBaseP, __global const float *pEyBaseP,
                          __global const float *pDe, __global const float *pDh,
			  __global const float *pEMob, __global const float *pHMob,
			  __global const float *pGRminJDivE,
			  __global const float *pGRminJDivH,
			  __global const float *pRN,
			  __global const float *pRP, __global const float *pRecFactor,
			  __global const float *pN, __global const float *pP,
			  __global const float *pNBase, __global const float *pPBase,
			  __global float *pNdst, __global float *pPdst,
			  int width, int height,
			  float scaledDt, float pixelFrac, float pixelFracInv)

{
	int idx = get_global_id(0);
	int x = idx % width;
	int y = idx / width;
	int prevX = (x-1+width)%width;
	int prevY = y-1;
	int leftIdx = prevX + y*width;
	int belowIdx = x + prevY*width;

	float2 xCurrentsIdx = getXRelCurrents(idx, pV, pExBaseN, pExBaseP, pDe, pDh, pEMob, pHMob, pN, pP, pNBase, pPBase, width, height);
	float2 yCurrentsIdx = getYRelCurrents(idx, pV, pEyBaseN, pEyBaseP, pDe, pDh, pEMob, pHMob, pN, pP, pNBase, pPBase, width, height, pixelFrac, pixelFracInv);
	float2 xCurrentsLeftIdx = getXRelCurrents(leftIdx, pV, pExBaseN, pExBaseP, pDe, pDh, pEMob, pHMob, pN, pP, pNBase, pPBase, width, height);
	float2 yCurrentsBelowIdx = getYRelCurrents(belowIdx, pV, pEyBaseN, pEyBaseP, pDe, pDh, pEMob, pHMob, pN, pP, pNBase, pPBase, width, height, pixelFrac, pixelFracInv);

	float JexCur = xCurrentsIdx.x;  // .x specifies electron current
	float JexPrev = xCurrentsLeftIdx.x;
	float JeyCur = yCurrentsIdx.x;
	float JeyPrev = yCurrentsBelowIdx.x;

	float JhxCur = xCurrentsIdx.y; // .y specifies hole current
	float JhxPrev = xCurrentsLeftIdx.y;
	float JhyCur = yCurrentsIdx.y;
	float JhyPrev = yCurrentsBelowIdx.y;

	// Calculate the gradients

	float Jexx = (JexCur - JexPrev);
	float Jeyy = (JeyCur - JeyPrev)*pixelFrac;
	float Jhxx = (JhxCur - JhxPrev);
	float Jhyy = (JhyCur - JhyPrev)*pixelFrac;
			
	// Calculate the rates according to the continuity equation
	float n = pN[idx];
	float p = pP[idx];
	float recPart = pRN[idx]*p + pRP[idx]*n + pRecFactor[idx]*p*n;
	float netGenE = pGRminJDivE[idx] - recPart;
	float netGenH = pGRminJDivH[idx] - recPart;

	float dndt = netGenE - ( Jexx+Jeyy );
	float dpdt = netGenH - ( Jhxx+Jhyy );

	pNdst[idx] = n + dndt*scaledDt;
	pPdst[idx] = p + dpdt*scaledDt;
}

__kernel void updateDensitiesKernelSwapped(__global const float *pV, 
			  __global const float *pExBaseN, __global const float *pEyBaseN,
			  __global const float *pExBaseP, __global const float *pEyBaseP,
                          __global const float *pDe, __global const float *pDh,
			  __global const float *pEMob, __global const float *pHMob,
			  __global const float *pGRminJDivE,
			  __global const float *pGRminJDivH,
			  __global const float *pRN,
			  __global const float *pRP, __global const float *pRecFactor,
			  __global float *pNdst, __global float *pPdst,
			  __global const float *pNBase, __global const float *pPBase,
			  __global const float *pN, __global const float *pP,
			  int width, int height,
			  float scaledDt, float pixelFrac, float pixelFracInv)

{
	int idx = get_global_id(0);
	int x = idx % width;
	int y = idx / width;
	int prevX = (x-1+width)%width;
	int prevY = y-1;
	int leftIdx = prevX + y*width;
	int belowIdx = x + prevY*width;

	float2 xCurrentsIdx = getXRelCurrents(idx, pV, pExBaseN, pExBaseP, pDe, pDh, pEMob, pHMob, pN, pP, pNBase, pPBase, width, height);
	float2 yCurrentsIdx = getYRelCurrents(idx, pV, pEyBaseN, pEyBaseP, pDe, pDh, pEMob, pHMob, pN, pP, pNBase, pPBase, width, height, pixelFrac, pixelFracInv);
	float2 xCurrentsLeftIdx = getXRelCurrents(leftIdx, pV, pExBaseN, pExBaseP, pDe, pDh, pEMob, pHMob, pN, pP, pNBase, pPBase, width, height);
	float2 yCurrentsBelowIdx = getYRelCurrents(belowIdx, pV, pEyBaseN, pEyBaseP, pDe, pDh, pEMob, pHMob, pN, pP, pNBase, pPBase, width, height, pixelFrac, pixelFracInv);

	float JexCur = xCurrentsIdx.x;  // .x specifies electron current
	float JexPrev = xCurrentsLeftIdx.x;
	float JeyCur = yCurrentsIdx.x;
	float JeyPrev = yCurrentsBelowIdx.x;

	float JhxCur = xCurrentsIdx.y; // .y specifies hole current
	float JhxPrev = xCurrentsLeftIdx.y;
	float JhyCur = yCurrentsIdx.y;
	float JhyPrev = yCurrentsBelowIdx.y;

	// Calculate the gradients

	float Jexx = (JexCur - JexPrev);
	float Jeyy = (JeyCur - JeyPrev)*pixelFrac;
	float Jhxx = (JhxCur - JhxPrev);
	float Jhyy = (JhyCur - JhyPrev)*pixelFrac;
			
	// Calculate the rates according to the continuity equation
	float n = pN[idx];
	float p = pP[idx];
	float recPart = pRN[idx]*p + pRP[idx]*n + pRecFactor[idx]*p*n;
	float netGenE = pGRminJDivE[idx] - recPart;
	float netGenH = pGRminJDivH[idx] - recPart;

	float dndt = netGenE - ( Jexx+Jeyy );
	float dpdt = netGenH - ( Jhxx+Jhyy );

	pNdst[idx] = n + dndt*scaledDt;
	pPdst[idx] = p + dpdt*scaledDt;
}

__kernel void blackRedKernel(float chargeMultiplier, 
                          __global const float *pN, __global const float *pP,
			  __global const float *pA0, __global const float *pA1, __global const float *pA2,
			  __global const float *pA3, __global const float *pA4, 
			  __global const float *pVPredSub,
			  __global float *pV, 
			  float w, int blackOrRed, int width, int height)
{
	int idx = get_global_id(0);
	int x = idx % width;
	int y = idx / width;
	int xPrev = (x-1+width)%width;
	int xNext = (x+1)%width;
	int yNext = y+1;
	int yPrev = y-1;

	int leftIndex = xPrev + y*width;
	int rightIndex = xNext + y*width;
	int upIndex = x + yNext*width;
	int downIndex = x + yPrev*width;

	float charge = (pP[idx]-pN[idx])*chargeMultiplier;

	float prediction = ( pA1[idx]*pV[rightIndex] + pA2[idx]*pV[leftIndex] + pA3[idx]*pV[upIndex] + pA4[idx]*pV[downIndex] 
                             + charge ) / pA0[idx];

	float curValue = pV[idx];
	float diff = pVPredSub[idx] + prediction-curValue;

	int update = (x+y+blackOrRed+1)%2;
	
	float newValue = curValue + (diff*w)*((float)update);

	pV[idx] = newValue;
}

