View Javadoc
1   package org.metricshub.winrm.service.client.encryption;
2   
3   /*-
4    * ╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲
5    * WinRM Java Client
6    * ჻჻჻჻჻჻
7    * Copyright 2023 - 2024 Metricshub
8    * ჻჻჻჻჻჻
9    * Licensed under the Apache License, Version 2.0 (the "License");
10   * you may not use this file except in compliance with the License.
11   * You may obtain a copy of the License at
12   *
13   *      http://www.apache.org/licenses/LICENSE-2.0
14   *
15   * Unless required by applicable law or agreed to in writing, software
16   * distributed under the License is distributed on an "AS IS" BASIS,
17   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18   * See the License for the specific language governing permissions and
19   * limitations under the License.
20   * ╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱
21   */
22  
23  import java.io.ByteArrayInputStream;
24  import java.io.ByteArrayOutputStream;
25  import java.io.IOException;
26  import java.io.InputStream;
27  import java.util.Arrays;
28  import org.apache.cxf.helpers.IOUtils;
29  import org.apache.cxf.message.Message;
30  import org.metricshub.winrm.service.client.auth.ntlm.NTCredentialsWithEncryption;
31  import org.metricshub.winrm.service.client.auth.ntlm.NTLMEngineUtils;
32  
33  /**
34   * Code from io.cloudsoft.winrm4j.client.encryption.NtlmEncryptionUtils.Decryptor
35   * release 0.12.3 @link https://github.com/cloudsoft/winrm4j
36   */
37  public class Decryptor {
38  
39  	private final NTCredentialsWithEncryption credentials;
40  	private byte[] rawBytes;
41  	private byte[] encryptedPayloadBytes;
42  	private int index;
43  	private int lastBlockStart;
44  	private int lastBlockEnd;
45  	private byte[] signatureBytes;
46  	private byte[] sealedBytes;
47  	private byte[] unsealedBytes;
48  
49  	public Decryptor(final NTCredentialsWithEncryption credentials) {
50  		this.credentials = credentials;
51  	}
52  
53  	public void handle(final Message message) {
54  		final Object contentType = message.get(Message.CONTENT_TYPE);
55  
56  		final boolean isEncrypted = contentType != null && contentType.toString().startsWith("multipart/encrypted");
57  
58  		if (isEncrypted) {
59  			if (credentials == null) {
60  				throw new IllegalStateException("Encrypted payload from server when no credentials with encryption known");
61  			}
62  			if (!credentials.isAuthenticated()) {
63  				throw new IllegalStateException("Encrypted payload from server when not authenticated");
64  			}
65  
66  			try {
67  				decrypt(message);
68  			} catch (final Exception e) {
69  				throw new IllegalStateException(e);
70  			}
71  		} else {
72  			if (credentials != null && credentials.isAuthenticated()) {
73  				throw new IllegalStateException(
74  					"Unencrypted payload from server when authenticated and encryption is required"
75  				);
76  			}
77  		}
78  	}
79  
80  	void decrypt(final Message message) throws IOException {
81  		try (final InputStream in = message.getContent(InputStream.class)) {
82  			rawBytes = IOUtils.readBytesFromStream(in);
83  		}
84  
85  		unwrap();
86  
87  		final int signatureLength = (int) ByteArrayUtils.readLittleEndianUnsignedInt(encryptedPayloadBytes, 0);
88  		signatureBytes = Arrays.copyOfRange(encryptedPayloadBytes, 4, 4 + signatureLength);
89  		sealedBytes = Arrays.copyOfRange(encryptedPayloadBytes, 4 + signatureLength, encryptedPayloadBytes.length);
90  
91  		unseal();
92  
93  		// should set length and type headers - but they don't seem to be needed!
94  
95  		verify();
96  
97  		message.setContent(InputStream.class, new ByteArrayInputStream(unsealedBytes));
98  	}
99  
100 	private void verify() throws IOException {
101 		final long seqNum = ByteArrayUtils.readLittleEndianUnsignedInt(signatureBytes, 12);
102 		final int checkSumOffset = credentials.hasNegotiateFlag(NTLMEngineUtils.NTLMSSP_NEGOTIATE_EXTENDED_SESSIONSECURITY)
103 			? 4
104 			: 8;
105 
106 		final byte[] checksum = Arrays.copyOfRange(signatureBytes, checkSumOffset, 12);
107 
108 		try (final ByteArrayOutputStream signature = new ByteArrayOutputStream()) {
109 			NtlmEncryptionUtils.calculateSignature(
110 				unsealedBytes,
111 				seqNum,
112 				signature,
113 				credentials,
114 				NTCredentialsWithEncryption::getServerSigningKey,
115 				credentials.getStatefulDecryptor()::update
116 			);
117 
118 			final byte[] expectedChecksum = Arrays.copyOfRange(signature.toByteArray(), checkSumOffset, 12);
119 			final long expectedSeqNum = ByteArrayUtils.readLittleEndianUnsignedInt(signature.toByteArray(), 12);
120 
121 			if (!Arrays.equals(checksum, expectedChecksum)) {
122 				throw new IllegalStateException(
123 					String.format(
124 						"Checksum mismatch\n%s--\n%s",
125 						ByteArrayUtils.formatHexDump(checksum),
126 						ByteArrayUtils.formatHexDump(expectedChecksum)
127 					)
128 				);
129 			}
130 
131 			if (expectedSeqNum != seqNum) {
132 				throw new IllegalStateException(String.format("Sequence number mismatch: %d != %d", seqNum, expectedSeqNum));
133 			}
134 		}
135 
136 		credentials.getSequenceNumberIncoming().incrementAndGet();
137 	}
138 
139 	void unwrap() {
140 		index = 0;
141 		skipOver(NtlmEncryptionUtils.ENCRYPTED_BOUNDARY_CR);
142 		skipUntil("\n" + NtlmEncryptionUtils.ENCRYPTED_BOUNDARY_CR);
143 		skipUntil("\r\n");
144 
145 		// for credssh de-chunking might be needed, but not for ntlm
146 
147 		lastBlockStart = index;
148 		lastBlockEnd = rawBytes.length - NtlmEncryptionUtils.ENCRYPTED_BOUNDARY_END.length();
149 		index = lastBlockEnd;
150 		skipOver(NtlmEncryptionUtils.ENCRYPTED_BOUNDARY_END);
151 
152 		encryptedPayloadBytes = Arrays.copyOfRange(rawBytes, lastBlockStart, lastBlockEnd);
153 	}
154 
155 	void skipOver(final String s) {
156 		skipOver(s.getBytes());
157 	}
158 
159 	void skipOver(final byte[] expected) {
160 		int i = 0;
161 		while (i < expected.length) {
162 			if (index >= rawBytes.length) {
163 				throw new IllegalStateException(
164 					String.format(
165 						"Invalid format for response from server; terminated early (%d) when expecting '%s'\n%s",
166 						i,
167 						new String(expected),
168 						ByteArrayUtils.formatHexDump(rawBytes)
169 					)
170 				);
171 			}
172 
173 			if (expected[i++] != rawBytes[index++]) {
174 				throw new IllegalStateException(
175 					String.format(
176 						"Invalid format for response from server; mismatch at position %d (%d) when expecting '%s'\n%s",
177 						index,
178 						i,
179 						new String(expected),
180 						ByteArrayUtils.formatHexDump(rawBytes)
181 					)
182 				);
183 			}
184 		}
185 	}
186 
187 	void skipUntil(final String str) {
188 		final byte[] expected = str.getBytes();
189 		int nextBlock = index;
190 		outer:while (true) {
191 			for (int i = 0; i < expected.length && nextBlock + i < rawBytes.length; i++) {
192 				if (nextBlock + i >= rawBytes.length) {
193 					throw new IllegalStateException(
194 						String.format(
195 							"Invalid format for response from server; terminated early (%d) when looking for '%s'\n%s",
196 							i,
197 							new String(expected),
198 							ByteArrayUtils.formatHexDump(rawBytes)
199 						)
200 					);
201 				}
202 				if (expected[i] != rawBytes[nextBlock + i]) {
203 					nextBlock++;
204 					continue outer;
205 				}
206 			}
207 			lastBlockStart = index;
208 			lastBlockEnd = nextBlock;
209 			index = nextBlock + expected.length;
210 			return;
211 		}
212 	}
213 
214 	private void unseal() {
215 		unsealedBytes = credentials.getStatefulDecryptor().update(sealedBytes);
216 	}
217 }