1 package org.metricshub.winrm.service.client.encryption;
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 import java.io.ByteArrayInputStream;
24 import java.io.ByteArrayOutputStream;
25 import java.io.IOException;
26 import java.io.InputStream;
27 import java.util.Arrays;
28 import org.apache.cxf.helpers.IOUtils;
29 import org.apache.cxf.message.Message;
30 import org.metricshub.winrm.service.client.auth.ntlm.NTCredentialsWithEncryption;
31 import org.metricshub.winrm.service.client.auth.ntlm.NTLMEngineUtils;
32
33
34
35
36
37 public class Decryptor {
38
39 private final NTCredentialsWithEncryption credentials;
40 private byte[] rawBytes;
41 private byte[] encryptedPayloadBytes;
42 private int index;
43 private int lastBlockStart;
44 private int lastBlockEnd;
45 private byte[] signatureBytes;
46 private byte[] sealedBytes;
47 private byte[] unsealedBytes;
48
49 public Decryptor(final NTCredentialsWithEncryption credentials) {
50 this.credentials = credentials;
51 }
52
53 public void handle(final Message message) {
54 final Object contentType = message.get(Message.CONTENT_TYPE);
55
56 final boolean isEncrypted = contentType != null && contentType.toString().startsWith("multipart/encrypted");
57
58 if (isEncrypted) {
59 if (credentials == null) {
60 throw new IllegalStateException("Encrypted payload from server when no credentials with encryption known");
61 }
62 if (!credentials.isAuthenticated()) {
63 throw new IllegalStateException("Encrypted payload from server when not authenticated");
64 }
65
66 try {
67 decrypt(message);
68 } catch (final Exception e) {
69 throw new IllegalStateException(e);
70 }
71 } else {
72 if (credentials != null && credentials.isAuthenticated()) {
73 throw new IllegalStateException(
74 "Unencrypted payload from server when authenticated and encryption is required"
75 );
76 }
77 }
78 }
79
80 void decrypt(final Message message) throws IOException {
81 try (final InputStream in = message.getContent(InputStream.class)) {
82 rawBytes = IOUtils.readBytesFromStream(in);
83 }
84
85 unwrap();
86
87 final int signatureLength = (int) ByteArrayUtils.readLittleEndianUnsignedInt(encryptedPayloadBytes, 0);
88 signatureBytes = Arrays.copyOfRange(encryptedPayloadBytes, 4, 4 + signatureLength);
89 sealedBytes = Arrays.copyOfRange(encryptedPayloadBytes, 4 + signatureLength, encryptedPayloadBytes.length);
90
91 unseal();
92
93
94
95 verify();
96
97 message.setContent(InputStream.class, new ByteArrayInputStream(unsealedBytes));
98 }
99
100 private void verify() throws IOException {
101 final long seqNum = ByteArrayUtils.readLittleEndianUnsignedInt(signatureBytes, 12);
102 final int checkSumOffset = credentials.hasNegotiateFlag(NTLMEngineUtils.NTLMSSP_NEGOTIATE_EXTENDED_SESSIONSECURITY)
103 ? 4
104 : 8;
105
106 final byte[] checksum = Arrays.copyOfRange(signatureBytes, checkSumOffset, 12);
107
108 try (final ByteArrayOutputStream signature = new ByteArrayOutputStream()) {
109 NtlmEncryptionUtils.calculateSignature(
110 unsealedBytes,
111 seqNum,
112 signature,
113 credentials,
114 NTCredentialsWithEncryption::getServerSigningKey,
115 credentials.getStatefulDecryptor()::update
116 );
117
118 final byte[] expectedChecksum = Arrays.copyOfRange(signature.toByteArray(), checkSumOffset, 12);
119 final long expectedSeqNum = ByteArrayUtils.readLittleEndianUnsignedInt(signature.toByteArray(), 12);
120
121 if (!Arrays.equals(checksum, expectedChecksum)) {
122 throw new IllegalStateException(
123 String.format(
124 "Checksum mismatch\n%s--\n%s",
125 ByteArrayUtils.formatHexDump(checksum),
126 ByteArrayUtils.formatHexDump(expectedChecksum)
127 )
128 );
129 }
130
131 if (expectedSeqNum != seqNum) {
132 throw new IllegalStateException(String.format("Sequence number mismatch: %d != %d", seqNum, expectedSeqNum));
133 }
134 }
135
136 credentials.getSequenceNumberIncoming().incrementAndGet();
137 }
138
139 void unwrap() {
140 index = 0;
141 skipOver(NtlmEncryptionUtils.ENCRYPTED_BOUNDARY_CR);
142 skipUntil("\n" + NtlmEncryptionUtils.ENCRYPTED_BOUNDARY_CR);
143 skipUntil("\r\n");
144
145
146
147 lastBlockStart = index;
148 lastBlockEnd = rawBytes.length - NtlmEncryptionUtils.ENCRYPTED_BOUNDARY_END.length();
149 index = lastBlockEnd;
150 skipOver(NtlmEncryptionUtils.ENCRYPTED_BOUNDARY_END);
151
152 encryptedPayloadBytes = Arrays.copyOfRange(rawBytes, lastBlockStart, lastBlockEnd);
153 }
154
155 void skipOver(final String s) {
156 skipOver(s.getBytes());
157 }
158
159 void skipOver(final byte[] expected) {
160 int i = 0;
161 while (i < expected.length) {
162 if (index >= rawBytes.length) {
163 throw new IllegalStateException(
164 String.format(
165 "Invalid format for response from server; terminated early (%d) when expecting '%s'\n%s",
166 i,
167 new String(expected),
168 ByteArrayUtils.formatHexDump(rawBytes)
169 )
170 );
171 }
172
173 if (expected[i++] != rawBytes[index++]) {
174 throw new IllegalStateException(
175 String.format(
176 "Invalid format for response from server; mismatch at position %d (%d) when expecting '%s'\n%s",
177 index,
178 i,
179 new String(expected),
180 ByteArrayUtils.formatHexDump(rawBytes)
181 )
182 );
183 }
184 }
185 }
186
187 void skipUntil(final String str) {
188 final byte[] expected = str.getBytes();
189 int nextBlock = index;
190 outer:while (true) {
191 for (int i = 0; i < expected.length && nextBlock + i < rawBytes.length; i++) {
192 if (nextBlock + i >= rawBytes.length) {
193 throw new IllegalStateException(
194 String.format(
195 "Invalid format for response from server; terminated early (%d) when looking for '%s'\n%s",
196 i,
197 new String(expected),
198 ByteArrayUtils.formatHexDump(rawBytes)
199 )
200 );
201 }
202 if (expected[i] != rawBytes[nextBlock + i]) {
203 nextBlock++;
204 continue outer;
205 }
206 }
207 lastBlockStart = index;
208 lastBlockEnd = nextBlock;
209 index = nextBlock + expected.length;
210 return;
211 }
212 }
213
214 private void unseal() {
215 unsealedBytes = credentials.getStatefulDecryptor().update(sealedBytes);
216 }
217 }