#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <math.h>
#include <assert.h>
#include <string>
#include <fstream>
#include <iostream>
#include "infer.h"

#define coeff (10.0f)

float avg(float input1, float input2, float input3, float input4){
	
	float result;
	result = (input1 + input2 + input3 + input4)/4.0f;
	return result;
}

void infer_only(float C1_weight[6][5][5], float C1_bias[6], float C2_weight[16][6][5][5], float C2_bias[16], float C3_weight[120][16][4][4], float C3_bias[120], float F1_weight[84][120], float F1_bias[84], float F2_weight[10][84], float F2_bias[10], float Input_Test_Img[10000][28][28], float Test_Label[10000][10], bool *judge){
//void infer_only(Input_Test_Img[10000][28][28], Test_Label[10000][10]){
	int i, j, k, m;
	for(i = 3; i < 4; i++){
//printf("In infer\n");
		float tmp_input[28][28];
		float tmp_label[10];
		for(j = 0; j < 28; j++){
			for(k = 0; k < 28; k++){
				tmp_input[j][k] = Input_Test_Img[i][j][k];
			}
		}
		for(j = 0; j < 10; j++){
			tmp_label[j] = Test_Label[i][j];	
		}	
	
	//=======================C1============================================================
	//printf("Convolution 1 start\n");
	float C1_preact[6][24][24] = {0.0f};
	float C1_output[6][24][24] = {0.0f};
	fp_C1_preact((float (*)[28])tmp_input, (float (*)[24][24])C1_preact, (float (*)[5][5])C1_weight);
	fp_C1_bias((float (*)[24][24])C1_preact, (float (*)[24][24])C1_output, C1_bias);
	//printf("Convolution 1 end\n");
	//=======================pooling1============================================================
	//printf("Subsampling 1 start\n");
	float S1_output[6][12][12] = {0.0f};
	fp_S1_preact((float (*)[24][24])C1_output, (float (*)[12][12])S1_output);
	//printf("Subsampling 1 end\n");
	
	//=======================C2============================================================
	//printf("Convolution 2 start\n");
	float C2_preact[16][8][8] = {0.0f};
	float C2_output[16][8][8] = {0.0f};
	fp_C2_preact((float (*)[12][12])S1_output, (float (*)[8][8])C2_preact, (float (*)[6][5][5])C2_weight);
	fp_C2_bias((float (*)[8][8])C2_preact, (float (*)[8][8])C2_output, C2_bias);
	//printf("Convolution 2 end\n");
	//=======================pooling2============================================================
	//printf("Subsampling 2 start\n");
	float S2_output[16][4][4] = {0.0f};
	fp_S2_preact((float (*)[8][8])C2_output, (float (*)[4][4])S2_output);
	//printf("Subsampling 2 end\n");
	//=======================C3============================================================
	//printf("Convolution 3 start\n");
	float C3_preact[120] = {0.0f};
	float C3_output[120] = {0.0f};
	fp_C3_preact((float (*)[4][4])S2_output, C3_preact, (float (*)[16][4][4])C3_weight);
	fp_C3_bias(C3_preact, C3_output, C3_bias);
	//printf("Convolution 3 end\n");
	//==================================F1=================================================
	//printf("Fully 1 start\n");
	float F1_preact[84] = {0.0f};
	float F1_output[84] = {0.0f};
	fp_F1_preact(C3_output, F1_preact, (float (*)[120])F1_weight);
	fp_F1_bias(F1_preact, F1_output, F1_bias);
	//printf("Fully 1 end\n");
	//==================================F1=================================================
	//printf("Fully 2 start\n");
	float F2_preact[10] = {0.0f};
	float F2_output[10] = {0.0f};
	fp_F2_preact(F1_output, F2_preact, (float (*)[84])F2_weight);
	fp_F2_bias(F2_preact, F2_output, F2_bias);
	//printf("Fully 2 end\n");

	//printf("forward pass done!!\n");
		int max_index = 0;
		float max_number = 0.0f;
		for(m = 0; m < 10; m++){
			if(F2_output[m] > max_number){
				max_number = F2_output[m];
				max_index = m;
			}	
		}
		if(Test_Label[i][max_index] == 1.0f){
			*judge = 1;
		}
	}
}

//================================== fp C1 START====================================
void fp_C1_preact(float input[28][28], float C1_preact[6][24][24], float C1_weight[6][5][5]){
	
	int i, j, k, m, n;
	//convolution start	
	for(i = 0; i < 6; i++){
		for(j = 0; j < 24; j++){
			for(k = 0; k < 24; k++){
				for(m = 0; m < 5; m++){
					for(n = 0; n < 5; n++){
						C1_preact[i][j][k] += C1_weight[i][m][n] * input[j+m][k+n];	
					}
				}
			}
		}
	}
}

void fp_C1_bias(float C1_preact[6][24][24], float C1_output[6][24][24], float C1_bias[6]){
	
	int i, j, k;
	
	for(i = 0; i < 6; i++){
		for(j = 0; j < 24; j++){
			for(k = 0; k < 24; k++){
				C1_output[i][j][k] = C1_preact[i][j][k] + C1_bias[i];
				C1_output[i][j][k] = 1/(1 + exp(-(C1_output[i][j][k]*coeff)));
			}
		}
	}

}
//================================fp C1 end====================================================
void fp_S1_preact(float input[6][24][24], float S1_output[6][12][12]){

	int i, j, k;
	for(i = 0; i < 6; i++){
		for(j = 0; j < 12; j++){
			for(k = 0; k < 12; k++){
				
				S1_output[i][j][k] = avg(input[i][j*2][k*2], input[i][j*2+1][k*2], input[i][j*2][k*2+1], input[i][j*2+1][k*2+1]);
			}
		}
	}
}
//==============================fp C2 start=====================================================

void fp_C2_preact(float input[6][12][12], float C3_preact[16][8][8], float C3_weight[16][6][5][5]){
	
	int ch_c, ch_p, i, j, m, n;
	//convolution start	
	for(ch_c = 0; ch_c < 16; ch_c++){
		for(i = 0; i < 8; i++){
			for(j = 0; j < 8; j++){
				for(ch_p = 0; ch_p < 6; ch_p++){
					for(m = 0; m < 5; m++){
						for(n = 0; n < 5; n++){
							C3_preact[ch_c][i][j] += C3_weight[ch_c][ch_p][m][n] * input[ch_p][i+m][j+n];	
						}
					}
				}
			}
		}
	}
}
void fp_C2_bias(float C3_preact[16][8][8], float C3_output[16][8][8], float C3_bias[16]){
	
	int i, j, k;
	
	for(i = 0; i < 16; i++){
		for(j = 0; j < 8; j++){
			for(k = 0; k < 8; k++){
				C3_output[i][j][k] = C3_preact[i][j][k] + C3_bias[i];
				C3_output[i][j][k] = 1/(1 + exp(-(C3_output[i][j][k]*coeff)));
			}
		}
	}

}
//=============================================================================================
void fp_S2_preact(float input[16][8][8], float S4_output[16][4][4]){

	int i, j, k;
	for(i = 0; i < 16; i++){
		for(j = 0; j < 4; j++){
			for(k = 0; k < 4; k++){
				
				S4_output[i][j][k] = avg(input[i][j*2][k*2], input[i][j*2+1][k*2], input[i][j*2][k*2+1], input[i][j*2+1][k*2+1]);
			}
		}
	}
}
//==============================fp C3 start=====================================================
void fp_C3_preact(float input[16][4][4], float C5_preact[120], float C5_weight[120][16][4][4]){
	
	int ch_c, ch_p, i, j;
	//convolution start	
	for(ch_c = 0; ch_c < 120; ch_c++){
		for(ch_p = 0; ch_p < 4; ch_p++){
			for(i = 0; i < 4; i++){
				for(j = 0; j < 4; j++){
					C5_preact[ch_c] += C5_weight[ch_c][ch_p][i][j] * input[ch_p][i][j];	
				}
			}
		}
	}

}
void fp_C3_bias(float C5_preact[120], float C5_output[120], float C5_bias[120]){
	
	int i, j, k;
	
	for(i = 0; i < 120; i++){
		C5_output[i] = C5_preact[i] + C5_bias[i];
		C5_output[i] = 1/(1 + exp(-(C5_output[i]*coeff)));
	}
}
//=============================================================================================
//==============================fp fully connect 1 start=======================================
void fp_F1_preact(float input[120], float F1_preact[84], float F1_weight[84][120]){
	
	int ch_c, i, j;
	//convolution start	
	for(ch_c = 0; ch_c < 84; ch_c++){
		for(i = 0; i < 120; i++){
			F1_preact[ch_c] += F1_weight[ch_c][i] * input[i];	
		}
	}

}
void fp_F1_bias(float F1_preact[84], float F1_output[84], float F1_bias[84]){
	
	int i;
	
	for(i = 0; i < 84; i++){
		F1_output[i] = F1_preact[i] + F1_bias[i];
		F1_output[i] = 1/(1 + exp(-(F1_output[i]*coeff)));
	}

}
//=============================================================================================
//==============================fp fully connect 2 start=======================================
void fp_F2_preact(float input[84], float F2_preact[10], float F2_weight[10][84]){
	
	int ch_c, i;
	//convolution start	
	for(ch_c = 0; ch_c < 10; ch_c++){
		for(i = 0; i < 84; i++){
			F2_preact[ch_c] += F2_weight[ch_c][i] * input[i];	
		}
	}

}
void fp_F2_bias(float F2_preact[10], float F2_output[10], float F2_bias[10]){
	
	int i;
	
	for(i = 0; i < 10; i++){
		F2_output[i] = F2_preact[i] + F2_bias[i];
		F2_output[i] = 1/(1 + exp(-(F2_output[i]*coeff)));
	}

}

//=============================================================================================
