From 217f6603c018ca0d72925f8d8f41f605e9edf4db Mon Sep 17 00:00:00 2001 From: moparisthebest Date: Mon, 11 Jan 2016 17:25:16 -0500 Subject: Implement XEP-0368: SRV records for XMPP over TLS --- .../eu/siacs/conversations/utils/DNSHelper.java | 121 +++++++----- .../siacs/conversations/xmpp/XmppConnection.java | 203 +++++++++++++++------ 2 files changed, 220 insertions(+), 104 deletions(-) (limited to 'src/main/java') diff --git a/src/main/java/eu/siacs/conversations/utils/DNSHelper.java b/src/main/java/eu/siacs/conversations/utils/DNSHelper.java index e07df627..87790d64 100644 --- a/src/main/java/eu/siacs/conversations/utils/DNSHelper.java +++ b/src/main/java/eu/siacs/conversations/utils/DNSHelper.java @@ -19,6 +19,7 @@ import java.util.Collections; import java.util.List; import java.util.Random; import java.util.TreeMap; +import java.util.Map; import java.util.regex.Pattern; import de.measite.minidns.Client; @@ -57,7 +58,7 @@ public class DNSHelper { if (!b.containsKey("values")) { Log.d(Config.LOGTAG,"all dns queries failed. provide fallback A record"); ArrayList values = new ArrayList<>(); - values.add(createNamePortBundle(host,5222)); + values.add(createNamePortBundle(host, 5222, false)); b.putParcelableArrayList("values",values); } return b; @@ -96,57 +97,73 @@ public class DNSHelper { return servers; } - public static Bundle queryDNS(String host, InetAddress dnsServer) { - Bundle bundle = new Bundle(); - try { - client.setTimeout(Config.PING_TIMEOUT * 1000); - String qname = "_xmpp-client._tcp." + host; - Log.d(Config.LOGTAG, "using dns server: " + dnsServer.getHostAddress() + " to look up " + host); - DNSMessage message = client.query(qname, TYPE.SRV, CLASS.IN, dnsServer.getHostAddress()); - - TreeMap> priorities = new TreeMap<>(); - TreeMap> ips4 = new TreeMap<>(); - TreeMap> ips6 = new TreeMap<>(); - - for (Record[] rrset : new Record[][] { message.getAnswers(), message.getAdditionalResourceRecords() }) { - for (Record rr : rrset) { - Data d = rr.getPayload(); - if (d instanceof SRV && NameUtil.idnEquals(qname, rr.getName())) { - SRV srv = (SRV) d; - if (!priorities.containsKey(srv.getPriority())) { - priorities.put(srv.getPriority(),new ArrayList()); - } - priorities.get(srv.getPriority()).add(srv); + private static class TlsSrv { + private final SRV srv; + private final boolean tls; + + public TlsSrv(SRV srv, boolean tls) { + this.srv = srv; + this.tls = tls; + } + } + + private static void fillSrvMaps(final String qname, final InetAddress dnsServer, final Map> priorities, final Map> ips4, final Map> ips6, final boolean tls) throws IOException { + final DNSMessage message = client.query(qname, TYPE.SRV, CLASS.IN, dnsServer.getHostAddress()); + for (Record[] rrset : new Record[][] { message.getAnswers(), message.getAdditionalResourceRecords() }) { + for (Record rr : rrset) { + Data d = rr.getPayload(); + if (d instanceof SRV && NameUtil.idnEquals(qname, rr.getName())) { + SRV srv = (SRV) d; + if (!priorities.containsKey(srv.getPriority())) { + priorities.put(srv.getPriority(),new ArrayList()); } - if (d instanceof A) { - A a = (A) d; - if (!ips4.containsKey(rr.getName())) { - ips4.put(rr.getName(), new ArrayList()); - } - ips4.get(rr.getName()).add(a.toString()); + priorities.get(srv.getPriority()).add(new TlsSrv(srv, tls)); + } + if (d instanceof A) { + A a = (A) d; + if (!ips4.containsKey(rr.getName())) { + ips4.put(rr.getName(), new ArrayList()); } - if (d instanceof AAAA) { - AAAA aaaa = (AAAA) d; - if (!ips6.containsKey(rr.getName())) { - ips6.put(rr.getName(), new ArrayList()); - } - ips6.get(rr.getName()).add("[" + aaaa.toString() + "]"); + ips4.get(rr.getName()).add(a.toString()); + } + if (d instanceof AAAA) { + AAAA aaaa = (AAAA) d; + if (!ips6.containsKey(rr.getName())) { + ips6.put(rr.getName(), new ArrayList()); } + ips6.get(rr.getName()).add("[" + aaaa.toString() + "]"); } } + } + } + + public static Bundle queryDNS(String host, InetAddress dnsServer) { + Bundle bundle = new Bundle(); + try { + client.setTimeout(Config.PING_TIMEOUT * 1000); + final String qname = "_xmpp-client._tcp." + host; + final String tlsQname = "_xmpps-client._tcp." + host; + Log.d(Config.LOGTAG, "using dns server: " + dnsServer.getHostAddress() + " to look up " + host); + + final Map> priorities = new TreeMap<>(); + final Map> ips4 = new TreeMap<>(); + final Map> ips6 = new TreeMap<>(); + + fillSrvMaps(qname, dnsServer, priorities, ips4, ips6, false); + fillSrvMaps(tlsQname, dnsServer, priorities, ips4, ips6, true); - ArrayList result = new ArrayList<>(); - for (ArrayList s : priorities.values()) { + final List result = new ArrayList<>(); + for (final List s : priorities.values()) { result.addAll(s); } - ArrayList values = new ArrayList<>(); + final ArrayList values = new ArrayList<>(); if (result.size() == 0) { DNSMessage response; try { response = client.query(host, TYPE.A, CLASS.IN, dnsServer.getHostAddress()); for (int i = 0; i < response.getAnswers().length; ++i) { - values.add(createNamePortBundle(host, 5222, response.getAnswers()[i].getPayload())); + values.add(createNamePortBundle(host, 5222, response.getAnswers()[i].getPayload(), false)); } } catch (SocketTimeoutException e) { Log.d(Config.LOGTAG,"ignoring timeout exception when querying A record on "+dnsServer.getHostAddress()); @@ -154,37 +171,38 @@ public class DNSHelper { try { response = client.query(host, TYPE.AAAA, CLASS.IN, dnsServer.getHostAddress()); for (int i = 0; i < response.getAnswers().length; ++i) { - values.add(createNamePortBundle(host, 5222, response.getAnswers()[i].getPayload())); + values.add(createNamePortBundle(host, 5222, response.getAnswers()[i].getPayload(), false)); } } catch (SocketTimeoutException e) { Log.d(Config.LOGTAG,"ignoring timeout exception when querying AAAA record on "+dnsServer.getHostAddress()); } - values.add(createNamePortBundle(host,5222)); + values.add(createNamePortBundle(host, 5222, false)); bundle.putParcelableArrayList("values", values); return bundle; } - for (SRV srv : result) { + for (final TlsSrv tlsSrv : result) { + final SRV srv = tlsSrv.srv; if (ips6.containsKey(srv.getName())) { - values.add(createNamePortBundle(srv.getName(),srv.getPort(),ips6)); + values.add(createNamePortBundle(srv.getName(),srv.getPort(),ips6, tlsSrv.tls)); } else { try { DNSMessage response = client.query(srv.getName(), TYPE.AAAA, CLASS.IN, dnsServer.getHostAddress()); for (int i = 0; i < response.getAnswers().length; ++i) { - values.add(createNamePortBundle(srv.getName(), srv.getPort(), response.getAnswers()[i].getPayload())); + values.add(createNamePortBundle(srv.getName(), srv.getPort(), response.getAnswers()[i].getPayload(), tlsSrv.tls)); } } catch (SocketTimeoutException e) { Log.d(Config.LOGTAG,"ignoring timeout exception when querying AAAA record on "+dnsServer.getHostAddress()); } } if (ips4.containsKey(srv.getName())) { - values.add(createNamePortBundle(srv.getName(),srv.getPort(),ips4)); + values.add(createNamePortBundle(srv.getName(),srv.getPort(),ips4, tlsSrv.tls)); } else { DNSMessage response = client.query(srv.getName(), TYPE.A, CLASS.IN, dnsServer.getHostAddress()); for(int i = 0; i < response.getAnswers().length; ++i) { - values.add(createNamePortBundle(srv.getName(),srv.getPort(),response.getAnswers()[i].getPayload())); + values.add(createNamePortBundle(srv.getName(),srv.getPort(),response.getAnswers()[i].getPayload(), tlsSrv.tls)); } } - values.add(createNamePortBundle(srv.getName(), srv.getPort())); + values.add(createNamePortBundle(srv.getName(), srv.getPort(), tlsSrv.tls)); } bundle.putParcelableArrayList("values", values); } catch (SocketTimeoutException e) { @@ -195,28 +213,31 @@ public class DNSHelper { return bundle; } - private static Bundle createNamePortBundle(String name, int port) { + private static Bundle createNamePortBundle(String name, int port, final boolean tls) { Bundle namePort = new Bundle(); namePort.putString("name", name); + namePort.putBoolean("tls", tls); namePort.putInt("port", port); return namePort; } - private static Bundle createNamePortBundle(String name, int port, TreeMap> ips) { + private static Bundle createNamePortBundle(String name, int port, Map> ips, final boolean tls) { Bundle namePort = new Bundle(); namePort.putString("name", name); + namePort.putBoolean("tls", tls); namePort.putInt("port", port); if (ips!=null) { - ArrayList ip = ips.get(name); + List ip = ips.get(name); Collections.shuffle(ip, new Random()); namePort.putString("ip", ip.get(0)); } return namePort; } - private static Bundle createNamePortBundle(String name, int port, Data data) { + private static Bundle createNamePortBundle(String name, int port, Data data, final boolean tls) { Bundle namePort = new Bundle(); namePort.putString("name", name); + namePort.putBoolean("tls", tls); namePort.putInt("port", port); if (data instanceof A) { namePort.putString("ip", data.toString()); diff --git a/src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java b/src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java index 4bc44ce9..be001a7f 100644 --- a/src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java +++ b/src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java @@ -20,7 +20,7 @@ import org.xmlpull.v1.XmlPullParserException; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; -import java.io.OutputStream; +import java.lang.reflect.Method; import java.math.BigInteger; import java.net.ConnectException; import java.net.IDN; @@ -247,6 +247,7 @@ public class XmppConnection implements Runnable { } Log.d(Config.LOGTAG,account.getJid().toBareJid()+": connect to "+destination+" via TOR"); socket = SocksSocketFactory.createSocketOverTor(destination,account.getPort()); + startXmpp(); } else if (DNSHelper.isIp(account.getServer().toString())) { socket = new Socket(); try { @@ -254,6 +255,7 @@ public class XmppConnection implements Runnable { } catch (IOException e) { throw new UnknownHostException(); } + startXmpp(); } else { final Bundle result = DNSHelper.getSRVRecord(account.getServer(), mXmppConnectionService); final ArrayListvalues = result.getParcelableArrayList("values"); @@ -269,24 +271,46 @@ public class XmppConnection implements Runnable { } final int srvRecordPort = namePort.getInt("port"); final String srvIpServer = namePort.getString("ip"); + // if tls is true, encryption is implied and must not be started + features.encryptionEnabled = namePort.getBoolean("tls"); final InetSocketAddress addr; if (srvIpServer != null) { addr = new InetSocketAddress(srvIpServer, srvRecordPort); Log.d(Config.LOGTAG, account.getJid().toBareJid().toString() + ": using values from dns " + srvRecordServer - + "[" + srvIpServer + "]:" + srvRecordPort); + + "[" + srvIpServer + "]:" + srvRecordPort + " tls: " + features.encryptionEnabled); } else { addr = new InetSocketAddress(srvRecordServer, srvRecordPort); Log.d(Config.LOGTAG, account.getJid().toBareJid().toString() + ": using values from dns " - + srvRecordServer + ":" + srvRecordPort); + + srvRecordServer + ":" + srvRecordPort + " tls: " + features.encryptionEnabled); } - socket = new Socket(); - socket.connect(addr, Config.SOCKET_TIMEOUT * 1000); - tagWriter.setOutputStream(socket.getOutputStream()); - tagReader.setInputStream(socket.getInputStream()); - tagWriter.beginDocument(); - sendStartStream(); + + if (!features.encryptionEnabled) { + socket = new Socket(); + socket.connect(addr, Config.SOCKET_TIMEOUT * 1000); + } else { + final TlsFactoryVerifier tlsFactoryVerifier = getTlsFactoryVerifier(); + socket = tlsFactoryVerifier.factory.createSocket(); + + if (socket == null) { + throw new IOException("could not initialize ssl socket"); + } + + setSSLSocketSecurity((SSLSocket) socket); + this.setSNIHost(tlsFactoryVerifier.factory, (SSLSocket) socket, account.getServer().getDomainpart()); + this.setAlpnProtocol(tlsFactoryVerifier.factory, (SSLSocket) socket, "xmpp-client"); + + socket.connect(addr, Config.SOCKET_TIMEOUT * 1000); + + if (!tlsFactoryVerifier.verifier.verify(account.getServer().getDomainpart(), ((SSLSocket) socket).getSession())) { + Log.d(Config.LOGTAG, account.getJid().toBareJid() + ": TLS certificate verification failed"); + throw new SecurityException(); + } + } + + if(startXmpp()) + break; // successfully connected to server that speaks xmpp } catch (final Throwable e) { Log.d(Config.LOGTAG, account.getJid().toBareJid().toString() + ": " + e.getMessage() +"("+e.getClass().getName()+")"); if (!iterator.hasNext()) { @@ -295,18 +319,7 @@ public class XmppConnection implements Runnable { } } } - Tag nextTag; - while ((nextTag = tagReader.readTag()) != null) { - if (nextTag.isStart("stream")) { - processStream(); - break; - } else { - throw new IOException("unknown tag on connect"); - } - } - if (socket.isConnected()) { - socket.close(); - } + processStream(); } catch (final IncompatibleServerException e) { this.changeStatus(Account.State.INCOMPATIBLE_SERVER); } catch (final SecurityException e) { @@ -338,6 +351,99 @@ public class XmppConnection implements Runnable { } } + /** + * Starts xmpp protocol, call after connecting to socket + * @return true if server returns with valid xmpp, false otherwise + * @throws IOException Unknown tag on connect + * @throws XmlPullParserException Bad Xml + * @throws NoSuchAlgorithmException Other error + */ + private boolean startXmpp() throws IOException, XmlPullParserException, NoSuchAlgorithmException { + tagWriter.setOutputStream(socket.getOutputStream()); + tagReader.setInputStream(socket.getInputStream()); + tagWriter.beginDocument(); + sendStartStream(); + Tag nextTag; + while ((nextTag = tagReader.readTag()) != null) { + if (nextTag.isStart("stream")) { + return true; + } else { + throw new IOException("unknown tag on connect"); + } + } + if (socket.isConnected()) { + socket.close(); + } + return false; + } + + private void setSNIHost(final SSLSocketFactory factory, final SSLSocket socket, final String hostname) { + if (factory instanceof android.net.SSLCertificateSocketFactory && android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.JELLY_BEAN_MR1) { + ((android.net.SSLCertificateSocketFactory) factory).setHostname(socket, hostname); + } else { + try { + socket.getClass().getMethod("setHostname", String.class).invoke(socket, hostname); + } catch (Throwable e) { + // ignore any error, we just can't set the hostname... + } + } + } + + private void setAlpnProtocol(final SSLSocketFactory factory, final SSLSocket socket, final String protocol) { + try { + if (factory instanceof android.net.SSLCertificateSocketFactory && android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.KITKAT) { + // can't call directly because of @hide? + //((android.net.SSLCertificateSocketFactory)factory).setAlpnProtocols(new byte[][]{protocol.getBytes("UTF-8")}); + android.net.SSLCertificateSocketFactory.class.getMethod("setAlpnProtocols", byte[][].class).invoke(socket, new Object[]{new byte[][]{protocol.getBytes("UTF-8")}}); + } else { + final Method method = socket.getClass().getMethod("setAlpnProtocols", byte[].class); + // the concatenation of 8-bit, length prefixed protocol names, just one in our case... + // http://tools.ietf.org/html/draft-agl-tls-nextprotoneg-04#page-4 + final byte[] protocolUTF8Bytes = protocol.getBytes("UTF-8"); + final byte[] lengthPrefixedProtocols = new byte[protocolUTF8Bytes.length + 1]; + lengthPrefixedProtocols[0] = (byte) protocol.length(); // cannot be over 255 anyhow + System.arraycopy(protocolUTF8Bytes, 0, lengthPrefixedProtocols, 1, protocolUTF8Bytes.length); + method.invoke(socket, new Object[]{lengthPrefixedProtocols}); + } + } catch (Throwable e) { + // ignore any error, we just can't set the alpn protocol... + } + } + + private static class TlsFactoryVerifier { + private final SSLSocketFactory factory; + private final HostnameVerifier verifier; + + public TlsFactoryVerifier(final SSLSocketFactory factory, final HostnameVerifier verifier) throws IOException { + this.factory = factory; + this.verifier = verifier; + if (factory == null || verifier == null) { + throw new IOException("could not setup ssl"); + } + } + } + + private TlsFactoryVerifier getTlsFactoryVerifier() throws NoSuchAlgorithmException, KeyManagementException, IOException { + final SSLContext sc = SSLContext.getInstance("TLS"); + MemorizingTrustManager trustManager = this.mXmppConnectionService.getMemorizingTrustManager(); + KeyManager[] keyManager; + if (account.getPrivateKeyAlias() != null && account.getPassword().isEmpty()) { + keyManager = new KeyManager[]{mKeyManager}; + } else { + keyManager = null; + } + sc.init(keyManager, new X509TrustManager[]{mInteractive ? trustManager : trustManager.getNonInteractive()}, mXmppConnectionService.getRNG()); + final SSLSocketFactory factory = sc.getSocketFactory(); + final HostnameVerifier verifier; + if (mInteractive) { + verifier = trustManager.wrapHostnameVerifier(new XmppDomainVerifier()); + } else { + verifier = trustManager.wrapHostnameVerifierNonInteractive(new XmppDomainVerifier()); + } + + return new TlsFactoryVerifier(factory, verifier); + } + @Override public void run() { try { @@ -599,53 +705,42 @@ public class XmppConnection implements Runnable { tagWriter.writeTag(startTLS); } + private void setSSLSocketSecurity(final SSLSocket sslSocket) throws NoSuchAlgorithmException { + final String[] supportProtocols; + final Collection supportedProtocols = new LinkedList<>( + Arrays.asList(sslSocket.getSupportedProtocols())); + supportedProtocols.remove("SSLv3"); + supportProtocols = supportedProtocols.toArray(new String[supportedProtocols.size()]); + + sslSocket.setEnabledProtocols(supportProtocols); + + final String[] cipherSuites = CryptoHelper.getOrderedCipherSuites( + sslSocket.getSupportedCipherSuites()); + //Log.d(Config.LOGTAG, "Using ciphers: " + Arrays.toString(cipherSuites)); + if (cipherSuites.length > 0) { + sslSocket.setEnabledCipherSuites(cipherSuites); + } + } + private void switchOverToTls(final Tag currentTag) throws XmlPullParserException, IOException { tagReader.readTag(); try { - final SSLContext sc = SSLContext.getInstance("TLS"); - MemorizingTrustManager trustManager = this.mXmppConnectionService.getMemorizingTrustManager(); - KeyManager[] keyManager; - if (account.getPrivateKeyAlias() != null && account.getPassword().isEmpty()) { - keyManager = new KeyManager[]{ mKeyManager }; - } else { - keyManager = null; - } - sc.init(keyManager,new X509TrustManager[]{mInteractive ? trustManager : trustManager.getNonInteractive()},mXmppConnectionService.getRNG()); - final SSLSocketFactory factory = sc.getSocketFactory(); - final HostnameVerifier verifier; - if (mInteractive) { - verifier = trustManager.wrapHostnameVerifier(new XmppDomainVerifier()); - } else { - verifier = trustManager.wrapHostnameVerifierNonInteractive(new XmppDomainVerifier()); - } + final TlsFactoryVerifier tlsFactoryVerifier = getTlsFactoryVerifier(); final InetAddress address = socket == null ? null : socket.getInetAddress(); - if (factory == null || address == null || verifier == null) { + if (address == null) { throw new IOException("could not setup ssl"); } - final SSLSocket sslSocket = (SSLSocket) factory.createSocket(socket,address.getHostAddress(), socket.getPort(),true); + final SSLSocket sslSocket = (SSLSocket) tlsFactoryVerifier.factory.createSocket(socket, address.getHostAddress(), socket.getPort(), true); if (sslSocket == null) { throw new IOException("could not initialize ssl socket"); } - final String[] supportProtocols; - final Collection supportedProtocols = new LinkedList<>( - Arrays.asList(sslSocket.getSupportedProtocols())); - supportedProtocols.remove("SSLv3"); - supportProtocols = supportedProtocols.toArray(new String[supportedProtocols.size()]); - - sslSocket.setEnabledProtocols(supportProtocols); - - final String[] cipherSuites = CryptoHelper.getOrderedCipherSuites( - sslSocket.getSupportedCipherSuites()); - //Log.d(Config.LOGTAG, "Using ciphers: " + Arrays.toString(cipherSuites)); - if (cipherSuites.length > 0) { - sslSocket.setEnabledCipherSuites(cipherSuites); - } + setSSLSocketSecurity(sslSocket); - if (!verifier.verify(account.getServer().getDomainpart(),sslSocket.getSession())) { + if (!tlsFactoryVerifier.verifier.verify(account.getServer().getDomainpart(), sslSocket.getSession())) { Log.d(Config.LOGTAG,account.getJid().toBareJid()+": TLS certificate verification failed"); throw new SecurityException(); } -- cgit v1.2.3