From 692ee6c9fbd03f38aa8ab969077b77cc77357854 Mon Sep 17 00:00:00 2001
From: Daniel Gultsch <daniel@gultsch.de>
Date: Wed, 30 Dec 2020 15:57:42 +0100
Subject: [PATCH] SCRAM remove cache. made digest and hmac non static

DIGEST and HMAC were static variables. Those are initialized by
what ever concrete implementation gets executed first.

(Perform SCRAM-SHA1 first and those variables got initialized with
SHA1 variants)

For subsequent SHA256 executions those variables contained wrong
values.
---
 .../crypto/sasl/ScramMechanism.java           | 72 ++++++++-----------
 .../conversations/crypto/sasl/ScramSha1.java  | 13 +++-
 .../crypto/sasl/ScramSha256.java              | 13 +++-
 3 files changed, 48 insertions(+), 50 deletions(-)

diff --git a/src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java b/src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java
index 4d40d2b74..dbc73c78f 100644
--- a/src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java
+++ b/src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java
@@ -3,13 +3,11 @@ package eu.siacs.conversations.crypto.sasl;
 import android.annotation.TargetApi;
 import android.os.Build;
 import android.util.Base64;
-import android.util.LruCache;
 
 import org.bouncycastle.crypto.Digest;
 import org.bouncycastle.crypto.macs.HMac;
 import org.bouncycastle.crypto.params.KeyParameter;
 
-import java.math.BigInteger;
 import java.nio.charset.Charset;
 import java.security.InvalidKeyException;
 import java.security.SecureRandom;
@@ -24,30 +22,22 @@ abstract class ScramMechanism extends SaslMechanism {
     private final static String GS2_HEADER = "n,,";
     private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
     private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
-    private static final LruCache<String, KeyPair> CACHE;
-    static HMac HMAC;
-    static Digest DIGEST;
 
-    static {
-        CACHE = new LruCache<String, KeyPair>(10) {
-            protected KeyPair create(final String k) {
-                // Map keys are "bytesToHex(JID),bytesToHex(password),bytesToHex(salt),iterations,SASL-Mechanism".
-                // Changing any of these values forces a cache miss. `CryptoHelper.bytesToHex()'
-                // is applied to prevent commas in the strings breaking things.
-                final String[] kParts = k.split(",", 5);
-                try {
-                    final byte[] saltedPassword, serverKey, clientKey;
-                    saltedPassword = hi(CryptoHelper.hexToString(kParts[1]).getBytes(),
-                            Base64.decode(CryptoHelper.hexToString(kParts[2]), Base64.DEFAULT), Integer.parseInt(kParts[3]));
-                    serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
-                    clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
+    protected abstract HMac getHMAC();
 
-                    return new KeyPair(clientKey, serverKey);
-                } catch (final InvalidKeyException | NumberFormatException e) {
-                    return null;
-                }
-            }
-        };
+    protected abstract Digest getDigest();
+
+    private KeyPair getKeyPair(final String password, final String salt, final int iterations) {
+        try {
+            final byte[] saltedPassword, serverKey, clientKey;
+            saltedPassword = hi(password.getBytes(), Base64.decode(salt, Base64.DEFAULT), iterations);
+            serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
+            clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
+
+            return new KeyPair(clientKey, serverKey);
+        } catch (final InvalidKeyException | NumberFormatException e) {
+            return null;
+        }
     }
 
     private final String clientNonce;
@@ -63,20 +53,21 @@ abstract class ScramMechanism extends SaslMechanism {
         clientFirstMessageBare = "";
     }
 
-    private static synchronized byte[] hmac(final byte[] key, final byte[] input)
-            throws InvalidKeyException {
-        HMAC.init(new KeyParameter(key));
-        HMAC.update(input, 0, input.length);
-        final byte[] out = new byte[HMAC.getMacSize()];
-        HMAC.doFinal(out, 0);
+    private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
+        final HMac hMac = getHMAC();
+        hMac.init(new KeyParameter(key));
+        hMac.update(input, 0, input.length);
+        final byte[] out = new byte[hMac.getMacSize()];
+        hMac.doFinal(out, 0);
         return out;
     }
 
-    public static synchronized byte[] digest(byte[] bytes) {
-        DIGEST.reset();
-        DIGEST.update(bytes, 0, bytes.length);
-        final byte[] out = new byte[DIGEST.getDigestSize()];
-        DIGEST.doFinal(out, 0);
+    public byte[] digest(byte[] bytes) {
+        final Digest digest = getDigest();
+        digest.reset();
+        digest.update(bytes, 0, bytes.length);
+        final byte[] out = new byte[digest.getDigestSize()];
+        digest.doFinal(out, 0);
         return out;
     }
 
@@ -85,7 +76,7 @@ abstract class ScramMechanism extends SaslMechanism {
      * pseudorandom function (PRF) and with dkLen == output length of
      * HMAC() == output length of H().
      */
-    private static synchronized byte[] hi(final byte[] key, final byte[] salt, final int iterations)
+    private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
             throws InvalidKeyException {
         byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
         byte[] out = u.clone();
@@ -171,14 +162,7 @@ abstract class ScramMechanism extends SaslMechanism {
                 final byte[] authMessage = (clientFirstMessageBare + ',' + new String(serverFirstMessage) + ','
                         + clientFinalMessageWithoutProof).getBytes();
 
-                // Map keys are "bytesToHex(JID),bytesToHex(password),bytesToHex(salt),iterations,SASL-Mechanism".
-                final KeyPair keys = CACHE.get(
-                        CryptoHelper.bytesToHex(CryptoHelper.saslPrep(account.getJid().asBareJid().toEscapedString()).getBytes()) + ","
-                                + CryptoHelper.bytesToHex(CryptoHelper.saslPrep(account.getPassword()).getBytes()) + ","
-                                + CryptoHelper.bytesToHex(salt.getBytes()) + ","
-                                + iterationCount + ","
-                                + getMechanism()
-                );
+                final KeyPair keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
                 if (keys == null) {
                     throw new AuthenticationException("Invalid keys generated");
                 }
diff --git a/src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1.java b/src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1.java
index 13593778d..5558d6a43 100644
--- a/src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1.java
+++ b/src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1.java
@@ -1,5 +1,6 @@
 package eu.siacs.conversations.crypto.sasl;
 
+import org.bouncycastle.crypto.Digest;
 import org.bouncycastle.crypto.digests.SHA1Digest;
 import org.bouncycastle.crypto.macs.HMac;
 
@@ -9,9 +10,15 @@ import eu.siacs.conversations.entities.Account;
 import eu.siacs.conversations.xml.TagWriter;
 
 public class ScramSha1 extends ScramMechanism {
-	static {
-		DIGEST = new SHA1Digest();
-		HMAC = new HMac(new SHA1Digest());
+
+	@Override
+	protected HMac getHMAC() {
+		return  new HMac(new SHA1Digest());
+	}
+
+	@Override
+	protected Digest getDigest() {
+		return new SHA1Digest();
 	}
 
 	public ScramSha1(final TagWriter tagWriter, final Account account, final SecureRandom rng) {
diff --git a/src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha256.java b/src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha256.java
index 1b7a969d9..866c1ea79 100644
--- a/src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha256.java
+++ b/src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha256.java
@@ -1,5 +1,6 @@
 package eu.siacs.conversations.crypto.sasl;
 
+import org.bouncycastle.crypto.Digest;
 import org.bouncycastle.crypto.digests.SHA256Digest;
 import org.bouncycastle.crypto.macs.HMac;
 
@@ -9,9 +10,15 @@ import eu.siacs.conversations.entities.Account;
 import eu.siacs.conversations.xml.TagWriter;
 
 public class ScramSha256 extends ScramMechanism {
-	static {
-		DIGEST = new SHA256Digest();
-		HMAC = new HMac(new SHA256Digest());
+
+	@Override
+	protected HMac getHMAC() {
+		return new HMac(new SHA256Digest());
+	}
+
+	@Override
+	protected Digest getDigest() {
+		return new SHA256Digest();
 	}
 
 	public ScramSha256(final TagWriter tagWriter, final Account account, final SecureRandom rng) {