aboutsummaryrefslogtreecommitdiffstats
path: root/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/org/whispersystems/libaxolotl/SessionCipher.java')
-rw-r--r--src/main/java/org/whispersystems/libaxolotl/SessionCipher.java469
1 files changed, 469 insertions, 0 deletions
diff --git a/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java b/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java
new file mode 100644
index 00000000..381dedb8
--- /dev/null
+++ b/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java
@@ -0,0 +1,469 @@
+/**
+ * Copyright (C) 2013 Open Whisper Systems
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+package org.whispersystems.libaxolotl;
+
+import org.whispersystems.libaxolotl.ecc.Curve;
+import org.whispersystems.libaxolotl.ecc.ECKeyPair;
+import org.whispersystems.libaxolotl.ecc.ECPublicKey;
+import org.whispersystems.libaxolotl.protocol.CiphertextMessage;
+import org.whispersystems.libaxolotl.protocol.PreKeyWhisperMessage;
+import org.whispersystems.libaxolotl.protocol.WhisperMessage;
+import org.whispersystems.libaxolotl.ratchet.ChainKey;
+import org.whispersystems.libaxolotl.ratchet.MessageKeys;
+import org.whispersystems.libaxolotl.ratchet.RootKey;
+import org.whispersystems.libaxolotl.state.AxolotlStore;
+import org.whispersystems.libaxolotl.state.IdentityKeyStore;
+import org.whispersystems.libaxolotl.state.PreKeyStore;
+import org.whispersystems.libaxolotl.state.SessionRecord;
+import org.whispersystems.libaxolotl.state.SessionState;
+import org.whispersystems.libaxolotl.state.SessionStore;
+import org.whispersystems.libaxolotl.state.SignedPreKeyStore;
+import org.whispersystems.libaxolotl.util.ByteUtil;
+import org.whispersystems.libaxolotl.util.Pair;
+import org.whispersystems.libaxolotl.util.guava.Optional;
+
+import java.security.InvalidAlgorithmParameterException;
+import java.security.NoSuchAlgorithmException;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+
+import javax.crypto.BadPaddingException;
+import javax.crypto.Cipher;
+import javax.crypto.IllegalBlockSizeException;
+import javax.crypto.NoSuchPaddingException;
+import javax.crypto.spec.IvParameterSpec;
+import javax.crypto.spec.SecretKeySpec;
+
+import static org.whispersystems.libaxolotl.state.SessionState.UnacknowledgedPreKeyMessageItems;
+
+/**
+ * The main entry point for Axolotl encrypt/decrypt operations.
+ *
+ * Once a session has been established with {@link SessionBuilder},
+ * this class can be used for all encrypt/decrypt operations within
+ * that session.
+ *
+ * @author Moxie Marlinspike
+ */
+public class SessionCipher {
+
+ public static final Object SESSION_LOCK = new Object();
+
+ private final SessionStore sessionStore;
+ private final SessionBuilder sessionBuilder;
+ private final PreKeyStore preKeyStore;
+ private final long recipientId;
+ private final int deviceId;
+
+ /**
+ * Construct a SessionCipher for encrypt/decrypt operations on a session.
+ * In order to use SessionCipher, a session must have already been created
+ * and stored using {@link SessionBuilder}.
+ *
+ * @param sessionStore The {@link SessionStore} that contains a session for this recipient.
+ * @param recipientId The remote ID that messages will be encrypted to or decrypted from.
+ * @param deviceId The device corresponding to the recipientId.
+ */
+ public SessionCipher(SessionStore sessionStore, PreKeyStore preKeyStore,
+ SignedPreKeyStore signedPreKeyStore, IdentityKeyStore identityKeyStore,
+ long recipientId, int deviceId)
+ {
+ this.sessionStore = sessionStore;
+ this.recipientId = recipientId;
+ this.deviceId = deviceId;
+ this.preKeyStore = preKeyStore;
+ this.sessionBuilder = new SessionBuilder(sessionStore, preKeyStore, signedPreKeyStore,
+ identityKeyStore, recipientId, deviceId);
+ }
+
+ public SessionCipher(AxolotlStore store, long recipientId, int deviceId) {
+ this(store, store, store, store, recipientId, deviceId);
+ }
+
+ /**
+ * Encrypt a message.
+ *
+ * @param paddedMessage The plaintext message bytes, optionally padded to a constant multiple.
+ * @return A ciphertext message encrypted to the recipient+device tuple.
+ */
+ public CiphertextMessage encrypt(byte[] paddedMessage) {
+ synchronized (SESSION_LOCK) {
+ SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);
+ SessionState sessionState = sessionRecord.getSessionState();
+ ChainKey chainKey = sessionState.getSenderChainKey();
+ MessageKeys messageKeys = chainKey.getMessageKeys();
+ ECPublicKey senderEphemeral = sessionState.getSenderRatchetKey();
+ int previousCounter = sessionState.getPreviousCounter();
+ int sessionVersion = sessionState.getSessionVersion();
+
+ byte[] ciphertextBody = getCiphertext(sessionVersion, messageKeys, paddedMessage);
+ CiphertextMessage ciphertextMessage = new WhisperMessage(sessionVersion, messageKeys.getMacKey(),
+ senderEphemeral, chainKey.getIndex(),
+ previousCounter, ciphertextBody,
+ sessionState.getLocalIdentityKey(),
+ sessionState.getRemoteIdentityKey());
+
+ if (sessionState.hasUnacknowledgedPreKeyMessage()) {
+ UnacknowledgedPreKeyMessageItems items = sessionState.getUnacknowledgedPreKeyMessageItems();
+ int localRegistrationId = sessionState.getLocalRegistrationId();
+
+ ciphertextMessage = new PreKeyWhisperMessage(sessionVersion, localRegistrationId, items.getPreKeyId(),
+ items.getSignedPreKeyId(), items.getBaseKey(),
+ sessionState.getLocalIdentityKey(),
+ (WhisperMessage) ciphertextMessage);
+ }
+
+ sessionState.setSenderChainKey(chainKey.getNextChainKey());
+ sessionStore.storeSession(recipientId, deviceId, sessionRecord);
+ return ciphertextMessage;
+ }
+ }
+
+ /**
+ * Decrypt a message.
+ *
+ * @param ciphertext The {@link PreKeyWhisperMessage} to decrypt.
+ *
+ * @return The plaintext.
+ * @throws InvalidMessageException if the input is not valid ciphertext.
+ * @throws DuplicateMessageException if the input is a message that has already been received.
+ * @throws LegacyMessageException if the input is a message formatted by a protocol version that
+ * is no longer supported.
+ * @throws InvalidKeyIdException when there is no local {@link org.whispersystems.libaxolotl.state.PreKeyRecord}
+ * that corresponds to the PreKey ID in the message.
+ * @throws InvalidKeyException when the message is formatted incorrectly.
+ * @throws UntrustedIdentityException when the {@link IdentityKey} of the sender is untrusted.
+ */
+ public byte[] decrypt(PreKeyWhisperMessage ciphertext)
+ throws DuplicateMessageException, LegacyMessageException, InvalidMessageException,
+ InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException
+ {
+ return decrypt(ciphertext, new NullDecryptionCallback());
+ }
+
+ /**
+ * Decrypt a message.
+ *
+ * @param ciphertext The {@link PreKeyWhisperMessage} to decrypt.
+ * @param callback A callback that is triggered after decryption is complete,
+ * but before the updated session state has been committed to the session
+ * DB. This allows some implementations to store the committed plaintext
+ * to a DB first, in case they are concerned with a crash happening between
+ * the time the session state is updated but before they're able to store
+ * the plaintext to disk.
+ *
+ * @return The plaintext.
+ * @throws InvalidMessageException if the input is not valid ciphertext.
+ * @throws DuplicateMessageException if the input is a message that has already been received.
+ * @throws LegacyMessageException if the input is a message formatted by a protocol version that
+ * is no longer supported.
+ * @throws InvalidKeyIdException when there is no local {@link org.whispersystems.libaxolotl.state.PreKeyRecord}
+ * that corresponds to the PreKey ID in the message.
+ * @throws InvalidKeyException when the message is formatted incorrectly.
+ * @throws UntrustedIdentityException when the {@link IdentityKey} of the sender is untrusted.
+ */
+ public byte[] decrypt(PreKeyWhisperMessage ciphertext, DecryptionCallback callback)
+ throws DuplicateMessageException, LegacyMessageException, InvalidMessageException,
+ InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException
+ {
+ synchronized (SESSION_LOCK) {
+ SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);
+ Optional<Integer> unsignedPreKeyId = sessionBuilder.process(sessionRecord, ciphertext);
+ byte[] plaintext = decrypt(sessionRecord, ciphertext.getWhisperMessage());
+
+ callback.handlePlaintext(plaintext);
+
+ sessionStore.storeSession(recipientId, deviceId, sessionRecord);
+
+ if (unsignedPreKeyId.isPresent()) {
+ preKeyStore.removePreKey(unsignedPreKeyId.get());
+ }
+
+ return plaintext;
+ }
+ }
+
+ /**
+ * Decrypt a message.
+ *
+ * @param ciphertext The {@link WhisperMessage} to decrypt.
+ *
+ * @return The plaintext.
+ * @throws InvalidMessageException if the input is not valid ciphertext.
+ * @throws DuplicateMessageException if the input is a message that has already been received.
+ * @throws LegacyMessageException if the input is a message formatted by a protocol version that
+ * is no longer supported.
+ * @throws NoSessionException if there is no established session for this contact.
+ */
+ public byte[] decrypt(WhisperMessage ciphertext)
+ throws InvalidMessageException, DuplicateMessageException, LegacyMessageException,
+ NoSessionException
+ {
+ return decrypt(ciphertext, new NullDecryptionCallback());
+ }
+
+ /**
+ * Decrypt a message.
+ *
+ * @param ciphertext The {@link WhisperMessage} to decrypt.
+ * @param callback A callback that is triggered after decryption is complete,
+ * but before the updated session state has been committed to the session
+ * DB. This allows some implementations to store the committed plaintext
+ * to a DB first, in case they are concerned with a crash happening between
+ * the time the session state is updated but before they're able to store
+ * the plaintext to disk.
+ *
+ * @return The plaintext.
+ * @throws InvalidMessageException if the input is not valid ciphertext.
+ * @throws DuplicateMessageException if the input is a message that has already been received.
+ * @throws LegacyMessageException if the input is a message formatted by a protocol version that
+ * is no longer supported.
+ * @throws NoSessionException if there is no established session for this contact.
+ */
+ public byte[] decrypt(WhisperMessage ciphertext, DecryptionCallback callback)
+ throws InvalidMessageException, DuplicateMessageException, LegacyMessageException,
+ NoSessionException
+ {
+ synchronized (SESSION_LOCK) {
+
+ if (!sessionStore.containsSession(recipientId, deviceId)) {
+ throw new NoSessionException("No session for: " + recipientId + ", " + deviceId);
+ }
+
+ SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId);
+ byte[] plaintext = decrypt(sessionRecord, ciphertext);
+
+ callback.handlePlaintext(plaintext);
+
+ sessionStore.storeSession(recipientId, deviceId, sessionRecord);
+
+ return plaintext;
+ }
+ }
+
+ private byte[] decrypt(SessionRecord sessionRecord, WhisperMessage ciphertext)
+ throws DuplicateMessageException, LegacyMessageException, InvalidMessageException
+ {
+ synchronized (SESSION_LOCK) {
+ Iterator<SessionState> previousStates = sessionRecord.getPreviousSessionStates().iterator();
+ List<Exception> exceptions = new LinkedList<>();
+
+ try {
+ SessionState sessionState = new SessionState(sessionRecord.getSessionState());
+ byte[] plaintext = decrypt(sessionState, ciphertext);
+
+ sessionRecord.setState(sessionState);
+ return plaintext;
+ } catch (InvalidMessageException e) {
+ exceptions.add(e);
+ }
+
+ while (previousStates.hasNext()) {
+ try {
+ SessionState promotedState = new SessionState(previousStates.next());
+ byte[] plaintext = decrypt(promotedState, ciphertext);
+
+ previousStates.remove();
+ sessionRecord.promoteState(promotedState);
+
+ return plaintext;
+ } catch (InvalidMessageException e) {
+ exceptions.add(e);
+ }
+ }
+
+ throw new InvalidMessageException("No valid sessions.", exceptions);
+ }
+ }
+
+ private byte[] decrypt(SessionState sessionState, WhisperMessage ciphertextMessage)
+ throws InvalidMessageException, DuplicateMessageException, LegacyMessageException
+ {
+ if (!sessionState.hasSenderChain()) {
+ throw new InvalidMessageException("Uninitialized session!");
+ }
+
+ if (ciphertextMessage.getMessageVersion() != sessionState.getSessionVersion()) {
+ throw new InvalidMessageException(String.format("Message version %d, but session version %d",
+ ciphertextMessage.getMessageVersion(),
+ sessionState.getSessionVersion()));
+ }
+
+ int messageVersion = ciphertextMessage.getMessageVersion();
+ ECPublicKey theirEphemeral = ciphertextMessage.getSenderRatchetKey();
+ int counter = ciphertextMessage.getCounter();
+ ChainKey chainKey = getOrCreateChainKey(sessionState, theirEphemeral);
+ MessageKeys messageKeys = getOrCreateMessageKeys(sessionState, theirEphemeral,
+ chainKey, counter);
+
+ ciphertextMessage.verifyMac(messageVersion,
+ sessionState.getRemoteIdentityKey(),
+ sessionState.getLocalIdentityKey(),
+ messageKeys.getMacKey());
+
+ byte[] plaintext = getPlaintext(messageVersion, messageKeys, ciphertextMessage.getBody());
+
+ sessionState.clearUnacknowledgedPreKeyMessage();
+
+ return plaintext;
+ }
+
+ public int getRemoteRegistrationId() {
+ synchronized (SESSION_LOCK) {
+ SessionRecord record = sessionStore.loadSession(recipientId, deviceId);
+ return record.getSessionState().getRemoteRegistrationId();
+ }
+ }
+
+ public int getSessionVersion() {
+ synchronized (SESSION_LOCK) {
+ if (!sessionStore.containsSession(recipientId, deviceId)) {
+ throw new IllegalStateException(String.format("No session for (%d, %d)!", recipientId, deviceId));
+ }
+
+ SessionRecord record = sessionStore.loadSession(recipientId, deviceId);
+ return record.getSessionState().getSessionVersion();
+ }
+ }
+
+ private ChainKey getOrCreateChainKey(SessionState sessionState, ECPublicKey theirEphemeral)
+ throws InvalidMessageException
+ {
+ try {
+ if (sessionState.hasReceiverChain(theirEphemeral)) {
+ return sessionState.getReceiverChainKey(theirEphemeral);
+ } else {
+ RootKey rootKey = sessionState.getRootKey();
+ ECKeyPair ourEphemeral = sessionState.getSenderRatchetKeyPair();
+ Pair<RootKey, ChainKey> receiverChain = rootKey.createChain(theirEphemeral, ourEphemeral);
+ ECKeyPair ourNewEphemeral = Curve.generateKeyPair();
+ Pair<RootKey, ChainKey> senderChain = receiverChain.first().createChain(theirEphemeral, ourNewEphemeral);
+
+ sessionState.setRootKey(senderChain.first());
+ sessionState.addReceiverChain(theirEphemeral, receiverChain.second());
+ sessionState.setPreviousCounter(Math.max(sessionState.getSenderChainKey().getIndex()-1, 0));
+ sessionState.setSenderChain(ourNewEphemeral, senderChain.second());
+
+ return receiverChain.second();
+ }
+ } catch (InvalidKeyException e) {
+ throw new InvalidMessageException(e);
+ }
+ }
+
+ private MessageKeys getOrCreateMessageKeys(SessionState sessionState,
+ ECPublicKey theirEphemeral,
+ ChainKey chainKey, int counter)
+ throws InvalidMessageException, DuplicateMessageException
+ {
+ if (chainKey.getIndex() > counter) {
+ if (sessionState.hasMessageKeys(theirEphemeral, counter)) {
+ return sessionState.removeMessageKeys(theirEphemeral, counter);
+ } else {
+ throw new DuplicateMessageException("Received message with old counter: " +
+ chainKey.getIndex() + " , " + counter);
+ }
+ }
+
+ if (counter - chainKey.getIndex() > 2000) {
+ throw new InvalidMessageException("Over 2000 messages into the future!");
+ }
+
+ while (chainKey.getIndex() < counter) {
+ MessageKeys messageKeys = chainKey.getMessageKeys();
+ sessionState.setMessageKeys(theirEphemeral, messageKeys);
+ chainKey = chainKey.getNextChainKey();
+ }
+
+ sessionState.setReceiverChainKey(theirEphemeral, chainKey.getNextChainKey());
+ return chainKey.getMessageKeys();
+ }
+
+ private byte[] getCiphertext(int version, MessageKeys messageKeys, byte[] plaintext) {
+ try {
+ Cipher cipher;
+
+ if (version >= 3) {
+ cipher = getCipher(Cipher.ENCRYPT_MODE, messageKeys.getCipherKey(), messageKeys.getIv());
+ } else {
+ cipher = getCipher(Cipher.ENCRYPT_MODE, messageKeys.getCipherKey(), messageKeys.getCounter());
+ }
+
+ return cipher.doFinal(plaintext);
+ } catch (IllegalBlockSizeException | BadPaddingException e) {
+ throw new AssertionError(e);
+ }
+ }
+
+ private byte[] getPlaintext(int version, MessageKeys messageKeys, byte[] cipherText)
+ throws InvalidMessageException
+ {
+ try {
+ Cipher cipher;
+
+ if (version >= 3) {
+ cipher = getCipher(Cipher.DECRYPT_MODE, messageKeys.getCipherKey(), messageKeys.getIv());
+ } else {
+ cipher = getCipher(Cipher.DECRYPT_MODE, messageKeys.getCipherKey(), messageKeys.getCounter());
+ }
+
+ return cipher.doFinal(cipherText);
+ } catch (IllegalBlockSizeException | BadPaddingException e) {
+ throw new InvalidMessageException(e);
+ }
+ }
+
+ private Cipher getCipher(int mode, SecretKeySpec key, int counter) {
+ try {
+ Cipher cipher = Cipher.getInstance("AES/CTR/NoPadding");
+
+ byte[] ivBytes = new byte[16];
+ ByteUtil.intToByteArray(ivBytes, 0, counter);
+
+ IvParameterSpec iv = new IvParameterSpec(ivBytes);
+ cipher.init(mode, key, iv);
+
+ return cipher;
+ } catch (NoSuchAlgorithmException | NoSuchPaddingException | java.security.InvalidKeyException |
+ InvalidAlgorithmParameterException e)
+ {
+ throw new AssertionError(e);
+ }
+ }
+
+ private Cipher getCipher(int mode, SecretKeySpec key, IvParameterSpec iv) {
+ try {
+ Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
+ cipher.init(mode, key, iv);
+ return cipher;
+ } catch (NoSuchAlgorithmException | NoSuchPaddingException | java.security.InvalidKeyException |
+ InvalidAlgorithmParameterException e)
+ {
+ throw new AssertionError(e);
+ }
+ }
+
+ public static interface DecryptionCallback {
+ public void handlePlaintext(byte[] plaintext);
+ }
+
+ private static class NullDecryptionCallback implements DecryptionCallback {
+ @Override
+ public void handlePlaintext(byte[] plaintext) {}
+ }
+} \ No newline at end of file