package nl.nikhef.slcshttps.trust;

import java.util.Collection;
import java.util.List;
import java.util.StringTokenizer;
import java.util.Vector;

import java.security.cert.X509Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateParsingException;

/**
 * Class to check whether a certain certificate is valid for a certain hostname,
 * either using TLS or LDAP scheme.
 * This class is roughly a combination of the JDK1.6 internal
 * <CODE>sun.security.util.HostnameChecker</CODE> and
 * <CODE>sun.net.util.IPAddressUtil</CODE>. At the same time it uses a method
 * <CODE>getCNS()</CODE>, roughly adapted from the
 * <A HREF="http://juliusdavies.ca/commons-ssl/">not-yet-commons-ssl-0.3.10</A>
 * package, instead of
 * <CODE>getSubjectX500Name().findMostSpecificAttribute(X500Name.commonName_oid)</CODE>.
 */

public class HostnameChecker {
    /** Number of bytes for a IPv4 address. */
    private final static int INADDR4SZ = 4;
    /** Number of bytes for a IPv6 address. */
    private final static int INADDR16SZ = 16;
    /** used in IPv6 checking. */
    private final static int INT16SZ = 2;

    /** Constant for a HostnameChecker for TLS. */
    public final static byte TYPE_TLS = 1;
    /** Contains the HostnameChecker for type TLS. */
    private final static HostnameChecker INSTANCE_TLS =
                                        new HostnameChecker(TYPE_TLS);

    /** Constant for a HostnameChecker for LDAP. */
    public final static byte TYPE_LDAP = 2;
    /** Contains the HostnameChecker for type LDAP. */
    private final static HostnameChecker INSTANCE_LDAP =
                                        new HostnameChecker(TYPE_LDAP);

    /** constant for subject alt names of type DNS. */
    private final static int ALTNAME_DNS = 2;
    /** constant for subject alt names of type IP. */
    private final static int ALTNAME_IP  = 7;

    /** the algorithm to follow to perform the check. */
    private final byte checkType;

    /** Constructs a <CODE>HostnameChecker</CODE> for type
     * <CODE>checkType</CODE>.
     * @param checkType specifies which type to use, {@link #TYPE_TLS} or {@link
     * #TYPE_LDAP}
     */
    private HostnameChecker(byte checkType) {
        this.checkType = checkType;
    }

    /**
     * Returns a <CODE>HostnameChecker</CODE> instance of the right type. Note
     * that no new instance is created!
     * <CODE>checkType</CODE> should be one of the TYPE_* constants defined in
     * this class.
     * @param checkType specifies which type to return, {@link #TYPE_TLS} or {@link
     * #TYPE_LDAP}
     * @return HostnameChecker of the correct type;
     */
    public static HostnameChecker getInstance(byte checkType) {
        if (checkType == TYPE_TLS) {
            return INSTANCE_TLS;
        } else if (checkType == TYPE_LDAP) {
            return INSTANCE_LDAP;
        }
        throw new IllegalArgumentException("Unknown check type: " + checkType);
    }

    /**
     * Tries to match the {@link X509Certificate} against the given
     * <CODE>expectedName</CODE>.
     * @param expectedName <CODE>String</CODE> containing the hostname or IP to
     * check
     * @param cert <CODE>X509Certificate</CODE>
     * @throws CertificateException if the name does not match any of the names
     * specified in the certificate
     */
    public void match(String expectedName, X509Certificate cert)
            throws CertificateException {
        if (isIpAddress(expectedName)) {
           matchIP(expectedName, cert);
        } else {
           matchDNS(expectedName, cert);
        }
    }

    /**
     * Test whether the given hostname looks like a literal IPv4 or IPv6
     * address. The hostname does not need to be a fully qualified name.
     * This is not a strict check that performs full input validation.
     * That means if the method returns <CODE>true</CODE>, name need not be a
     * correct IP address, rather that it does not represent a valid DNS
     * hostname. Likewise for IP addresses when it returns <CODE>false</CODE>.
     * @param name <CODE>String</CODE> to check
     * @return boolean <CODE>true</CODE> if <CODE>name</CODE> looks like an IP
     * address.
     */
    private static boolean isIpAddress(String name) {
        if (isIPv4LiteralAddress(name) ||
            isIPv6LiteralAddress(name)) {
            return true;
        } else {
            return false;
        }
    }

    /**
     * Check if the certificate allows use of the given IP address.
     *
     * From RFC2818:
     * In some cases, the URI is specified as an IP address rather than a
     * hostname. In this case, the iPAddress subjectAltName must be present
     * in the certificate and must exactly match the IP in the URI.
     * @param expectedIP <CODE>String</CODE> containing the IP to check.
     * @param cert <CODE>X509Certificate</CODE>.
     * @throws CertificateException if the certificate is not valid for the
     * given IP address.
     * @see #match(String,X509Certificate)
     * @see #matchDNS(String,X509Certificate)
     */
    private static void matchIP(String expectedIP, X509Certificate cert)
            throws CertificateException {
        Collection<List<?>> subjAltNames = cert.getSubjectAlternativeNames();
        if (subjAltNames == null) {
            throw new CertificateException
                                ("No subject alternative names present");
        }
        for (List<?> next : subjAltNames) {
            // For IP address, it needs to be exact match
            if (((Integer)next.get(0)).intValue() == ALTNAME_IP) {
                String ipAddress = (String)next.get(1);
                if (expectedIP.equalsIgnoreCase(ipAddress)) {
                    return;
                }
            }
        }
        throw new CertificateException("No subject alternative " +
                        "names matching " + "IP address " +
                        expectedIP + " found");
    }

    /**
     * Check if the certificate allows use of the given DNS name.
     * Note: default Sun Java behaviour is to check only the first CN, only
     * IE(6) checks all, we use the Java default.
     * @param expectedName <CODE>String</CODE> containing the DNS name to check.
     * @param cert <CODE>X509Certificate</CODE>.
     * @throws CertificateException if the certificate is not valid for the
     * given DNS name.
     * @see #match(String,X509Certificate)
     * @see #matchDNS(String,X509Certificate,boolean)
     * 
     */
    private void matchDNS(String expectedName, X509Certificate cert)
            throws CertificateException {
	matchDNS(expectedName,cert,false);
    }

    /**
     * Check if the certificate allows use of the given DNS name.
     *
     * Below is comments from JDK 1.6 source, note that RFC2459 is superseded by
     * RFC3280 and later RFC5280.
     *
     * From RFC2818:
     * If a subjectAltName extension of type dNSName is present, that MUST
     * be used as the identity. Otherwise, the (most specific) Common Name
     * field in the Subject field of the certificate MUST be used. Although
     * the use of the Common Name is existing practice, it is deprecated and
     * Certification Authorities are encouraged to use the dNSName instead.
     *
     * Matching is performed using the matching rules specified by
     * [RFC2459].  If more than one identity of a given type is present in
     * the certificate (e.g., more than one dNSName name, a match in any one
     * of the set is considered acceptable.)
     *
     * @param expectedName <CODE>String</CODE> containing the DNS name to check.
     * @param cert <CODE>X509Certificate</CODE>.
     * @param allCN <CODE>boolean<CODE> whether to check all CN's or just the
     * first. Only IE uses all.
     * @throws CertificateException if the certificate is not valid for the
     * given DNS name.
     * @see #match(String,X509Certificate)
     * @see #matchDNS(String,X509Certificate)
     */
    private void matchDNS(String expectedName, X509Certificate cert, boolean allCN)
            throws CertificateException {
        Collection<List<?>> subjAltNames = cert.getSubjectAlternativeNames();
        if (subjAltNames != null) {
            boolean foundDNS = false;
            for ( List<?> next : subjAltNames) {
                if (((Integer)next.get(0)).intValue() == ALTNAME_DNS) {
                    foundDNS = true;
                    String dnsName = (String)next.get(1);
                    if (isMatched(expectedName, dnsName)) {
                        return;
                    }
                }
            }
            if (foundDNS) {
                // if certificate contains any subject alt names of type DNS
                // but none match, reject
                throw new CertificateException("No subject alternative DNS "
                        + "name matching " + expectedName + " found.");
            }
        }
	// We substitute here a different way of getting the CNs, adapted from
	// Julius Davies not-yet-common-ssl, see above 
	String[] commonName=getCNs(cert);
	if (commonName != null)	{
	    int max=allCN ? commonName.length : 1;
	    for (int i=0; i<max; i++) {
		if (isMatched(expectedName, commonName[i] /* derValue.getAsString() */))
		    return;
	    }
        }
        String msg = "No name matching " + expectedName + " found";
        throw new CertificateException(msg);
    }

    /**
     * Method to obtain all the CNs from a certificate. This method replaces the
     * getSubjectX500Name().findMostSpecificAttribute(X500Name.commonName_oid)
     * and is adapted from
     * <A HREF="http://juliusdavies.ca/commons-ssl/">not-yet-commons-ssl-0.3.10</A>.
     * @param cert X509Certificate to use
     * @return String[] array of CN's
     */
    private static String[] getCNs(X509Certificate cert)	{
	Vector<String> cnVector = new Vector<String>();
        String subjectPrincipal = cert.getSubjectX500Principal().toString();
        StringTokenizer st = new StringTokenizer(subjectPrincipal, ",");
        while (st.hasMoreTokens()) {
            String tok = st.nextToken();
            int x = tok.indexOf("CN=");
            if (x >= 0)
                cnVector.add(tok.substring(x + 3));
        }
        if (!cnVector.isEmpty()) {
            String[] cns = new String[cnVector.size()];
            cnVector.toArray(cns);
            return cns;
        } else
            return null;
    }

    /**
     * Returns true if name matches against template.
     *
     * The matching is performed as per RFC 2818 rules for TLS and
     * RFC 2830 rules for LDAP.
     *
     * @param name should represent a DNS name.
     * @param template may contain the wildcard character *
     * @return boolean whether name matches template
     * @see #matchAllWildcards(String,String)
     * @see #matchLeftmostWildcard(String,String)
     * @see #matchWildCards(String,String)
     */
    private boolean isMatched(String name, String template) {
        if (checkType == TYPE_TLS) {
            return matchAllWildcards(name, template);
        } else if (checkType == TYPE_LDAP) {
            return matchLeftmostWildcard(name, template);
        } else {
            return false;
        }
    }


    /**
     * Returns true if name matches against template.
     *
     * According to RFC 2818, section 3.1 -
     * Names may contain the wildcard character * which is
     * considered to match any single domain name component
     * or component fragment.
     * E.g., *.a.com matches foo.a.com but not
     * bar.foo.a.com. f*.com matches foo.com but not bar.com.
     * @param name should represent a DNS name.
     * @param template may contain the wildcard character *
     * @return boolean whether name matches template
     * @see #isMatched(String,String)
     * @see #matchLeftmostWildcard(String,String)
     * @see #matchWildCards(String,String)
     */
    private static boolean matchAllWildcards(String name, String template) {
        name = name.toLowerCase();
        template = template.toLowerCase();
        StringTokenizer nameSt = new StringTokenizer(name, ".");
        StringTokenizer templateSt = new StringTokenizer(template, ".");

        if (nameSt.countTokens() != templateSt.countTokens()) {
            return false;
        }

        while (nameSt.hasMoreTokens()) {
            if (!matchWildCards(nameSt.nextToken(),
                        templateSt.nextToken())) {
                return false;
            }
        }
        return true;
    }

    /**
     * Returns true if name matches against template.
     *
     * As per RFC 2830, section 3.6 -
     * The "*" wildcard character is allowed.  If present, it applies only
     * to the left-most name component.
     * E.g. *.bar.com would match a.bar.com, b.bar.com, etc. but not
     * bar.com.
     * @param name should represent a DNS name.
     * @param template may contain the wildcard character *
     * @return boolean whether name matches template
     * @see #isMatched(String,String)
     * @see #matchAllWildcards(String,String)
     * @see #matchWildCards(String,String)
     */
    private static boolean matchLeftmostWildcard(String name, String template) {
        name = name.toLowerCase();
        template = template.toLowerCase();

        // Retreive leftmost component
        int templateIdx = template.indexOf(".");
        int nameIdx = name.indexOf(".");

        if (templateIdx == -1)
            templateIdx = template.length();
        if (nameIdx == -1)
            nameIdx = name.length();

        if (matchWildCards(name.substring(0, nameIdx),
            template.substring(0, templateIdx))) {

            // match rest of the name
            return template.substring(templateIdx).equals(
                        name.substring(nameIdx));
        } else {
            return false;
        }
    }

    /**
     * Returns true if the name matches against the template that may
     * contain wildcard char *.
     * @param name should represent a DNS name.
     * @param template may contain the wildcard character *
     * @return boolean whether name matches template
     * @see #isMatched(String,String)
     * @see #matchAllWildcards(String,String)
     * @see #matchLeftmostWildcard(String,String)
     */
    private static boolean matchWildCards(String name, String template) {

        int wildcardIdx = template.indexOf("*");
        if (wildcardIdx == -1)
            return name.equals(template);

        boolean isBeginning = true;
        String beforeWildcard = "";
        String afterWildcard = template;

        while (wildcardIdx != -1) {

            // match in sequence the non-wildcard chars in the template.
            beforeWildcard = afterWildcard.substring(0, wildcardIdx);
            afterWildcard = afterWildcard.substring(wildcardIdx + 1);

            int beforeStartIdx = name.indexOf(beforeWildcard);
            if ((beforeStartIdx == -1) ||
                        (isBeginning && beforeStartIdx != 0)) {
                return false;
            }
            isBeginning = false;

            // update the match scope
            name = name.substring(beforeStartIdx + beforeWildcard.length());
            wildcardIdx = afterWildcard.indexOf("*");
        }
        return name.endsWith(afterWildcard);
    }

    /**
     * Converts IPv4 address in its textual presentation form
     * into its numeric binary form.
     *
     * @param src a String representing an IPv4 address in standard format
     * @return byte[] representing the IPv4 numeric address
     */
    private static byte[] textToNumericFormatV4(String src)
    {
        if (src.length() == 0) {
            return null;
        }

        byte[] res = new byte[INADDR4SZ];
        String[] s = src.split("\\.", -1);
        long val;
        try {
            switch(s.length) {
            case 1:
                /*
                 * When only one part is given, the value is stored directly in
                 * the network address without any byte rearrangement.
                 */

                val = Long.parseLong(s[0]);
                if (val < 0 || val > 0xffffffffL)
                    return null;
                res[0] = (byte) ((val >> 24) & 0xff);
                res[1] = (byte) (((val & 0xffffff) >> 16) & 0xff);
                res[2] = (byte) (((val & 0xffff) >> 8) & 0xff);
                res[3] = (byte) (val & 0xff);
                break;
            case 2:
                /*
                 * When a two part address is supplied, the last part is
                 * interpreted as a 24-bit quantity and placed in the right
                 * most three bytes of the network address. This makes the
                 * two part address format convenient for specifying Class A
                 * network addresses as net.host.
                 */

                val = Integer.parseInt(s[0]);
                if (val < 0 || val > 0xff)
                    return null;
                res[0] = (byte) (val & 0xff);
                val = Integer.parseInt(s[1]);
                if (val < 0 || val > 0xffffff)
                    return null;
                res[1] = (byte) ((val >> 16) & 0xff);
                res[2] = (byte) (((val & 0xffff) >> 8) &0xff);
                res[3] = (byte) (val & 0xff);
                break;
            case 3:
                /*
                 * When a three part address is specified, the last part is
                 * interpreted as a 16-bit quantity and placed in the right
                 * most two bytes of the network address. This makes the
                 * three part address format convenient for specifying
                 * Class B net- work addresses as 128.net.host.
                 */
                for (int i = 0; i < 2; i++) {
                    val = Integer.parseInt(s[i]);
                    if (val < 0 || val > 0xff)
                        return null;
                    res[i] = (byte) (val & 0xff);
                }
                val = Integer.parseInt(s[2]);
                if (val < 0 || val > 0xffff)
                    return null;
                res[2] = (byte) ((val >> 8) & 0xff);
                res[3] = (byte) (val & 0xff);
                break;
            case 4:
                /*
                 * When four parts are specified, each is interpreted as a
                 * byte of data and assigned, from left to right, to the
                 * four bytes of an IPv4 address.
                 */
                for (int i = 0; i < 4; i++) {
                    val = Integer.parseInt(s[i]);
                    if (val < 0 || val > 0xff)
                        return null;
                    res[i] = (byte) (val & 0xff);
                }
                break;
            default:
                return null;
            }
        } catch(NumberFormatException e) {
            return null;
        }
        return res;
    }

    /**
     * Convert IPv6 presentation level address to network order binary form.
     * credit:
     *  Converted from C code from Solaris 8 (inet_pton)
     *
     * Any component of the string following a per-cent % is ignored.
     *
     * @param src a String representing an IPv6 address in textual format
     * @return byte[] representing the IPv6 numeric address
     */
    private static byte[] textToNumericFormatV6(String src)
    {
        // Shortest valid string is "::", hence at least 2 chars
        if (src.length() < 2) {
            return null;
        }

        int colonp;
        char ch;
        boolean saw_xdigit;
        int val;
        char[] srcb = src.toCharArray();
        byte[] dst = new byte[INADDR16SZ];

        int srcb_length = srcb.length;
        int pc = src.indexOf ("%");
        if (pc == srcb_length -1) {
            return null;
        }

        if (pc != -1) {
            srcb_length = pc;
        }

        colonp = -1;
        int i = 0, j = 0;
        /* Leading :: requires some special handling. */
        if (srcb[i] == ':')
            if (srcb[++i] != ':')
                return null;
        int curtok = i;
        saw_xdigit = false;
        val = 0;
        while (i < srcb_length) {
            ch = srcb[i++];
            int chval = Character.digit(ch, 16);
            if (chval != -1) {
                val <<= 4;
                val |= chval;
                if (val > 0xffff)
                    return null;
                saw_xdigit = true;
                continue;
            }
            if (ch == ':') {
                curtok = i;
                if (!saw_xdigit) {
                    if (colonp != -1)
                        return null;
                    colonp = j;
                    continue;
                } else if (i == srcb_length) {
                    return null;
                }
                if (j + INT16SZ > INADDR16SZ)
                    return null;
                dst[j++] = (byte) ((val >> 8) & 0xff);
                dst[j++] = (byte) (val & 0xff);
                saw_xdigit = false;
                val = 0;
                continue;
            }
            if (ch == '.' && ((j + INADDR4SZ) <= INADDR16SZ)) {
                String ia4 = src.substring(curtok, srcb_length);
                /* check this IPv4 address has 3 dots, ie. A.B.C.D */
                int dot_count = 0, index=0;
                while ((index = ia4.indexOf ('.', index)) != -1) {
                    dot_count ++;
                    index ++;
                }
                if (dot_count != 3) {
                    return null;
                }
                byte[] v4addr = textToNumericFormatV4(ia4);
                if (v4addr == null) {
                    return null;
                }
                for (int k = 0; k < INADDR4SZ; k++) {
                    dst[j++] = v4addr[k];
                }
                saw_xdigit = false;
                break;  /* '\0' was seen by inet_pton4(). */
            }
            return null;
        }
        if (saw_xdigit) {
            if (j + INT16SZ > INADDR16SZ)
                return null;
            dst[j++] = (byte) ((val >> 8) & 0xff);
            dst[j++] = (byte) (val & 0xff);
        }

        if (colonp != -1) {
            int n = j - colonp;

            if (j == INADDR16SZ)
                return null;
            for (i = 1; i <= n; i++) {
                dst[INADDR16SZ - i] = dst[colonp + n - i];
                dst[colonp + n - i] = 0;
            }
            j = INADDR16SZ;
        }
        if (j != INADDR16SZ)
            return null;
        byte[] newdst = convertFromIPv4MappedAddress(dst);
        if (newdst != null) {
            return newdst;
        } else {
            return dst;
        }
    }

    /**
     * Checks whether <CODE>src</CODE> is an IPv4 address.
     * @param src <CODE>String</CODE> representing an IPv4 address in textual format.
     * @return boolean indicating whether <CODE>src</CODE> is an IPv4 literal address
     */
    private static boolean isIPv4LiteralAddress(String src) {
        return textToNumericFormatV4(src) != null;
    }

    /**
     * Checks whether <CODE>src</CODE> is an IPv6 address.
     * @param src <CODE>String</CODE> representing an IPv6 address in textual format.
     * @return boolean indicating whether <CODE>src</CODE> is an IPv6 literal address.
     */
    private static boolean isIPv6LiteralAddress(String src) {
        return textToNumericFormatV6(src) != null;
    }

    /**
     * Converts IPv4-Mapped address to IPv4 address. Both input and
     * returned value are in network order binary form.
     *
     * @param addr <CODE>byte[]</CODE> representing an IPv4-Mapped address 
     * @return byte[] representing the IPv4 numeric address or <CODE>null</CODE>
     */
    private static byte[] convertFromIPv4MappedAddress(byte[] addr) {
        if (isIPv4MappedAddress(addr)) {
            byte[] newAddr = new byte[INADDR4SZ];
            System.arraycopy(addr, 12, newAddr, 0, INADDR4SZ);
            return newAddr;
        }
        return null;
    }

    /**
     * Utility routine to check if the InetAddress is an
     * IPv4 mapped IPv6 address.
     *
     * @param addr <CODE>byte[]</CODE> describing the address.
     * @return <CODE>boolean</CODE>: <CODE>true</CODE> if the InetAddress is
     * an IPv4 mapped IPv6 address; or <CODE>false</CODE> if address is IPv4 address.
     */
    private static boolean isIPv4MappedAddress(byte[] addr) {
        if (addr.length < INADDR16SZ) {
            return false;
        }
        if ((addr[0] == 0x00) && (addr[1] == 0x00) &&
            (addr[2] == 0x00) && (addr[3] == 0x00) &&
            (addr[4] == 0x00) && (addr[5] == 0x00) &&
            (addr[6] == 0x00) && (addr[7] == 0x00) &&
            (addr[8] == 0x00) && (addr[9] == 0x00) &&
            (addr[10] == (byte)0xff) &&
            (addr[11] == (byte)0xff))  {
            return true;
        }
        return false;
    }
}
