aboutsummaryrefslogtreecommitdiffstats
path: root/src/main/java/de/pixart/messenger/crypto/XmppDomainVerifier.java
blob: 26aa268e4152bfa535e0e436426baccd09a0fc3d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package de.pixart.messenger.crypto;

import android.util.Log;
import android.util.Pair;

import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.DERIA5String;
import org.bouncycastle.asn1.DERTaggedObject;
import org.bouncycastle.asn1.DERUTF8String;
import org.bouncycastle.asn1.DLSequence;
import org.bouncycastle.asn1.x500.RDN;
import org.bouncycastle.asn1.x500.X500Name;
import org.bouncycastle.asn1.x500.style.BCStyle;
import org.bouncycastle.asn1.x500.style.IETFUtils;
import org.bouncycastle.cert.jcajce.JcaX509CertificateHolder;

import java.io.IOException;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import javax.net.ssl.SSLSession;

import de.duenndns.ssl.DomainHostnameVerifier;

public class XmppDomainVerifier implements DomainHostnameVerifier {

    private static final String LOGTAG = "XmppDomainVerifier";

    private final String SRVName = "1.3.6.1.5.5.7.8.7";
    private final String xmppAddr = "1.3.6.1.5.5.7.8.5";

    @Override
    public boolean verify(String domain, String hostname, SSLSession sslSession) {
        try {
            Certificate[] chain = sslSession.getPeerCertificates();
            if (chain.length == 0 || !(chain[0] instanceof X509Certificate)) {
                return false;
            }
            X509Certificate certificate = (X509Certificate) chain[0];
            Collection<List<?>> alternativeNames = certificate.getSubjectAlternativeNames();
            List<String> xmppAddrs = new ArrayList<>();
            List<String> srvNames = new ArrayList<>();
            List<String> domains = new ArrayList<>();
            if (alternativeNames != null) {
                for (List<?> san : alternativeNames) {
                    Integer type = (Integer) san.get(0);
                    if (type == 0) {
                        Pair<String, String> otherName = parseOtherName((byte[]) san.get(1));
                        if (otherName != null) {
                            switch (otherName.first) {
                                case SRVName:
                                    srvNames.add(otherName.second);
                                    break;
                                case xmppAddr:
                                    xmppAddrs.add(otherName.second);
                                    break;
                                default:
                                    Log.d(LOGTAG, "oid: " + otherName.first + " value: " + otherName.second);
                            }
                        }
                    } else if (type == 2) {
                        Object value = san.get(1);
                        if (value instanceof String) {
                            domains.add((String) value);
                        }
                    }
                }
            }
            if (srvNames.size() == 0 && xmppAddrs.size() == 0 && domains.size() == 0) {
                X500Name x500name = new JcaX509CertificateHolder(certificate).getSubject();
                RDN[] rdns = x500name.getRDNs(BCStyle.CN);
                for (int i = 0; i < rdns.length; ++i) {
                    domains.add(IETFUtils.valueToString(x500name.getRDNs(BCStyle.CN)[i].getFirst().getValue()));
                }
            }
            Log.d(LOGTAG, "searching for " + domain + " in srvNames: " + srvNames + " xmppAddrs: " + xmppAddrs + " domains:" + domains);
            if (hostname != null) {
                Log.d(LOGTAG, "also trying to verify hostname " + hostname);
            }
            return xmppAddrs.contains(domain)
                    || srvNames.contains("_xmpp-client." + domain)
                    || matchDomain(domain, domains)
                    || (hostname != null && matchDomain(hostname, domains));
        } catch (Exception e) {
            return false;
        }
    }

    private static Pair<String, String> parseOtherName(byte[] otherName) {
        try {
            ASN1Primitive asn1Primitive = ASN1Primitive.fromByteArray(otherName);
            if (asn1Primitive instanceof DERTaggedObject) {
                ASN1Primitive inner = ((DERTaggedObject) asn1Primitive).getObject();
                if (inner instanceof DLSequence) {
                    DLSequence sequence = (DLSequence) inner;
                    if (sequence.size() >= 2 && sequence.getObjectAt(1) instanceof DERTaggedObject) {
                        String oid = sequence.getObjectAt(0).toString();
                        ASN1Primitive value = ((DERTaggedObject) sequence.getObjectAt(1)).getObject();
                        if (value instanceof DERUTF8String) {
                            return new Pair<>(oid, ((DERUTF8String) value).getString());
                        } else if (value instanceof DERIA5String) {
                            return new Pair<>(oid, ((DERIA5String) value).getString());
                        }
                    }
                }
            }
            return null;
        } catch (IOException e) {
            return null;
        }
    }

    private static boolean matchDomain(String needle, List<String> haystack) {
        for (String entry : haystack) {
            if (entry.startsWith("*.")) {
                int i = needle.indexOf('.');
                Log.d(LOGTAG, "comparing " + needle.substring(i) + " and " + entry.substring(1));
                if (i != -1 && needle.substring(i).equals(entry.substring(1))) {
                    Log.d(LOGTAG, "domain " + needle + " matched " + entry);
                    return true;
                }
            } else {
                if (entry.equals(needle)) {
                    Log.d(LOGTAG, "domain " + needle + " matched " + entry);
                    return true;
                }
            }
        }
        return false;
    }

    @Override
    public boolean verify(String domain, SSLSession sslSession) {
        return verify(domain, null, sslSession);
    }
}