View Javadoc
1   package org.metricshub.winrm.service.client;
2   
3   import jakarta.xml.ws.BindingProvider;
4   import jakarta.xml.ws.WebServiceException;
5   import jakarta.xml.ws.handler.Handler;
6   import jakarta.xml.ws.soap.SOAPFaultException;
7   /*-
8    * ╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲
9    * WinRM Java Client
10   * ჻჻჻჻჻჻
11   * Copyright 2023 - 2024 Metricshub
12   * ჻჻჻჻჻჻
13   * Licensed under the Apache License, Version 2.0 (the "License");
14   * you may not use this file except in compliance with the License.
15   * You may obtain a copy of the License at
16   *
17   *      http://www.apache.org/licenses/LICENSE-2.0
18   *
19   * Unless required by applicable law or agreed to in writing, software
20   * distributed under the License is distributed on an "AS IS" BASIS,
21   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22   * See the License for the specific language governing permissions and
23   * limitations under the License.
24   * ╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱╲╱
25   */
26  
27  import java.io.IOException;
28  import java.lang.reflect.InvocationHandler;
29  import java.lang.reflect.InvocationTargetException;
30  import java.lang.reflect.Method;
31  import java.net.URL;
32  import java.nio.file.Path;
33  import java.util.Arrays;
34  import java.util.Collections;
35  import java.util.HashMap;
36  import java.util.LinkedList;
37  import java.util.List;
38  import java.util.Map;
39  import java.util.Objects;
40  import java.util.Queue;
41  import java.util.concurrent.ConcurrentHashMap;
42  import java.util.stream.Collectors;
43  import java.util.stream.Stream;
44  import javax.net.ssl.TrustManager;
45  import javax.xml.namespace.QName;
46  import org.apache.cxf.Bus;
47  import org.apache.cxf.binding.soap.SoapBindingConstants;
48  import org.apache.cxf.configuration.jsse.TLSClientParameters;
49  import org.apache.cxf.endpoint.Client;
50  import org.apache.cxf.frontend.ClientProxy;
51  import org.apache.cxf.jaxws.JaxWsProxyFactoryBean;
52  import org.apache.cxf.message.Message;
53  import org.apache.cxf.service.model.ServiceInfo;
54  import org.apache.cxf.transport.http.HTTPConduitFactory;
55  import org.apache.cxf.transport.http.asyncclient.AsyncHTTPConduit;
56  import org.apache.cxf.transports.http.configuration.HTTPClientPolicy;
57  import org.apache.cxf.ws.addressing.WSAddressingFeature;
58  import org.apache.cxf.ws.addressing.WSAddressingFeature.AddressingResponses;
59  import org.apache.cxf.ws.addressing.policy.MetadataConstants;
60  import org.apache.cxf.ws.policy.PolicyConstants;
61  import org.apache.http.auth.AuthSchemeProvider;
62  import org.apache.http.auth.Credentials;
63  import org.apache.http.auth.NTCredentials;
64  import org.apache.http.client.config.AuthSchemes;
65  import org.apache.http.config.Registry;
66  import org.apache.http.config.RegistryBuilder;
67  import org.apache.http.impl.auth.KerberosSchemeFactory;
68  import org.apache.neethi.Policy;
69  import org.apache.neethi.builders.PrimitiveAssertion;
70  import org.metricshub.winrm.Utils;
71  import org.metricshub.winrm.WinRMHttpProtocolEnum;
72  import org.metricshub.winrm.service.WinRMEndpoint;
73  import org.metricshub.winrm.service.WinRMWebService;
74  import org.metricshub.winrm.service.WinRMWebServiceClient;
75  import org.metricshub.winrm.service.client.auth.AuthenticationEnum;
76  import org.metricshub.winrm.service.client.auth.TrustAllX509Manager;
77  import org.metricshub.winrm.service.client.auth.kerberos.KerberosUtils;
78  import org.metricshub.winrm.service.client.auth.ntlm.NTCredentialsWithEncryption;
79  import org.metricshub.winrm.service.client.auth.ntlm.NtlmMasqAsSpnegoSchemeFactory;
80  import org.metricshub.winrm.service.client.encryption.AsyncHttpEncryptionAwareConduitFactory;
81  import org.metricshub.winrm.service.client.encryption.DecryptAndVerifyInInterceptor;
82  import org.metricshub.winrm.service.client.encryption.SignAndEncryptOutInterceptor;
83  
84  public class WinRMInvocationHandler implements InvocationHandler {
85  
86  	public static final String WSMAN_SCHEMA_NAMESPACE = "http://schemas.dmtf.org/wbem/wsman/1/wsman.xsd";
87  
88  	private static final long PAUSE_TIME_MILLISECONDS = 500;
89  	private static final int MAX_RETRY = 3;
90  
91  	private static final URL WSDL_LOCATION_URL =
92  		WinRMWebServiceClient.class.getClassLoader().getResource("wsdl/WinRM.wsdl");
93  
94  	private static final QName SERVICE = new QName(WSMAN_SCHEMA_NAMESPACE, "WinRMWebServiceClient");
95  
96  	private static final QName PORT = new QName(WSMAN_SCHEMA_NAMESPACE, "WinRMPort");
97  
98  	private static final List<String> CONTENT_TYPE_LIST = Collections.singletonList("application/soap+xml;charset=UTF-8");
99  
100 	@SuppressWarnings("rawtypes")
101 	private static final List<Handler> HANDLER_CHAIN = Arrays.asList(new StripShellResponseHandler());
102 
103 	private static final Registry<AuthSchemeProvider> AUTH_SCHEME_REGISTRY = RegistryBuilder
104 		.<AuthSchemeProvider>create()
105 		.register(AuthSchemes.SPNEGO, new NtlmMasqAsSpnegoSchemeFactory())
106 		.register(AuthSchemes.KERBEROS, new KerberosSchemeFactory(true))
107 		.build();
108 
109 	private static final Policy POLICY;
110 
111 	static {
112 		POLICY = new Policy();
113 		POLICY.addAssertion(new PrimitiveAssertion(MetadataConstants.USING_ADDRESSING_2004_QNAME));
114 	}
115 
116 	private static final WSAddressingFeature WS_ADDRESSING_FEATURE;
117 
118 	static {
119 		WS_ADDRESSING_FEATURE = new WSAddressingFeature();
120 		WS_ADDRESSING_FEATURE.setResponses(AddressingResponses.ANONYMOUS);
121 	}
122 
123 	private static final TLSClientParameters TLS_CLIENT_PARAMETERS;
124 
125 	static {
126 		TLS_CLIENT_PARAMETERS = new TLSClientParameters();
127 		TLS_CLIENT_PARAMETERS.setDisableCNCheck(true);
128 		// Accept all certificates
129 		TLS_CLIENT_PARAMETERS.setTrustManagers(new TrustManager[] { new TrustAllX509Manager() });
130 	}
131 
132 	private static final Map<CredentialsMapKey, Credentials> CREDENTIALS = new ConcurrentHashMap<>();
133 
134 	private final WinRMWebService winRMWebService;
135 	private final WinRMEndpoint winRMEndpoint;
136 	private final long timeout;
137 	private final String resourceUri;
138 	private final Path ticketCache;
139 	private final Queue<AuthenticationEnum> authenticationsQueue;
140 	private AuthenticationEnum authentication;
141 	private Client wsClient;
142 
143 	/**
144 	 * WinRMInvocationHandler constructor
145 	 *
146 	 * @param winRMEndpoint Endpoint with credentials (mandatory)
147 	 * @param bus Apache CXF Bus (mandatory)
148 	 * @param timeout Timeout used for Connection, Connection Request and Receive Request in milliseconds
149 	 * @param resourceUri The enumerate resource URI
150 	 * @param ticketCache The Ticket Cache path
151 	 * @param authentications List of authentications. (mandatory)
152 	 */
153 	public WinRMInvocationHandler(
154 		final WinRMEndpoint winRMEndpoint,
155 		final Bus bus,
156 		final long timeout,
157 		final String resourceUri,
158 		final Path ticketCache,
159 		final List<AuthenticationEnum> authentications
160 	) {
161 		Utils.checkNonNull(winRMEndpoint, "winRMEndpoint");
162 		Utils.checkNonNull(bus, "bus");
163 		Utils.checkNonNull(authentications, "authentications");
164 
165 		this.winRMEndpoint = winRMEndpoint;
166 		this.timeout = timeout;
167 		this.resourceUri = resourceUri;
168 		this.ticketCache = ticketCache;
169 		authenticationsQueue = authentications.stream().collect(Collectors.toCollection(LinkedList::new));
170 
171 		winRMWebService = createWinRMWebService(winRMEndpoint, bus);
172 
173 		final AuthCredentials authCredentials = computeCredentials(winRMEndpoint, ticketCache, authenticationsQueue);
174 
175 		authentication = authCredentials.getAuthentication();
176 
177 		wsClient =
178 			getWebServiceClient(winRMEndpoint, timeout, resourceUri, winRMWebService, authCredentials.getCredentials());
179 	}
180 
181 	public Client getClient() {
182 		return wsClient;
183 	}
184 
185 	@Override
186 	public Object invoke(final Object proxy, final Method method, final Object[] args) throws Throwable {
187 		Utils.checkNonNull(method, "method");
188 
189 		try {
190 			return invokeMethod(method, args);
191 		} catch (final RetryTgtExpirationException e) {
192 			// retry with a new TGT in case of current TGT expiration
193 			authentication = null;
194 
195 			Credentials credentials;
196 			try {
197 				credentials =
198 					KerberosUtils.createCredentials(winRMEndpoint.getUsername(), winRMEndpoint.getPassword(), ticketCache);
199 
200 				CREDENTIALS.put(new CredentialsMapKey(winRMEndpoint, ticketCache, AuthenticationEnum.KERBEROS), credentials);
201 				// Normally that should not happen as any other exception on KERBEROs should had been throw
202 				// at the first KERBEROS call
203 			} catch (final Exception e1) {
204 				if (continueToRetry()) {
205 					final AuthCredentials authCredentials = computeCredentials(winRMEndpoint, ticketCache, authenticationsQueue);
206 
207 					authentication = authCredentials.getAuthentication();
208 					credentials = authCredentials.getCredentials();
209 				} else {
210 					throw e1;
211 				}
212 			}
213 
214 			wsClient = getWebServiceClient(winRMEndpoint, timeout, resourceUri, winRMWebService, credentials);
215 
216 			return invoke(proxy, method, args);
217 		} catch (final RetryAuthenticationException e) {
218 			if (continueToRetry()) {
219 				final AuthCredentials authCredentials = computeCredentials(winRMEndpoint, ticketCache, authenticationsQueue);
220 
221 				authentication = authCredentials.getAuthentication();
222 
223 				wsClient =
224 					getWebServiceClient(winRMEndpoint, timeout, resourceUri, winRMWebService, authCredentials.getCredentials());
225 
226 				return invoke(proxy, method, args);
227 			}
228 
229 			// No more retries
230 			final Throwable cause = e.getCause();
231 			if (cause instanceof SOAPFaultException) {
232 				throw new RuntimeException("KERBEROS with encryption over HTTP is not implemented.", cause);
233 			}
234 			throw cause;
235 		}
236 	}
237 
238 	// this function is only needed for the unit testing
239 	boolean continueToRetry() {
240 		return !authenticationsQueue.isEmpty();
241 	}
242 
243 	Object invokeMethod(final Method method, final Object[] args)
244 		throws IllegalAccessException, RetryAuthenticationException {
245 		Throwable firstEx = null;
246 		int retry = 0;
247 
248 		while (retry < MAX_RETRY) {
249 			retry++;
250 
251 			try {
252 				return method.invoke(winRMWebService, args);
253 			} catch (final InvocationTargetException ite) {
254 				final Throwable targetEx = ite.getTargetException();
255 
256 				if (targetEx instanceof SOAPFaultException) {
257 					// Could retry with a different authentication than NTLM
258 					// because it could be a "WstxEOFException: Unexpected EOF in prolog"
259 					// due to a KERBEROS with HTTP and AllowUnencrypted=false
260 					if (winRMEndpoint.getProtocol() == WinRMHttpProtocolEnum.HTTP && authentication != AuthenticationEnum.NTLM) {
261 						throw new RetryAuthenticationException(targetEx);
262 					}
263 					throw (SOAPFaultException) targetEx;
264 				}
265 
266 				if (!(targetEx instanceof WebServiceException)) {
267 					throw new IllegalStateException("Failure when calling " + createCallInfos(method, args), targetEx);
268 				}
269 
270 				final WebServiceException wsEx = (WebServiceException) targetEx;
271 
272 				if (!(wsEx.getCause() instanceof IOException)) {
273 					throw new RuntimeException(
274 						"Exception occurred while making WinRM WebService call " + createCallInfos(method, args),
275 						wsEx
276 					);
277 				}
278 
279 				if (
280 					wsEx.getCause().getMessage() != null &&
281 					wsEx.getCause().getMessage().startsWith("Authorization loop detected on Conduit")
282 				) {
283 					final RuntimeException authEx = new RuntimeException(
284 						String.format(
285 							"Authentication error on %s with user name \"%s\"",
286 							winRMEndpoint.getEndpoint(),
287 							winRMEndpoint.getRawUsername()
288 						)
289 					);
290 
291 					// Could be due to a TGT expiration
292 					if (authentication == AuthenticationEnum.KERBEROS) {
293 						throw new RetryTgtExpirationException(authEx);
294 					}
295 					// Could retry with a different authentication
296 					throw new RetryAuthenticationException(authEx);
297 				}
298 
299 				if (firstEx == null) {
300 					firstEx = wsEx;
301 				}
302 
303 				if (retry < MAX_RETRY) {
304 					try {
305 						Utils.sleep(PAUSE_TIME_MILLISECONDS);
306 					} catch (final InterruptedException ie) {
307 						Thread.currentThread().interrupt();
308 						throw new RuntimeException(
309 							"Exception occured while making WinRM WebService call " + createCallInfos(method, args),
310 							ie
311 						);
312 					}
313 				}
314 			}
315 		}
316 
317 		throw new RuntimeException(
318 			String.format("failed task \"%s\" after %d attempts", createCallInfos(method, args), MAX_RETRY),
319 			firstEx
320 		);
321 	}
322 
323 	static String createCallInfos(final Method method, final Object[] args) {
324 		final String name = method != null && method.getName() != null ? method.getName() : Utils.EMPTY;
325 		return args == null
326 			? name
327 			: Stream
328 				.concat(Stream.of(name), Stream.of(args))
329 				.filter(Objects::nonNull)
330 				.map(Object::toString)
331 				.collect(Collectors.joining(" "));
332 	}
333 
334 	static Credentials createCredentials(
335 		final WinRMEndpoint winRMEndpoint,
336 		final AuthenticationEnum authentication,
337 		final Path ticketCache
338 	) {
339 		switch (authentication) {
340 			case KERBEROS:
341 				return KerberosUtils.createCredentials(winRMEndpoint.getUsername(), winRMEndpoint.getPassword(), ticketCache);
342 			case NTLM:
343 			default:
344 				final String password = String.valueOf(winRMEndpoint.getPassword());
345 				return winRMEndpoint.getProtocol() == WinRMHttpProtocolEnum.HTTP
346 					? new NTCredentialsWithEncryption(winRMEndpoint.getUsername(), password, null, winRMEndpoint.getDomain())
347 					: new NTCredentials(winRMEndpoint.getUsername(), password, null, winRMEndpoint.getDomain());
348 		}
349 	}
350 
351 	static AuthCredentials computeCredentials(
352 		final WinRMEndpoint winRMEndpoint,
353 		final Path ticketCache,
354 		final Queue<AuthenticationEnum> authenticationsQueue
355 	) {
356 		try {
357 			final AuthenticationEnum authenticationEnum = authenticationsQueue.remove();
358 
359 			final Credentials credentials = CREDENTIALS.compute(
360 				new CredentialsMapKey(winRMEndpoint, ticketCache, authenticationEnum),
361 				(user, cred) -> cred != null ? cred : createCredentials(winRMEndpoint, authenticationEnum, ticketCache)
362 			);
363 
364 			return new AuthCredentials(authenticationEnum, credentials);
365 		} catch (final Exception e) {
366 			// if there's still retry
367 			if (!authenticationsQueue.isEmpty()) {
368 				return computeCredentials(winRMEndpoint, ticketCache, authenticationsQueue);
369 			}
370 			throw e;
371 		}
372 	}
373 
374 	static WinRMWebService createWinRMWebService(final WinRMEndpoint winRMEndpoint, final Bus bus) {
375 		final JaxWsProxyFactoryBean jaxWsProxyFactoryBean = new JaxWsProxyFactoryBean();
376 		jaxWsProxyFactoryBean.setServiceName(SERVICE);
377 		jaxWsProxyFactoryBean.setEndpointName(PORT);
378 		jaxWsProxyFactoryBean.setBus(bus);
379 		jaxWsProxyFactoryBean.setServiceClass(WinRMWebService.class);
380 		jaxWsProxyFactoryBean.setAddress(winRMEndpoint.getEndpoint());
381 		jaxWsProxyFactoryBean.getFeatures().add(WS_ADDRESSING_FEATURE);
382 		jaxWsProxyFactoryBean.setBindingId(SoapBindingConstants.SOAP12_BINDING_ID);
383 		jaxWsProxyFactoryBean.getClientFactoryBean().getServiceFactory().setWsdlURL(WSDL_LOCATION_URL);
384 
385 		return jaxWsProxyFactoryBean.create(WinRMWebService.class);
386 	}
387 
388 	static Client getWebServiceClient(
389 		final WinRMEndpoint winRMEndpoint,
390 		final long timeout,
391 		final String enumerateResourceUri,
392 		final WinRMWebService winRMWebService,
393 		final Credentials credentials
394 	) {
395 		final Client client = ClientProxy.getClient(winRMWebService);
396 
397 		if (enumerateResourceUri != null) {
398 			final WSManHeaderInterceptor interceptor = new WSManHeaderInterceptor(enumerateResourceUri);
399 			client.getOutInterceptors().add(interceptor);
400 		}
401 
402 		client.getInInterceptors().add(new DecryptAndVerifyInInterceptor());
403 		client.getOutInterceptors().add(new SignAndEncryptOutInterceptor());
404 
405 		// this is different to endpoint properties
406 		client
407 			.getEndpoint()
408 			.getEndpointInfo()
409 			.setProperty(HTTPConduitFactory.class.getName(), new AsyncHttpEncryptionAwareConduitFactory());
410 
411 		final ServiceInfo serviceInfo = client.getEndpoint().getEndpointInfo().getService();
412 		serviceInfo.setProperty("soap.force.doclit.bare", true);
413 
414 		final BindingProvider bindingProvider = (BindingProvider) winRMWebService;
415 		bindingProvider.getBinding().setHandlerChain(HANDLER_CHAIN);
416 		bindingProvider.getRequestContext().put(PolicyConstants.POLICY_OVERRIDE, POLICY);
417 		bindingProvider.getRequestContext().put("http.autoredirect", true);
418 
419 		bindingProvider.getRequestContext().put(BindingProvider.ENDPOINT_ADDRESS_PROPERTY, winRMEndpoint.getEndpoint());
420 
421 		final Map<String, List<String>> headers = new HashMap<>();
422 		headers.put("Content-Type", CONTENT_TYPE_LIST);
423 
424 		bindingProvider.getRequestContext().put(Message.PROTOCOL_HEADERS, headers);
425 
426 		// Setup timeouts
427 		final HTTPClientPolicy httpClientPolicy = new HTTPClientPolicy();
428 		httpClientPolicy.setConnectionTimeout(timeout);
429 		httpClientPolicy.setConnectionRequestTimeout(timeout);
430 		httpClientPolicy.setReceiveTimeout(timeout);
431 		httpClientPolicy.setAllowChunking(false);
432 
433 		bindingProvider.getRequestContext().put(Credentials.class.getName(), credentials);
434 		bindingProvider.getRequestContext().put(AuthSchemeProvider.class.getName(), AUTH_SCHEME_REGISTRY);
435 
436 		final AsyncHTTPConduit asyncHTTPConduit = (AsyncHTTPConduit) client.getConduit();
437 		asyncHTTPConduit.setClient(httpClientPolicy);
438 		asyncHTTPConduit.getClient().setAutoRedirect(true);
439 		asyncHTTPConduit.setTlsClientParameters(TLS_CLIENT_PARAMETERS);
440 
441 		return client;
442 	}
443 
444 	static class RetryAuthenticationException extends Exception {
445 
446 		private static final long serialVersionUID = 1L;
447 
448 		RetryAuthenticationException(final Throwable throwable) {
449 			super(throwable);
450 		}
451 	}
452 
453 	static class RetryTgtExpirationException extends RetryAuthenticationException {
454 
455 		private static final long serialVersionUID = 1L;
456 
457 		RetryTgtExpirationException(final Throwable throwable) {
458 			super(throwable);
459 		}
460 	}
461 
462 	static class AuthCredentials {
463 
464 		private final AuthenticationEnum authentication;
465 		private final Credentials credentials;
466 
467 		AuthCredentials(final AuthenticationEnum authentication, final Credentials credentials) {
468 			this.authentication = authentication;
469 			this.credentials = credentials;
470 		}
471 
472 		public AuthenticationEnum getAuthentication() {
473 			return authentication;
474 		}
475 
476 		public Credentials getCredentials() {
477 			return credentials;
478 		}
479 
480 		@Override
481 		public int hashCode() {
482 			return Objects.hash(authentication, credentials);
483 		}
484 
485 		@Override
486 		public boolean equals(final Object obj) {
487 			if (this == obj) {
488 				return true;
489 			}
490 			if (obj == null) {
491 				return false;
492 			}
493 			if (!(obj instanceof AuthCredentials)) {
494 				return false;
495 			}
496 			final AuthCredentials other = (AuthCredentials) obj;
497 			return authentication == other.authentication && Objects.equals(credentials, other.credentials);
498 		}
499 	}
500 
501 	static class CredentialsMapKey {
502 
503 		private final String canonizedRawUsername;
504 		private final char[] password;
505 		private final Path ticketCache;
506 		private final AuthenticationEnum authentication;
507 
508 		CredentialsMapKey(
509 			final WinRMEndpoint winRMEndpoint,
510 			final Path ticketCache,
511 			final AuthenticationEnum authentication
512 		) {
513 			this.ticketCache = ticketCache;
514 			this.authentication = authentication;
515 
516 			password = winRMEndpoint.getPassword();
517 			canonizedRawUsername =
518 				winRMEndpoint.getRawUsername() != null
519 					? winRMEndpoint.getRawUsername().replaceAll("\\s", Utils.EMPTY).toUpperCase()
520 					: null;
521 		}
522 
523 		@Override
524 		public int hashCode() {
525 			final int prime = 31;
526 			int result = 1;
527 			result = prime * result + Arrays.hashCode(password);
528 			result = prime * result + Objects.hash(authentication, canonizedRawUsername, ticketCache);
529 			return result;
530 		}
531 
532 		@Override
533 		public boolean equals(final Object obj) {
534 			if (this == obj) {
535 				return true;
536 			}
537 			if (obj == null) {
538 				return false;
539 			}
540 			if (!(obj instanceof CredentialsMapKey)) {
541 				return false;
542 			}
543 			final CredentialsMapKey other = (CredentialsMapKey) obj;
544 			return (
545 				authentication == other.authentication &&
546 				Objects.equals(canonizedRawUsername, other.canonizedRawUsername) &&
547 				Arrays.equals(password, other.password) &&
548 				Objects.equals(ticketCache, other.ticketCache)
549 			);
550 		}
551 	}
552 }