View Javadoc
1   package org.metricshub.winrm.service;
2   
3   /*-
4    * ╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲
5    * WinRM Java Client
6    * ჻჻჻჻჻჻
7    * Copyright 2023 - 2024 Metricshub
8    * ჻჻჻჻჻჻
9    * Licensed under the Apache License, Version 2.0 (the "License");
10   * you may not use this file except in compliance with the License.
11   * You may obtain a copy of the License at
12   *
13   *      http://www.apache.org/licenses/LICENSE-2.0
14   *
15   * Unless required by applicable law or agreed to in writing, software
16   * distributed under the License is distributed on an "AS IS" BASIS,
17   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18   * See the License for the specific language governing permissions and
19   * limitations under the License.
20   * ╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱
21   */
22  
23  import jakarta.xml.ws.BindingProvider;
24  import jakarta.xml.ws.WebServiceException;
25  import jakarta.xml.ws.handler.Handler;
26  import jakarta.xml.ws.soap.SOAPFaultException;
27  import java.io.IOException;
28  import java.lang.reflect.InvocationHandler;
29  import java.lang.reflect.InvocationTargetException;
30  import java.lang.reflect.Method;
31  import java.net.URL;
32  import java.nio.file.Path;
33  import java.util.Arrays;
34  import java.util.Collections;
35  import java.util.HashMap;
36  import java.util.LinkedList;
37  import java.util.List;
38  import java.util.Map;
39  import java.util.Objects;
40  import java.util.Queue;
41  import java.util.concurrent.ConcurrentHashMap;
42  import java.util.stream.Collectors;
43  import java.util.stream.Stream;
44  import javax.net.ssl.TrustManager;
45  import javax.xml.namespace.QName;
46  import org.apache.cxf.Bus;
47  import org.apache.cxf.binding.soap.SoapBindingConstants;
48  import org.apache.cxf.configuration.jsse.TLSClientParameters;
49  import org.apache.cxf.endpoint.Client;
50  import org.apache.cxf.frontend.ClientProxy;
51  import org.apache.cxf.jaxws.JaxWsProxyFactoryBean;
52  import org.apache.cxf.message.Message;
53  import org.apache.cxf.service.model.ServiceInfo;
54  import org.apache.cxf.transport.http.HTTPConduitFactory;
55  import org.apache.cxf.transport.http.asyncclient.AsyncHTTPConduit;
56  import org.apache.cxf.transports.http.configuration.HTTPClientPolicy;
57  import org.apache.cxf.ws.addressing.WSAddressingFeature;
58  import org.apache.cxf.ws.addressing.WSAddressingFeature.AddressingResponses;
59  import org.apache.cxf.ws.addressing.policy.MetadataConstants;
60  import org.apache.cxf.ws.policy.PolicyConstants;
61  import org.apache.http.auth.AuthSchemeProvider;
62  import org.apache.http.auth.Credentials;
63  import org.apache.http.auth.NTCredentials;
64  import org.apache.http.client.config.AuthSchemes;
65  import org.apache.http.config.Registry;
66  import org.apache.http.config.RegistryBuilder;
67  import org.apache.http.impl.auth.KerberosSchemeFactory;
68  import org.apache.neethi.Policy;
69  import org.apache.neethi.builders.PrimitiveAssertion;
70  import org.metricshub.winrm.Utils;
71  import org.metricshub.winrm.WinRMHttpProtocolEnum;
72  import org.metricshub.winrm.service.client.auth.AuthenticationEnum;
73  import org.metricshub.winrm.service.client.auth.TrustAllX509Manager;
74  import org.metricshub.winrm.service.client.auth.kerberos.KerberosUtils;
75  import org.metricshub.winrm.service.client.auth.ntlm.NTCredentialsWithEncryption;
76  import org.metricshub.winrm.service.client.auth.ntlm.NtlmMasqAsSpnegoSchemeFactory;
77  import org.metricshub.winrm.service.client.encryption.AsyncHttpEncryptionAwareConduitFactory;
78  import org.metricshub.winrm.service.client.encryption.DecryptAndVerifyInInterceptor;
79  import org.metricshub.winrm.service.client.encryption.SignAndEncryptOutInterceptor;
80  
81  public class WinRMInvocationHandler implements InvocationHandler {
82  
83  	public static final String WSMAN_SCHEMA_NAMESPACE = "http://schemas.dmtf.org/wbem/wsman/1/wsman.xsd";
84  
85  	private static final long PAUSE_TIME_MILLISECONDS = 500;
86  	private static final int MAX_RETRY = 3;
87  
88  	private static final URL WSDL_LOCATION_URL =
89  		WinRMWebServiceClient.class.getClassLoader().getResource("wsdl/WinRM.wsdl");
90  
91  	private static final QName SERVICE = new QName(WSMAN_SCHEMA_NAMESPACE, "WinRMWebServiceClient");
92  
93  	private static final QName PORT = new QName(WSMAN_SCHEMA_NAMESPACE, "WinRMPort");
94  
95  	private static final List<String> CONTENT_TYPE_LIST = Collections.singletonList("application/soap+xml;charset=UTF-8");
96  
97  	@SuppressWarnings("rawtypes")
98  	private static final List<Handler> HANDLER_CHAIN = Arrays.asList(new StripShellResponseHandler());
99  
100 	private static final Registry<AuthSchemeProvider> AUTH_SCHEME_REGISTRY = RegistryBuilder
101 		.<AuthSchemeProvider>create()
102 		.register(AuthSchemes.SPNEGO, new NtlmMasqAsSpnegoSchemeFactory())
103 		.register(AuthSchemes.KERBEROS, new KerberosSchemeFactory(true))
104 		.build();
105 
106 	private static final Policy POLICY;
107 
108 	static {
109 		POLICY = new Policy();
110 		POLICY.addAssertion(new PrimitiveAssertion(MetadataConstants.USING_ADDRESSING_2004_QNAME));
111 	}
112 
113 	private static final WSAddressingFeature WS_ADDRESSING_FEATURE;
114 
115 	static {
116 		WS_ADDRESSING_FEATURE = new WSAddressingFeature();
117 		WS_ADDRESSING_FEATURE.setResponses(AddressingResponses.ANONYMOUS);
118 	}
119 
120 	private static final TLSClientParameters TLS_CLIENT_PARAMETERS;
121 
122 	static {
123 		TLS_CLIENT_PARAMETERS = new TLSClientParameters();
124 		TLS_CLIENT_PARAMETERS.setDisableCNCheck(true);
125 		// Accept all certificates
126 		TLS_CLIENT_PARAMETERS.setTrustManagers(new TrustManager[] { new TrustAllX509Manager() });
127 	}
128 
129 	private static final Map<CredentialsMapKey, Credentials> CREDENTIALS = new ConcurrentHashMap<>();
130 
131 	private final WinRMWebService winRMWebService;
132 	private final WinRMEndpoint winRMEndpoint;
133 	private final long timeout;
134 	private final String resourceUri;
135 	private final Path ticketCache;
136 	private final Queue<AuthenticationEnum> authenticationsQueue;
137 	private AuthenticationEnum authentication;
138 	private Client wsClient;
139 
140 	/**
141 	 * WinRMInvocationHandler constructor
142 	 *
143 	 * @param winRMEndpoint Endpoint with credentials (mandatory)
144 	 * @param bus Apache CXF Bus (mandatory)
145 	 * @param timeout Timeout used for Connection, Connection Request and Receive Request in milliseconds
146 	 * @param resourceUri The enumerate resource URI
147 	 * @param ticketCache The Ticket Cache path
148 	 * @param authentications List of authentications. (mandatory)
149 	 */
150 	public WinRMInvocationHandler(
151 		final WinRMEndpoint winRMEndpoint,
152 		final Bus bus,
153 		final long timeout,
154 		final String resourceUri,
155 		final Path ticketCache,
156 		final List<AuthenticationEnum> authentications
157 	) {
158 		Utils.checkNonNull(winRMEndpoint, "winRMEndpoint");
159 		Utils.checkNonNull(bus, "bus");
160 		Utils.checkNonNull(authentications, "authentications");
161 
162 		this.winRMEndpoint = winRMEndpoint;
163 		this.timeout = timeout;
164 		this.resourceUri = resourceUri;
165 		this.ticketCache = ticketCache;
166 		authenticationsQueue = authentications.stream().collect(Collectors.toCollection(LinkedList::new));
167 
168 		winRMWebService = createWinRMWebService(winRMEndpoint, bus);
169 
170 		final AuthCredentials authCredentials = computeCredentials(winRMEndpoint, ticketCache, authenticationsQueue);
171 
172 		authentication = authCredentials.getAuthentication();
173 
174 		wsClient =
175 			getWebServiceClient(winRMEndpoint, timeout, resourceUri, winRMWebService, authCredentials.getCredentials());
176 	}
177 
178 	public Client getClient() {
179 		return wsClient;
180 	}
181 
182 	@Override
183 	public Object invoke(final Object proxy, final Method method, final Object[] args) throws Throwable {
184 		Utils.checkNonNull(method, "method");
185 
186 		try {
187 			return invokeMethod(method, args);
188 		} catch (final RetryTgtExpirationException e) {
189 			// retry with a new TGT in case of current TGT expiration
190 			authentication = null;
191 
192 			Credentials credentials;
193 			try {
194 				credentials =
195 					KerberosUtils.createCredentials(winRMEndpoint.getUsername(), winRMEndpoint.getPassword(), ticketCache);
196 
197 				CREDENTIALS.put(new CredentialsMapKey(winRMEndpoint, ticketCache, AuthenticationEnum.KERBEROS), credentials);
198 				// Normally that should not happen as any other exception on KERBEROs should had been throw
199 				// at the first KERBEROS call
200 			} catch (final Exception e1) {
201 				if (continueToRetry()) {
202 					final AuthCredentials authCredentials = computeCredentials(winRMEndpoint, ticketCache, authenticationsQueue);
203 
204 					authentication = authCredentials.getAuthentication();
205 					credentials = authCredentials.getCredentials();
206 				} else {
207 					throw e1;
208 				}
209 			}
210 
211 			wsClient = getWebServiceClient(winRMEndpoint, timeout, resourceUri, winRMWebService, credentials);
212 
213 			return invoke(proxy, method, args);
214 		} catch (final RetryAuthenticationException e) {
215 			if (continueToRetry()) {
216 				final AuthCredentials authCredentials = computeCredentials(winRMEndpoint, ticketCache, authenticationsQueue);
217 
218 				authentication = authCredentials.getAuthentication();
219 
220 				wsClient =
221 					getWebServiceClient(winRMEndpoint, timeout, resourceUri, winRMWebService, authCredentials.getCredentials());
222 
223 				return invoke(proxy, method, args);
224 			}
225 
226 			// No more retries
227 			final Throwable cause = e.getCause();
228 			if (cause instanceof SOAPFaultException) {
229 				throw new RuntimeException("KERBEROS with encryption over HTTP is not implemented.", cause);
230 			}
231 			throw cause;
232 		}
233 	}
234 
235 	// this function is only needed for the unit testing
236 	boolean continueToRetry() {
237 		return !authenticationsQueue.isEmpty();
238 	}
239 
240 	Object invokeMethod(final Method method, final Object[] args)
241 		throws IllegalAccessException, RetryAuthenticationException {
242 		Throwable firstEx = null;
243 		int retry = 0;
244 
245 		while (retry < MAX_RETRY) {
246 			retry++;
247 
248 			try {
249 				return method.invoke(winRMWebService, args);
250 			} catch (final InvocationTargetException ite) {
251 				final Throwable targetEx = ite.getTargetException();
252 
253 				if (targetEx instanceof SOAPFaultException) {
254 					// Could retry with a different authentication than NTLM
255 					// because it could be a "WstxEOFException: Unexpected EOF in prolog"
256 					// due to a KERBEROS with HTTP and AllowUnencrypted=false
257 					if (winRMEndpoint.getProtocol() == WinRMHttpProtocolEnum.HTTP && authentication != AuthenticationEnum.NTLM) {
258 						throw new RetryAuthenticationException(targetEx);
259 					}
260 					throw (SOAPFaultException) targetEx;
261 				}
262 
263 				if (!(targetEx instanceof WebServiceException)) {
264 					throw new IllegalStateException("Failure when calling " + createCallInfos(method, args), targetEx);
265 				}
266 
267 				final WebServiceException wsEx = (WebServiceException) targetEx;
268 
269 				if (!(wsEx.getCause() instanceof IOException)) {
270 					throw new RuntimeException(
271 						"Exception occurred while making WinRM WebService call " + createCallInfos(method, args),
272 						wsEx
273 					);
274 				}
275 
276 				if (
277 					wsEx.getCause().getMessage() != null &&
278 					wsEx.getCause().getMessage().startsWith("Authorization loop detected on Conduit")
279 				) {
280 					final RuntimeException authEx = new RuntimeException(
281 						String.format(
282 							"Authentication error on %s with user name \"%s\"",
283 							winRMEndpoint.getEndpoint(),
284 							winRMEndpoint.getRawUsername()
285 						)
286 					);
287 
288 					// Could be due to a TGT expiration
289 					if (authentication == AuthenticationEnum.KERBEROS) {
290 						throw new RetryTgtExpirationException(authEx);
291 					}
292 					// Could retry with a different authentication
293 					throw new RetryAuthenticationException(authEx);
294 				}
295 
296 				if (firstEx == null) {
297 					firstEx = wsEx;
298 				}
299 
300 				if (retry < MAX_RETRY) {
301 					try {
302 						Utils.sleep(PAUSE_TIME_MILLISECONDS);
303 					} catch (final InterruptedException ie) {
304 						Thread.currentThread().interrupt();
305 						throw new RuntimeException(
306 							"Exception occured while making WinRM WebService call " + createCallInfos(method, args),
307 							ie
308 						);
309 					}
310 				}
311 			}
312 		}
313 
314 		throw new RuntimeException(
315 			String.format("failed task \"%s\" after %d attempts", createCallInfos(method, args), MAX_RETRY),
316 			firstEx
317 		);
318 	}
319 
320 	static String createCallInfos(final Method method, final Object[] args) {
321 		final String name = method != null && method.getName() != null ? method.getName() : Utils.EMPTY;
322 		return args == null
323 			? name
324 			: Stream
325 				.concat(Stream.of(name), Stream.of(args))
326 				.filter(Objects::nonNull)
327 				.map(Object::toString)
328 				.collect(Collectors.joining(" "));
329 	}
330 
331 	static Credentials createCredentials(
332 		final WinRMEndpoint winRMEndpoint,
333 		final AuthenticationEnum authentication,
334 		final Path ticketCache
335 	) {
336 		switch (authentication) {
337 			case KERBEROS:
338 				return KerberosUtils.createCredentials(winRMEndpoint.getUsername(), winRMEndpoint.getPassword(), ticketCache);
339 			case NTLM:
340 			default:
341 				final String password = String.valueOf(winRMEndpoint.getPassword());
342 				return winRMEndpoint.getProtocol() == WinRMHttpProtocolEnum.HTTP
343 					? new NTCredentialsWithEncryption(winRMEndpoint.getUsername(), password, null, winRMEndpoint.getDomain())
344 					: new NTCredentials(winRMEndpoint.getUsername(), password, null, winRMEndpoint.getDomain());
345 		}
346 	}
347 
348 	static AuthCredentials computeCredentials(
349 		final WinRMEndpoint winRMEndpoint,
350 		final Path ticketCache,
351 		final Queue<AuthenticationEnum> authenticationsQueue
352 	) {
353 		try {
354 			final AuthenticationEnum authenticationEnum = authenticationsQueue.remove();
355 
356 			final Credentials credentials = CREDENTIALS.compute(
357 				new CredentialsMapKey(winRMEndpoint, ticketCache, authenticationEnum),
358 				(user, cred) -> cred != null ? cred : createCredentials(winRMEndpoint, authenticationEnum, ticketCache)
359 			);
360 
361 			return new AuthCredentials(authenticationEnum, credentials);
362 		} catch (final Exception e) {
363 			// if there's still retry
364 			if (!authenticationsQueue.isEmpty()) {
365 				return computeCredentials(winRMEndpoint, ticketCache, authenticationsQueue);
366 			}
367 			throw e;
368 		}
369 	}
370 
371 	static WinRMWebService createWinRMWebService(final WinRMEndpoint winRMEndpoint, final Bus bus) {
372 		final JaxWsProxyFactoryBean jaxWsProxyFactoryBean = new JaxWsProxyFactoryBean();
373 		jaxWsProxyFactoryBean.setServiceName(SERVICE);
374 		jaxWsProxyFactoryBean.setEndpointName(PORT);
375 		jaxWsProxyFactoryBean.setBus(bus);
376 		jaxWsProxyFactoryBean.setServiceClass(WinRMWebService.class);
377 		jaxWsProxyFactoryBean.setAddress(winRMEndpoint.getEndpoint());
378 		jaxWsProxyFactoryBean.getFeatures().add(WS_ADDRESSING_FEATURE);
379 		jaxWsProxyFactoryBean.setBindingId(SoapBindingConstants.SOAP12_BINDING_ID);
380 		jaxWsProxyFactoryBean.getClientFactoryBean().getServiceFactory().setWsdlURL(WSDL_LOCATION_URL);
381 
382 		return jaxWsProxyFactoryBean.create(WinRMWebService.class);
383 	}
384 
385 	static Client getWebServiceClient(
386 		final WinRMEndpoint winRMEndpoint,
387 		final long timeout,
388 		final String enumerateResourceUri,
389 		final WinRMWebService winRMWebService,
390 		final Credentials credentials
391 	) {
392 		final Client client = ClientProxy.getClient(winRMWebService);
393 
394 		if (enumerateResourceUri != null) {
395 			final WSManHeaderInterceptor interceptor = new WSManHeaderInterceptor(enumerateResourceUri);
396 			client.getOutInterceptors().add(interceptor);
397 		}
398 
399 		client.getInInterceptors().add(new DecryptAndVerifyInInterceptor());
400 		client.getOutInterceptors().add(new SignAndEncryptOutInterceptor());
401 
402 		// this is different to endpoint properties
403 		client
404 			.getEndpoint()
405 			.getEndpointInfo()
406 			.setProperty(HTTPConduitFactory.class.getName(), new AsyncHttpEncryptionAwareConduitFactory());
407 
408 		final ServiceInfo serviceInfo = client.getEndpoint().getEndpointInfo().getService();
409 		serviceInfo.setProperty("soap.force.doclit.bare", true);
410 
411 		final BindingProvider bindingProvider = (BindingProvider) winRMWebService;
412 		bindingProvider.getBinding().setHandlerChain(HANDLER_CHAIN);
413 		bindingProvider.getRequestContext().put(PolicyConstants.POLICY_OVERRIDE, POLICY);
414 		bindingProvider.getRequestContext().put("http.autoredirect", true);
415 
416 		bindingProvider.getRequestContext().put(BindingProvider.ENDPOINT_ADDRESS_PROPERTY, winRMEndpoint.getEndpoint());
417 
418 		final Map<String, List<String>> headers = new HashMap<>();
419 		headers.put("Content-Type", CONTENT_TYPE_LIST);
420 
421 		bindingProvider.getRequestContext().put(Message.PROTOCOL_HEADERS, headers);
422 
423 		// Setup timeouts
424 		final HTTPClientPolicy httpClientPolicy = new HTTPClientPolicy();
425 		httpClientPolicy.setConnectionTimeout(timeout);
426 		httpClientPolicy.setConnectionRequestTimeout(timeout);
427 		httpClientPolicy.setReceiveTimeout(timeout);
428 		httpClientPolicy.setAllowChunking(false);
429 
430 		bindingProvider.getRequestContext().put(Credentials.class.getName(), credentials);
431 		bindingProvider.getRequestContext().put(AuthSchemeProvider.class.getName(), AUTH_SCHEME_REGISTRY);
432 
433 		final AsyncHTTPConduit asyncHTTPConduit = (AsyncHTTPConduit) client.getConduit();
434 		asyncHTTPConduit.setClient(httpClientPolicy);
435 		asyncHTTPConduit.getClient().setAutoRedirect(true);
436 		asyncHTTPConduit.setTlsClientParameters(TLS_CLIENT_PARAMETERS);
437 
438 		return client;
439 	}
440 
441 	static class RetryAuthenticationException extends Exception {
442 
443 		private static final long serialVersionUID = 1L;
444 
445 		RetryAuthenticationException(final Throwable throwable) {
446 			super(throwable);
447 		}
448 	}
449 
450 	static class RetryTgtExpirationException extends RetryAuthenticationException {
451 
452 		private static final long serialVersionUID = 1L;
453 
454 		RetryTgtExpirationException(final Throwable throwable) {
455 			super(throwable);
456 		}
457 	}
458 
459 	static class AuthCredentials {
460 
461 		private final AuthenticationEnum authentication;
462 		private final Credentials credentials;
463 
464 		AuthCredentials(final AuthenticationEnum authentication, final Credentials credentials) {
465 			this.authentication = authentication;
466 			this.credentials = credentials;
467 		}
468 
469 		public AuthenticationEnum getAuthentication() {
470 			return authentication;
471 		}
472 
473 		public Credentials getCredentials() {
474 			return credentials;
475 		}
476 
477 		@Override
478 		public int hashCode() {
479 			return Objects.hash(authentication, credentials);
480 		}
481 
482 		@Override
483 		public boolean equals(final Object obj) {
484 			if (this == obj) {
485 				return true;
486 			}
487 			if (obj == null) {
488 				return false;
489 			}
490 			if (!(obj instanceof AuthCredentials)) {
491 				return false;
492 			}
493 			final AuthCredentials other = (AuthCredentials) obj;
494 			return authentication == other.authentication && Objects.equals(credentials, other.credentials);
495 		}
496 	}
497 
498 	static class CredentialsMapKey {
499 
500 		private final String canonizedRawUsername;
501 		private final char[] password;
502 		private final Path ticketCache;
503 		private final AuthenticationEnum authentication;
504 
505 		CredentialsMapKey(
506 			final WinRMEndpoint winRMEndpoint,
507 			final Path ticketCache,
508 			final AuthenticationEnum authentication
509 		) {
510 			this.ticketCache = ticketCache;
511 			this.authentication = authentication;
512 
513 			password = winRMEndpoint.getPassword();
514 			canonizedRawUsername =
515 				winRMEndpoint.getRawUsername() != null
516 					? winRMEndpoint.getRawUsername().replaceAll("\\s", Utils.EMPTY).toUpperCase()
517 					: null;
518 		}
519 
520 		@Override
521 		public int hashCode() {
522 			final int prime = 31;
523 			int result = 1;
524 			result = prime * result + Arrays.hashCode(password);
525 			result = prime * result + Objects.hash(authentication, canonizedRawUsername, ticketCache);
526 			return result;
527 		}
528 
529 		@Override
530 		public boolean equals(final Object obj) {
531 			if (this == obj) {
532 				return true;
533 			}
534 			if (obj == null) {
535 				return false;
536 			}
537 			if (!(obj instanceof CredentialsMapKey)) {
538 				return false;
539 			}
540 			final CredentialsMapKey other = (CredentialsMapKey) obj;
541 			return (
542 				authentication == other.authentication &&
543 				Objects.equals(canonizedRawUsername, other.canonizedRawUsername) &&
544 				Arrays.equals(password, other.password) &&
545 				Objects.equals(ticketCache, other.ticketCache)
546 			);
547 		}
548 	}
549 }