diff options
author | Moxie Marlinspike <moxie@thoughtcrime.org> | 2014-11-24 12:54:30 -0800 |
---|---|---|
committer | Moxie Marlinspike <moxie@thoughtcrime.org> | 2014-11-24 12:54:30 -0800 |
commit | 60800e155612bea797eed93c67046a23d26054cc (patch) | |
tree | d88368c1c26162e27e790195133ca2b526597afe /src/main/java/org/whispersystems/libaxolotl/state/SessionState.java |
Break out into separate repo.
Diffstat (limited to 'src/main/java/org/whispersystems/libaxolotl/state/SessionState.java')
-rw-r--r-- | src/main/java/org/whispersystems/libaxolotl/state/SessionState.java | 509 |
1 files changed, 509 insertions, 0 deletions
diff --git a/src/main/java/org/whispersystems/libaxolotl/state/SessionState.java b/src/main/java/org/whispersystems/libaxolotl/state/SessionState.java new file mode 100644 index 00000000..9b2b1e2a --- /dev/null +++ b/src/main/java/org/whispersystems/libaxolotl/state/SessionState.java @@ -0,0 +1,509 @@ +/** + * Copyright (C) 2014 Open Whisper Systems + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +package org.whispersystems.libaxolotl.state; + +import android.util.Log; + +import com.google.protobuf.ByteString; + +import org.whispersystems.libaxolotl.IdentityKey; +import org.whispersystems.libaxolotl.IdentityKeyPair; +import org.whispersystems.libaxolotl.InvalidKeyException; +import org.whispersystems.libaxolotl.ecc.Curve; +import org.whispersystems.libaxolotl.ecc.ECKeyPair; +import org.whispersystems.libaxolotl.ecc.ECPrivateKey; +import org.whispersystems.libaxolotl.ecc.ECPublicKey; +import org.whispersystems.libaxolotl.kdf.HKDF; +import org.whispersystems.libaxolotl.ratchet.ChainKey; +import org.whispersystems.libaxolotl.ratchet.MessageKeys; +import org.whispersystems.libaxolotl.ratchet.RootKey; +import org.whispersystems.libaxolotl.state.StorageProtos.SessionStructure.Chain; +import org.whispersystems.libaxolotl.state.StorageProtos.SessionStructure.PendingKeyExchange; +import org.whispersystems.libaxolotl.state.StorageProtos.SessionStructure.PendingPreKey; +import org.whispersystems.libaxolotl.util.Pair; +import org.whispersystems.libaxolotl.util.guava.Optional; + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.SecretKeySpec; + +import static org.whispersystems.libaxolotl.state.StorageProtos.SessionStructure; + +public class SessionState { + + private SessionStructure sessionStructure; + + public SessionState() { + this.sessionStructure = SessionStructure.newBuilder().build(); + } + + public SessionState(SessionStructure sessionStructure) { + this.sessionStructure = sessionStructure; + } + + public SessionState(SessionState copy) { + this.sessionStructure = copy.sessionStructure.toBuilder().build(); + } + + public SessionStructure getStructure() { + return sessionStructure; + } + + public byte[] getAliceBaseKey() { + return this.sessionStructure.getAliceBaseKey().toByteArray(); + } + + public void setAliceBaseKey(byte[] aliceBaseKey) { + this.sessionStructure = this.sessionStructure.toBuilder() + .setAliceBaseKey(ByteString.copyFrom(aliceBaseKey)) + .build(); + } + + public void setSessionVersion(int version) { + this.sessionStructure = this.sessionStructure.toBuilder() + .setSessionVersion(version) + .build(); + } + + public int getSessionVersion() { + int sessionVersion = this.sessionStructure.getSessionVersion(); + + if (sessionVersion == 0) return 2; + else return sessionVersion; + } + + public void setRemoteIdentityKey(IdentityKey identityKey) { + this.sessionStructure = this.sessionStructure.toBuilder() + .setRemoteIdentityPublic(ByteString.copyFrom(identityKey.serialize())) + .build(); + } + + public void setLocalIdentityKey(IdentityKey identityKey) { + this.sessionStructure = this.sessionStructure.toBuilder() + .setLocalIdentityPublic(ByteString.copyFrom(identityKey.serialize())) + .build(); + } + + public IdentityKey getRemoteIdentityKey() { + try { + if (!this.sessionStructure.hasRemoteIdentityPublic()) { + return null; + } + + return new IdentityKey(this.sessionStructure.getRemoteIdentityPublic().toByteArray(), 0); + } catch (InvalidKeyException e) { + Log.w("SessionRecordV2", e); + return null; + } + } + + public IdentityKey getLocalIdentityKey() { + try { + return new IdentityKey(this.sessionStructure.getLocalIdentityPublic().toByteArray(), 0); + } catch (InvalidKeyException e) { + throw new AssertionError(e); + } + } + + public int getPreviousCounter() { + return sessionStructure.getPreviousCounter(); + } + + public void setPreviousCounter(int previousCounter) { + this.sessionStructure = this.sessionStructure.toBuilder() + .setPreviousCounter(previousCounter) + .build(); + } + + public RootKey getRootKey() { + return new RootKey(HKDF.createFor(getSessionVersion()), + this.sessionStructure.getRootKey().toByteArray()); + } + + public void setRootKey(RootKey rootKey) { + this.sessionStructure = this.sessionStructure.toBuilder() + .setRootKey(ByteString.copyFrom(rootKey.getKeyBytes())) + .build(); + } + + public ECPublicKey getSenderRatchetKey() { + try { + return Curve.decodePoint(sessionStructure.getSenderChain().getSenderRatchetKey().toByteArray(), 0); + } catch (InvalidKeyException e) { + throw new AssertionError(e); + } + } + + public ECKeyPair getSenderRatchetKeyPair() { + ECPublicKey publicKey = getSenderRatchetKey(); + ECPrivateKey privateKey = Curve.decodePrivatePoint(sessionStructure.getSenderChain() + .getSenderRatchetKeyPrivate() + .toByteArray()); + + return new ECKeyPair(publicKey, privateKey); + } + + public boolean hasReceiverChain(ECPublicKey senderEphemeral) { + return getReceiverChain(senderEphemeral) != null; + } + + public boolean hasSenderChain() { + return sessionStructure.hasSenderChain(); + } + + private Pair<Chain,Integer> getReceiverChain(ECPublicKey senderEphemeral) { + List<Chain> receiverChains = sessionStructure.getReceiverChainsList(); + int index = 0; + + for (Chain receiverChain : receiverChains) { + try { + ECPublicKey chainSenderRatchetKey = Curve.decodePoint(receiverChain.getSenderRatchetKey().toByteArray(), 0); + + if (chainSenderRatchetKey.equals(senderEphemeral)) { + return new Pair<>(receiverChain,index); + } + } catch (InvalidKeyException e) { + Log.w("SessionRecordV2", e); + } + + index++; + } + + return null; + } + + public ChainKey getReceiverChainKey(ECPublicKey senderEphemeral) { + Pair<Chain,Integer> receiverChainAndIndex = getReceiverChain(senderEphemeral); + Chain receiverChain = receiverChainAndIndex.first(); + + if (receiverChain == null) { + return null; + } else { + return new ChainKey(HKDF.createFor(getSessionVersion()), + receiverChain.getChainKey().getKey().toByteArray(), + receiverChain.getChainKey().getIndex()); + } + } + + public void addReceiverChain(ECPublicKey senderRatchetKey, ChainKey chainKey) { + Chain.ChainKey chainKeyStructure = Chain.ChainKey.newBuilder() + .setKey(ByteString.copyFrom(chainKey.getKey())) + .setIndex(chainKey.getIndex()) + .build(); + + Chain chain = Chain.newBuilder() + .setChainKey(chainKeyStructure) + .setSenderRatchetKey(ByteString.copyFrom(senderRatchetKey.serialize())) + .build(); + + this.sessionStructure = this.sessionStructure.toBuilder().addReceiverChains(chain).build(); + + if (this.sessionStructure.getReceiverChainsList().size() > 5) { + this.sessionStructure = this.sessionStructure.toBuilder() + .removeReceiverChains(0) + .build(); + } + } + + public void setSenderChain(ECKeyPair senderRatchetKeyPair, ChainKey chainKey) { + Chain.ChainKey chainKeyStructure = Chain.ChainKey.newBuilder() + .setKey(ByteString.copyFrom(chainKey.getKey())) + .setIndex(chainKey.getIndex()) + .build(); + + Chain senderChain = Chain.newBuilder() + .setSenderRatchetKey(ByteString.copyFrom(senderRatchetKeyPair.getPublicKey().serialize())) + .setSenderRatchetKeyPrivate(ByteString.copyFrom(senderRatchetKeyPair.getPrivateKey().serialize())) + .setChainKey(chainKeyStructure) + .build(); + + this.sessionStructure = this.sessionStructure.toBuilder().setSenderChain(senderChain).build(); + } + + public ChainKey getSenderChainKey() { + Chain.ChainKey chainKeyStructure = sessionStructure.getSenderChain().getChainKey(); + return new ChainKey(HKDF.createFor(getSessionVersion()), + chainKeyStructure.getKey().toByteArray(), chainKeyStructure.getIndex()); + } + + + public void setSenderChainKey(ChainKey nextChainKey) { + Chain.ChainKey chainKey = Chain.ChainKey.newBuilder() + .setKey(ByteString.copyFrom(nextChainKey.getKey())) + .setIndex(nextChainKey.getIndex()) + .build(); + + Chain chain = sessionStructure.getSenderChain().toBuilder() + .setChainKey(chainKey).build(); + + this.sessionStructure = this.sessionStructure.toBuilder().setSenderChain(chain).build(); + } + + public boolean hasMessageKeys(ECPublicKey senderEphemeral, int counter) { + Pair<Chain,Integer> chainAndIndex = getReceiverChain(senderEphemeral); + Chain chain = chainAndIndex.first(); + + if (chain == null) { + return false; + } + + List<Chain.MessageKey> messageKeyList = chain.getMessageKeysList(); + + for (Chain.MessageKey messageKey : messageKeyList) { + if (messageKey.getIndex() == counter) { + return true; + } + } + + return false; + } + + public MessageKeys removeMessageKeys(ECPublicKey senderEphemeral, int counter) { + Pair<Chain,Integer> chainAndIndex = getReceiverChain(senderEphemeral); + Chain chain = chainAndIndex.first(); + + if (chain == null) { + return null; + } + + List<Chain.MessageKey> messageKeyList = new LinkedList<>(chain.getMessageKeysList()); + Iterator<Chain.MessageKey> messageKeyIterator = messageKeyList.iterator(); + MessageKeys result = null; + + while (messageKeyIterator.hasNext()) { + Chain.MessageKey messageKey = messageKeyIterator.next(); + + if (messageKey.getIndex() == counter) { + result = new MessageKeys(new SecretKeySpec(messageKey.getCipherKey().toByteArray(), "AES"), + new SecretKeySpec(messageKey.getMacKey().toByteArray(), "HmacSHA256"), + new IvParameterSpec(messageKey.getIv().toByteArray()), + messageKey.getIndex()); + + messageKeyIterator.remove(); + break; + } + } + + Chain updatedChain = chain.toBuilder().clearMessageKeys() + .addAllMessageKeys(messageKeyList) + .build(); + + this.sessionStructure = this.sessionStructure.toBuilder() + .setReceiverChains(chainAndIndex.second(), updatedChain) + .build(); + + return result; + } + + public void setMessageKeys(ECPublicKey senderEphemeral, MessageKeys messageKeys) { + Pair<Chain,Integer> chainAndIndex = getReceiverChain(senderEphemeral); + Chain chain = chainAndIndex.first(); + Chain.MessageKey messageKeyStructure = Chain.MessageKey.newBuilder() + .setCipherKey(ByteString.copyFrom(messageKeys.getCipherKey().getEncoded())) + .setMacKey(ByteString.copyFrom(messageKeys.getMacKey().getEncoded())) + .setIndex(messageKeys.getCounter()) + .setIv(ByteString.copyFrom(messageKeys.getIv().getIV())) + .build(); + + Chain updatedChain = chain.toBuilder() + .addMessageKeys(messageKeyStructure) + .build(); + + this.sessionStructure = this.sessionStructure.toBuilder() + .setReceiverChains(chainAndIndex.second(), updatedChain) + .build(); + } + + public void setReceiverChainKey(ECPublicKey senderEphemeral, ChainKey chainKey) { + Pair<Chain,Integer> chainAndIndex = getReceiverChain(senderEphemeral); + Chain chain = chainAndIndex.first(); + + Chain.ChainKey chainKeyStructure = Chain.ChainKey.newBuilder() + .setKey(ByteString.copyFrom(chainKey.getKey())) + .setIndex(chainKey.getIndex()) + .build(); + + Chain updatedChain = chain.toBuilder().setChainKey(chainKeyStructure).build(); + + this.sessionStructure = this.sessionStructure.toBuilder() + .setReceiverChains(chainAndIndex.second(), updatedChain) + .build(); + } + + public void setPendingKeyExchange(int sequence, + ECKeyPair ourBaseKey, + ECKeyPair ourRatchetKey, + IdentityKeyPair ourIdentityKey) + { + PendingKeyExchange structure = + PendingKeyExchange.newBuilder() + .setSequence(sequence) + .setLocalBaseKey(ByteString.copyFrom(ourBaseKey.getPublicKey().serialize())) + .setLocalBaseKeyPrivate(ByteString.copyFrom(ourBaseKey.getPrivateKey().serialize())) + .setLocalRatchetKey(ByteString.copyFrom(ourRatchetKey.getPublicKey().serialize())) + .setLocalRatchetKeyPrivate(ByteString.copyFrom(ourRatchetKey.getPrivateKey().serialize())) + .setLocalIdentityKey(ByteString.copyFrom(ourIdentityKey.getPublicKey().serialize())) + .setLocalIdentityKeyPrivate(ByteString.copyFrom(ourIdentityKey.getPrivateKey().serialize())) + .build(); + + this.sessionStructure = this.sessionStructure.toBuilder() + .setPendingKeyExchange(structure) + .build(); + } + + public int getPendingKeyExchangeSequence() { + return sessionStructure.getPendingKeyExchange().getSequence(); + } + + public ECKeyPair getPendingKeyExchangeBaseKey() throws InvalidKeyException { + ECPublicKey publicKey = Curve.decodePoint(sessionStructure.getPendingKeyExchange() + .getLocalBaseKey().toByteArray(), 0); + + ECPrivateKey privateKey = Curve.decodePrivatePoint(sessionStructure.getPendingKeyExchange() + .getLocalBaseKeyPrivate() + .toByteArray()); + + return new ECKeyPair(publicKey, privateKey); + } + + public ECKeyPair getPendingKeyExchangeRatchetKey() throws InvalidKeyException { + ECPublicKey publicKey = Curve.decodePoint(sessionStructure.getPendingKeyExchange() + .getLocalRatchetKey().toByteArray(), 0); + + ECPrivateKey privateKey = Curve.decodePrivatePoint(sessionStructure.getPendingKeyExchange() + .getLocalRatchetKeyPrivate() + .toByteArray()); + + return new ECKeyPair(publicKey, privateKey); + } + + public IdentityKeyPair getPendingKeyExchangeIdentityKey() throws InvalidKeyException { + IdentityKey publicKey = new IdentityKey(sessionStructure.getPendingKeyExchange() + .getLocalIdentityKey().toByteArray(), 0); + + ECPrivateKey privateKey = Curve.decodePrivatePoint(sessionStructure.getPendingKeyExchange() + .getLocalIdentityKeyPrivate() + .toByteArray()); + + return new IdentityKeyPair(publicKey, privateKey); + } + + public boolean hasPendingKeyExchange() { + return sessionStructure.hasPendingKeyExchange(); + } + + public void setUnacknowledgedPreKeyMessage(Optional<Integer> preKeyId, int signedPreKeyId, ECPublicKey baseKey) { + PendingPreKey.Builder pending = PendingPreKey.newBuilder() + .setSignedPreKeyId(signedPreKeyId) + .setBaseKey(ByteString.copyFrom(baseKey.serialize())); + + if (preKeyId.isPresent()) { + pending.setPreKeyId(preKeyId.get()); + } + + this.sessionStructure = this.sessionStructure.toBuilder() + .setPendingPreKey(pending.build()) + .build(); + } + + public boolean hasUnacknowledgedPreKeyMessage() { + return this.sessionStructure.hasPendingPreKey(); + } + + public UnacknowledgedPreKeyMessageItems getUnacknowledgedPreKeyMessageItems() { + try { + Optional<Integer> preKeyId; + + if (sessionStructure.getPendingPreKey().hasPreKeyId()) { + preKeyId = Optional.of(sessionStructure.getPendingPreKey().getPreKeyId()); + } else { + preKeyId = Optional.absent(); + } + + return + new UnacknowledgedPreKeyMessageItems(preKeyId, + sessionStructure.getPendingPreKey().getSignedPreKeyId(), + Curve.decodePoint(sessionStructure.getPendingPreKey() + .getBaseKey() + .toByteArray(), 0)); + } catch (InvalidKeyException e) { + throw new AssertionError(e); + } + } + + public void clearUnacknowledgedPreKeyMessage() { + this.sessionStructure = this.sessionStructure.toBuilder() + .clearPendingPreKey() + .build(); + } + + public void setRemoteRegistrationId(int registrationId) { + this.sessionStructure = this.sessionStructure.toBuilder() + .setRemoteRegistrationId(registrationId) + .build(); + } + + public int getRemoteRegistrationId() { + return this.sessionStructure.getRemoteRegistrationId(); + } + + public void setLocalRegistrationId(int registrationId) { + this.sessionStructure = this.sessionStructure.toBuilder() + .setLocalRegistrationId(registrationId) + .build(); + } + + public int getLocalRegistrationId() { + return this.sessionStructure.getLocalRegistrationId(); + } + + public byte[] serialize() { + return sessionStructure.toByteArray(); + } + + public static class UnacknowledgedPreKeyMessageItems { + private final Optional<Integer> preKeyId; + private final int signedPreKeyId; + private final ECPublicKey baseKey; + + public UnacknowledgedPreKeyMessageItems(Optional<Integer> preKeyId, + int signedPreKeyId, + ECPublicKey baseKey) + { + this.preKeyId = preKeyId; + this.signedPreKeyId = signedPreKeyId; + this.baseKey = baseKey; + } + + + public Optional<Integer> getPreKeyId() { + return preKeyId; + } + + public int getSignedPreKeyId() { + return signedPreKeyId; + } + + public ECPublicKey getBaseKey() { + return baseKey; + } + } +} |