aboutsummaryrefslogtreecommitdiffstats
path: root/src/main/java/de/pixart/messenger/crypto/XmppDomainVerifier.java
blob: a894dbc1b087fd20ae2d992d79442bd9bb89b19c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
package de.pixart.messenger.crypto;

import android.util.Log;
import android.util.Pair;

import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.DERIA5String;
import org.bouncycastle.asn1.DERTaggedObject;
import org.bouncycastle.asn1.DERUTF8String;
import org.bouncycastle.asn1.DLSequence;
import org.bouncycastle.asn1.x500.RDN;
import org.bouncycastle.asn1.x500.X500Name;
import org.bouncycastle.asn1.x500.style.BCStyle;
import org.bouncycastle.asn1.x500.style.IETFUtils;
import org.bouncycastle.cert.jcajce.JcaX509CertificateHolder;

import java.io.IOException;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import javax.net.ssl.SSLSession;

import de.duenndns.ssl.DomainHostnameVerifier;

public class XmppDomainVerifier implements DomainHostnameVerifier {

    private static final String LOGTAG = "XmppDomainVerifier";

    private final String SRVName = "1.3.6.1.5.5.7.8.7";
    private final String xmppAddr = "1.3.6.1.5.5.7.8.5";

    @Override
    public boolean verify(String domain, String hostname, SSLSession sslSession) {
        try {
            Certificate[] chain = sslSession.getPeerCertificates();
            if (chain.length == 0 || !(chain[0] instanceof X509Certificate)) {
                return false;
            }
            X509Certificate certificate = (X509Certificate) chain[0];
            if (isSelfSigned(certificate)) {
                List<String> domains = getCommonNames(certificate);
                if (domains.size() == 1 && domains.get(0).equals(domain)) {
                    Log.d(LOGTAG, "accepted CN in cert self signed cert for " + domain);
                    return true;
                }
            }
            Collection<List<?>> alternativeNames = certificate.getSubjectAlternativeNames();
            List<String> xmppAddrs = new ArrayList<>();
            List<String> srvNames = new ArrayList<>();
            List<String> domains = new ArrayList<>();
            if (alternativeNames != null) {
                for (List<?> san : alternativeNames) {
                    Integer type = (Integer) san.get(0);
                    if (type == 0) {
                        Pair<String, String> otherName = parseOtherName((byte[]) san.get(1));
                        if (otherName != null) {
                            switch (otherName.first) {
                                case SRVName:
                                    srvNames.add(otherName.second);
                                    break;
                                case xmppAddr:
                                    xmppAddrs.add(otherName.second);
                                    break;
                                default:
                                    Log.d(LOGTAG, "oid: " + otherName.first + " value: " + otherName.second);
                            }
                        }
                    } else if (type == 2) {
                        Object value = san.get(1);
                        if (value instanceof String) {
                            domains.add((String) value);
                        }
                    }
                }
            }
            if (srvNames.size() == 0 && xmppAddrs.size() == 0 && domains.size() == 0) {
                domains.addAll(domains);
            }
            Log.d(LOGTAG, "searching for " + domain + " in srvNames: " + srvNames + " xmppAddrs: " + xmppAddrs + " domains:" + domains);
            if (hostname != null) {
                Log.d(LOGTAG, "also trying to verify hostname " + hostname);
            }
            return xmppAddrs.contains(domain)
                    || srvNames.contains("_xmpp-client." + domain)
                    || matchDomain(domain, domains)
                    || (hostname != null && matchDomain(hostname, domains));
        } catch (Exception e) {
            return false;
        }
    }

    private static List<String> getCommonNames(X509Certificate certificate) {
        List<String> domains = new ArrayList<>();
        try {
            X500Name x500name = new JcaX509CertificateHolder(certificate).getSubject();
            RDN[] rdns = x500name.getRDNs(BCStyle.CN);
            for (int i = 0; i < rdns.length; ++i) {
                domains.add(IETFUtils.valueToString(x500name.getRDNs(BCStyle.CN)[i].getFirst().getValue()));
            }
            return domains;
        } catch (CertificateEncodingException e) {
            return domains;
        }
    }

    private static Pair<String, String> parseOtherName(byte[] otherName) {
        try {
            ASN1Primitive asn1Primitive = ASN1Primitive.fromByteArray(otherName);
            if (asn1Primitive instanceof DERTaggedObject) {
                ASN1Primitive inner = ((DERTaggedObject) asn1Primitive).getObject();
                if (inner instanceof DLSequence) {
                    DLSequence sequence = (DLSequence) inner;
                    if (sequence.size() >= 2 && sequence.getObjectAt(1) instanceof DERTaggedObject) {
                        String oid = sequence.getObjectAt(0).toString();
                        ASN1Primitive value = ((DERTaggedObject) sequence.getObjectAt(1)).getObject();
                        if (value instanceof DERUTF8String) {
                            return new Pair<>(oid, ((DERUTF8String) value).getString());
                        } else if (value instanceof DERIA5String) {
                            return new Pair<>(oid, ((DERIA5String) value).getString());
                        }
                    }
                }
            }
            return null;
        } catch (IOException e) {
            return null;
        }
    }

    private static boolean matchDomain(String needle, List<String> haystack) {
        for (String entry : haystack) {
            if (entry.startsWith("*.")) {
                int i = needle.indexOf('.');
                Log.d(LOGTAG, "comparing " + needle.substring(i) + " and " + entry.substring(1));
                if (i != -1 && needle.substring(i).equals(entry.substring(1))) {
                    Log.d(LOGTAG, "domain " + needle + " matched " + entry);
                    return true;
                }
            } else {
                if (entry.equals(needle)) {
                    Log.d(LOGTAG, "domain " + needle + " matched " + entry);
                    return true;
                }
            }
        }
        return false;
    }

    private boolean isSelfSigned(X509Certificate certificate) {
        try {
            certificate.verify(certificate.getPublicKey());
            return true;
        } catch (Exception e) {
            return false;
        }
    }

    @Override
    public boolean verify(String domain, SSLSession sslSession) {
        return verify(domain, null, sslSession);
    }
}