#include "nr.h"
#include <iostream>
#include <fstream>
#include <math.h>

using namespace std;
using namespace NR;

void mult(Vec_DP &in1, Vec_DP &in2, Vec_DP &out) {
// return A times B
  for (int i=0; i<in1.size()/2; i++) {
     out[2*i]=in1[2*i]*in2[2*i]-in1[2*i+1]*in2[2*i+1];
     out[2*i+1]=in1[2*i+1]*in2[2*i]+in1[2*i]*in2[2*i+1];
  }
}

void divide(Vec_DP &in1, Vec_DP &in2, Vec_DP &out) {
// return A / B
  for (int i=0; i<in1.size()/2; i++) {
     double denom=in2[2*i]*in2[2*i]+in2[2*i+1]*in2[2*i+1];
     out[2*i]=in1[2*i]*in2[2*i]+in1[2*i+1]*in2[2*i+1];
     out[2*i]/=denom;
     out[2*i+1]=in1[2*i+1]*in2[2*i]-in1[2*i]*in2[2*i+1];
     out[2*i+1]/=denom;
  }
}
     
int main() {
   const int NX=512;
   const int NY=512;
   const int SIZE=2*NX*NY;
   const double PI=4*atan(1);

   Vec_DP orig(SIZE); // sum of signal and noise in time domain
   Vec_DP four(SIZE); // sum of signal and noise in time domain
   Vec_DP smear(SIZE); // noise in time domain
   Vec_DP prod(SIZE); // for the correlation
   Vec_INT sizes(2);
   sizes[0]=NX;
   sizes[1]=NY;

   ofstream origfil("./orig");
   ofstream smearfil("./smear");
   ofstream deconv("./deconv");

// produce data
// I make a sharp plot, with 1 and 0 in bins, and put a circle and some squares
   for (int i=0; i<NX; i++) for (int j=0; j<NY; j++) {
      int k = (i/64-j/64);
      if (k<0) k=-k;
      if (k==0) orig[j*NY+i]=1.0;
      if (k==1) orig[j*NY+i]=0;
      if (k==2) {
         int l = (i/32-j/32);
         if (l%2) orig[j*NY+i]=1;
         else orig[j*NY+i]=0;
      }
      if (k==2) {
         int l = (i/32-j/32);
         if (l%2) orig[j*NY+i]=1;
         else orig[j*NY+i]=0;
      }
      if (k==3) {
         int l = (i/16-j/16);
         if (l%2) orig[j*NY+i]=1;
         else orig[j*NY+i]=0;
      }
      if (k==4) {
         int l = (i/8-j/8);
         if (l%2) orig[j*NY+i]=1;
         else orig[j*NY+i]=0;
      }
      if (k==5) {
         int l = (i/4-j/4);
         if (l%2) orig[j*NY+i]=1;
         else orig[j*NY+i]=0;
      }
      if (k==7) {
         int l = (i-j);
         if (l%2) orig[j*NY+i]=1;
         else orig[j*NY+i]=0;
      }
      if (k==6) {
         int l = (i/2-j/2);
         if (l%2) orig[j*NY+i]=1;
         else orig[j*NY+i]=0;
      }
      if (fabs(sqrt((i-110)*(i-110)+(j-140)*(j-140))-100)<0.8) orig[j*NY+i]=1-orig[j*NY+i];
      if (fabs(sqrt((i-210)*(i-210)+(j-340)*(j-340))-200)<1.8) orig[j*NY+i]=1-orig[j*NY+i];
      origfil << i << " " << j << " " << orig[j*NY+i] << endl;
    }
    for (int i=0; i<SIZE/2; i++) {
       four[2*i]=orig[i];
       four[2*i+1]=0;
    }

    fourn(four,sizes,1);
    for (int i=0; i<SIZE; i++) smear[i]=0;
    for (int j=-5; j<6; j++) for (int i=-10; i<10; i++)  {
        double weight = exp(-j*j-i*i*0.3);	
        int index=((i+NX)%NX + ((j+NY)%NY)*NY);
        smear[2*index]=weight;
    }
    fourn(smear,sizes,1);
    mult(four,smear,prod);
    fourn(prod,sizes,-1);
      
   for (int j=0; j<NY; j++) for (int i=0; i<NX; i++)  {
      smearfil << i << " " << j << " " << prod[2*(j*NY+i)] << endl;
    }

    // deconvolve
    fourn(prod,sizes,1);
    divide(prod,smear,four);
    fourn(four,sizes,-1);
   for (int j=0; j<NY; j++) for (int i=0; i<NX; i++)  {
      deconv << i << " " << j << " " << four[2*(j*NY+i)] << endl;
    }
}

