package nl.nikhef.slcshttps.trust;

import javax.net.ssl.X509TrustManager;
import java.security.cert.X509Certificate;
import java.security.KeyStore;

import java.util.Enumeration;
import java.util.Hashtable;

import java.security.KeyStoreException;
import java.security.cert.CertificateException;

import java.io.BufferedReader;
import java.io.InputStreamReader;

import java.io.IOException;

import nl.nikhef.slcshttps.util.ConsoleTools;
import nl.nikhef.slcshttps.gui.GraphTools;
import nl.nikhef.slcshttps.gui.TrustPopupComm;

/**
 * This class implements a <CODE>X509TrustManager</CODE> which asks the user for
 * confirmation when something is wrong and in this process also checks whether
 * the hostname is valid for the certificate chain. This is non-trivial since
 * the implemented methods which are called during the SSL handshake, {@link
 * X509TrustManager#checkServerTrusted(X509Certificate[], String)
 * checkServerTrusted()} and {@link
 * X509TrustManager#checkClientTrusted(X509Certificate[], String)
 * checkClientTrusted()}, do not have the hostname/portnumber. This class
 * provides static fields which can be set using {@link #setHostname(String)}
 * and {@link #setPort(int)} to solve this. These methods have to be called
 * before setting up a HTTPS connection, which can be done e.g. by using {@link
 * HttxURLConnection}. The user communication is handled by a implementation of
 * {@link TrustCommunicator}. An implementation using just stdio is given by
 * {@link StdioComm}. Note that using static fields for hostname and portnumber
 * makes it non-thread safe. Doing this in a thread safe way is difficult since
 * the response of the user on invalid certificates should be kept global.
 * @author Mischa Sall&eacute;
 * @version 0.1
 */
public class TrustManagerImpl implements X509TrustManager	{
    /** Propertyname to set the type of communicator {@value}. */
    private final static String COMMPROP="nl.nikhef.slcshttps.comm";
    /** hostname of the open connection. */
    private static String host=null;
    /** portnumber of the open connection, note that for a certain combination
     * hostname/portnumber there can only be one certificate chain; the
     * portnumber is initialized to 443, the default for HTTPS. */
    private static int port=443;
    /** global table of known alias - certificate pairs, where alias is
     * hostname:port. */
    private static Hashtable<String, TrustCert> trustCertsTable=null;

    /** describes the type of communicator in use, initialized using the value
     * of property {@value COMMPROP} by {@link #setCommunicator(String)}. */
    private static String commString=null;
    /** The {@link TrustCommunicator} to be used, can be set using {@link
     * #setCommunicator(String)}. */
    private static TrustCommunicator comm=null;
    /** Initialize <CODE>commString</CODE> and <CODE>comm</CODE>. */
    static {
	// Valid options: stdio, popup, make it static for the class...
	String commProp=System.getProperty(COMMPROP);
	setCommunicator(commProp);
    }

    /**************************************************************************
     * METHODS
     *************************************************************************/

    /**
     * Constructs a <CODE>TrustManagerImpl</CODE>. Note that the
     * hostname/portnumber have to be set using {@link #setHostname(String)} and
     * {@link #setPort(int)}.
     * @see #TrustManagerImpl(String,int)
     */
    public TrustManagerImpl() {
	trustCertsTable=new Hashtable<String, TrustCert>();
    }

    /**
     * Constructs a <CODE>TrustManagerImpl</CODE> and sets the global
     * <CODE>hostName</CODE>.
     * @param hostName hostname for which this TrustManager is used. Note that
     * the portnumber has to be set using {@link #setPort(int)}.
     * @see #TrustManagerImpl(String,int)
     */
    public TrustManagerImpl(String hostName) {
	this();
	setHostname(hostName);
    }

    /**
     * Constructs a TrustManagerImpl and sets the global <CODE>hostName</CODE>
     * and <CODE>portNumber</CODE>. Note that they can be changed using {@link
     * #setHostname(String)} and {@link #setPort(int)}.
     * @param hostName sets the static hostname for the class
     * @param portNumber sets the static port number for the class
     */
    public TrustManagerImpl(String hostName, int portNumber) {
	this();
	setHostname(hostName);
	setPort(portNumber);
    }

    /**
     * Main checking method, contains all the logic: it checks the Server
     * certificate chain, also against the hostname which can be set either at
     * construction time or using {@link #setHostname(String)}. It uses a
     * table with alias/{@link TrustCert certificate} pairs containing chains it
     * has seen before. There are two reasons for this: if a certificate chain
     * has been accepted once, it will be accepted again within the same session
     * (unless is has expired); secondly the whole chain does not have to be
     * checked again, improving performance. There are roughly 4 different
     * possibilities:<UL>
     * <LI>The alias is known and the certchain is the same: need only a time
     * validity check.
     * <LI>The alias is known but the certchain is different: need a full check
     * and upon confirmation substitute the old.
     * <LI>The alias is unknown but the certchain is: need host/time validity
     * check and possibly confirmation to add the alias.
     * <LI>The alias and the certchain are unknown: need a full check and
     * possibly confirmation.
     * </UL>
     * @param chain peer <CODE>X509Certificate</CODE> chain to be checked.
     * @param authType the key exchange algorithm used (unused).
     * @throws CertificateException if the chain doesn't verify, including a
     * unset hostname.
     */
    public void checkServerTrusted(X509Certificate[] chain, String authType)
	throws CertificateException
    {
	// Check whether we have a hostname. Default portnumber is set to 443.
	if (host==null)
	    throw new CertificateException(
			"Hostname is not set, "+
			"probably not using the HttxURLConnection");

	boolean valid;
	String alias=host+":"+port;
	// Look for alias:
	TrustCert oldCert=trustCertsTable.get(alias);
	if (oldCert!=null)	{ // alias is known: certificate has to match!
	    if (oldCert.equals(chain[0])) {
		// certificates are equal: only validity left.
		// Backup old status, we need to be able to put it back.
		TrustCert.Status oldStatus=oldCert.status.copy();
		// Accept when it was already expired, when it is valid or
		// if it is not yet valid now (then it also wasn't before).
		if (oldStatus.expired || 
		    oldCert.checkValidity() || oldCert.status.notYet)
		    return;
		// Cert has expired since last time!
		String[] olderrs=oldCert.getOldErrors();
		String[] errs=oldCert.getErrors(host);
		// If the cert was only known while valid, user doesn't know...
		String mesg=(olderrs==null ?
		    "Invalid certificate found while connecting to "+host+"." :
		    "Certificate chain has expired since the last time we saw it.");
		String ques="Do you want to accept it";
		boolean accept=comm.confirm(host,mesg,ques,errs,olderrs);
		if (!accept)	{
		    // revert status, next time it should say 'expired' again.
		    oldCert.status=oldStatus;
		    throw new CertificateException("Certificate chain for "+
						    alias+" has expired");
		}
		// update its alias table
		updateCert(alias,oldCert);
		return;
	    } else {
		// alias is found, but certificates don't match: full check,
		// if OK just replace old, otherwise ask the USER.
		TrustCert trustCert=new TrustCert(chain);
		// Note: need all three, so don't combine in one line!
		valid=trustCert.checkValidity();
		valid=trustCert.checkChain() && valid;
		valid=trustCert.checkHostname(host) && valid;
		if (!valid)  { // something is not right, with the new one
		    String[] errs=trustCert.getErrors(host);
		    String mesg="Certificate chain for "+host+
				" has changed and is invalid.";
		    String ques="Do you want to accept the new one "+
				"and replace the old";
		    boolean accept=comm.confirm(host,mesg,ques,errs,null);
		    if (!accept)
			throw new CertificateException("Certificate for "+
						alias+" failed validation");
		}
		// remove old alias/cert and add new
		removeCert(alias,oldCert);
		addCert(alias,trustCert);
		return;
	    }
	} else { // no known alias, maybe known certificate...
	    oldCert=getOldCert(chain[0]);
	    if (oldCert!=null)	{ // Certchain is known
		// Backup old status, we need to be able to put it back.
		TrustCert.Status oldStatus=oldCert.status.copy();
		// Even if certchain is invalid, we did accept it, so only need
		// confirmation if validity or hostname are invalid!
		valid=oldCert.checkValidity();
		valid=oldCert.checkHostname(host) && valid;
		if (!oldCert.status.nameValid ||
		    (oldCert.status.expired && !oldStatus.expired)) {
		    String[] olderrs=oldCert.getOldErrors();
		    String[] errs=oldCert.getErrors(host);
		    String mesg=(olderrs==null ?
			"Invalid certificate found while connecting to "+host+"." :
			"Certificate chain is known but has new problems.");
		    String ques="Do you want to accept it";
		    boolean accept=comm.confirm(host,mesg,ques,errs,olderrs);
		    if (!accept)    {
			// revert status
			oldCert.status=oldStatus;
			throw new CertificateException(
				"Known invalid certificate not accepted for "+
				alias);
		    }
		}
		addCert(alias,oldCert);
		return;
	    } else { // cert fully unknown: do full check
		TrustCert trustCert=new TrustCert(chain);
		valid=trustCert.checkValidity();
		valid=trustCert.checkHostname(host) && valid;
		valid=trustCert.checkChain() && valid;
		if (!valid)  {
		    String[] errs=trustCert.getErrors(host);
		    String mesg="Invalid certificate found while connecting to "+
				host+".";
		    String ques="Do you want to accept it";
		    boolean accept=comm.confirm(host,mesg,ques,errs,null);
		    if (!accept)
			throw new CertificateException(
			    "Unknown invalid certificate not accepted for "+
			    alias);
		}
		// Add aliases (also to fully acceptable certs: caching...
		addCert(alias,trustCert);
		return;
	    }
	}
    }

    /**
     * Adds an alias/certificate to the list of known certificates. It adds the
     * alias to <CODE>cert</CODE> and then adds the alias/cert to the list of
     * known certificates.
     * @param alias the alias to add
     * @param cert the certificate to add
     */
    private void addCert(String alias, TrustCert cert) {
	// First add the alias to the certificate then add the certificate to
	// the list.
	cert.addAlias(alias);
	trustCertsTable.put(alias,cert);
    }

    /**
     * Updates the status for a known certificate. It does this by removing and
     * then adding the alias to <CODE>cert</CODE>, which effectively just
     * changes the status.
     * @param alias the alias
     * @param cert the certificate for which to update the alias.
     */
    private void updateCert(String alias, TrustCert cert)   {
	cert.removeAlias(alias);
	cert.addAlias(alias);
    }

    /**
     * Removes an alias/certificate from the list of known certificates. It
     * removes the alias from <CODE>cert</CODE> and then removes the alias from
     * the list of known aliases.
     * @param alias the alias to be removed
     * @param cert the certificate that belongs to the alias.
     */
    private void removeCert(String alias, TrustCert cert)   {
	cert.removeAlias(alias);
	trustCertsTable.remove(alias);
    }

    /**
     * Finds a certificate in the list of known alias/certificates.
     * @param x509Cert certificate to look for.
     * @return TrustCert for the given <CODE>X509Certificate</CODE> or
     * <CODE>null</CODE> when it is unknown.
     */
    private TrustCert getOldCert(X509Certificate x509Cert)	{
	TrustCert oldCert=null;
	boolean found=false;
	Enumeration<String> aliases=trustCertsTable.keys();
	for (; aliases.hasMoreElements();)	{
	    oldCert=trustCertsTable.get(aliases.nextElement());
	    if (oldCert.equals(x509Cert)) { // oldCert found
		found=true;
		break;
	    }
	}
	return (found ? oldCert : null);
    }

    /**
     * Dummy Client Certificate chain checker, which never fails.
     * @param chain <CODE>X509Certificate</CODE> chain to be checked
     * @param authType the authentication type based on the client certificate
     * @throws CertificateException if the chain doesn't verify
     */
    public void checkClientTrusted(X509Certificate[] chain, String authType)
	throws CertificateException
    {
    }

    /**
     * Return an array of certificate authority certificates which are trusted
     * for authenticating peers.
     * @return X509Certificate[] a non-null (possibly empty) array of acceptable
     * CA issuer certificates.
     * @see CertChainChecker#getAcceptedIssuers()
     */
    public X509Certificate[] getAcceptedIssuers()   {
	return CertChainChecker.getAcceptedIssuers();
    }

    /**
     * Sets the (static) hostname to be used during checking.
     * @param hostName static hostname to be used during checking.
     * @see #TrustManagerImpl(String)
     * @see #TrustManagerImpl(String,int)
     */
    public static void setHostname(String hostName)	{
	host=hostName;
    }
    
    /**
     * Sets the (static) portnumber to be used during checking.
     * @param portNumber static portnumber to be used during checking.
     * @see #TrustManagerImpl(String,int)
     */
    public static void setPort(int portNumber)	{
	port=portNumber;
    }

    /**
     * Sets the type of {@link TrustCommunicator} based on
     * <CODE>commInput</CODE>. Valid values are:<UL>
     * <LI><CODE>"stdio"</CODE> - use <CODE>stdin/stdout/stderr</CODE>
     * <LI><CODE>"popup"</CODE> - use (swing) popups
     * <LI><CODE>null</CODE> - use default <CODE>"stdio"</CODE>
     * </UL>
     * @param commInput <CODE>String</CODE> describing the wished type of
     * communicator to be used.
     * @return <CODE>String</CODE> describing the actual type being used.
     * @see #getCommunicator()
     */
    public static String setCommunicator(String commInput)  {
	final String defcomm="stdio";

	// If we don't have a gui, use stdio
	if (!GraphTools.isGraphic())
	    commString="stdio";
	else {
	    // If not specified, use default
	    if (commInput==null)
		commString=defcomm;
	    else
		commString=commInput.toLowerCase();
	}
	
	// Set the communicator
	if ("popup".equals(commString))	{
	    comm=new TrustPopupComm();
	    return commString;
	} else if ("stdio".equals(commString))	{
	    comm=new StdioComm();
	    return commString;
	} else { // Use default when unknown...
	    comm=new StdioComm();
	    return commString;
	}
    }
    
    /**
     * Returns the type of {@link TrustCommunicator} used for user interaction.
     * @return String describing the type being used.
     * @see #setCommunicator(String)
     */
    public static String getCommunicator()  {
	return commString;
    }

    /**
     * Interface for {@link TrustManagerImpl} communication with the user.
     * Only one method needs to be implemented, which asks the user for
     * confirmation.
     * @see StdioComm
     * @author Mischa Sall&eacute;
     * @version 0.1
     */
    public interface TrustCommunicator	{
	/**
	 * method to ask the user for confirmation.
	 * @param host <CODE>String</CODE> with the hostname to which an SSL is
	 * being set up.
	 * @param mesg <CODE>String</CODE> containing a message describing the
	 * problem.
	 * @param ques <CODE>String</CODE> containing a question asked to the
	 * user just before user input.
	 * @param errs <CODE>String</CODE> array with a list of errors.
	 * @param olderrs <CODE>String</CODE> array with a list of errors when
	 * this cert was seen previously.
	 * @return boolean <CODE>true</CODE> when the user accepts it,
	 * <CODE>false</CODE> when the user rejects it.
	 */
	public boolean confirm(String host,String mesg,String ques,
			       String[] errs,String[] olderrs);
    }

    /**
     * This Implementation uses only <CODE>stdio/stderr</CODE> for I/O.
     * @author Mischa Sall&eacute;
     * @version 0.1
     * @see TrustCommunicator
     */
    static class StdioComm implements TrustManagerImpl.TrustCommunicator    {
	/**
	 * method to ask the user for confirmation using stdio/stderr.
	 * @param host <CODE>String</CODE> with the hostname to which an SSL is
	 * being set up.
	 * @param mesg <CODE>String</CODE> containing a message describing the
	 * problem.
	 * @param ques <CODE>String</CODE> containing a question asked to the
	 * user just before user input.
	 * @param errs <CODE>String</CODE> array with a list of errors.
	 * @param olderrs <CODE>String</CODE> array with a list of errors when
	 * this cert was seen previously.
	 * @return boolean <CODE>true</CODE> when the user accepts it,
	 * <CODE>false</CODE> when the user rejects it.
	 */
	public boolean confirm(String host,String mesg,String ques,
			       String[] errs,String[] olderrs)   {
	    System.err.println("SECURITY PROBLEM WHILE CONNECTING TO "+host+
			       "!\n "+mesg);
	    // if there where errors before, print them
	    if (olderrs!=null)	{
		System.err.println("You accepted this chain before for the "+
				   "following hosts/errors:");
		for (int i=0; i<olderrs.length; i++)
		    System.err.println(" * "+olderrs[i]);
	    }
	    System.err.println("It has the following problem(s):");
	    for (int i=0; i<errs.length; i++)
		System.err.println(" * "+errs[i]);
	    try {
		return ConsoleTools.getConfirm(ques);
	    } catch(IOException e)	{
		System.err.println("Caught exception: "+e.getMessage());
		System.err.println("Not confirming validation");
		return false;
	    }
	}
    }
}
