Skip to content

Commit

Permalink
KNOX-2641 Fix invalid session handle issue with Hive HA (apache#481)
Browse files Browse the repository at this point in the history
* KNOX-2641 Fix invalid session handle issue with Hive HA
  • Loading branch information
moresandeep authored Aug 12, 2021
1 parent 3a2eabc commit eecf3ca
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ protected void executeRequestWrapper(HttpUriRequest outboundRequest,
}

@Override
protected void outboundResponseWrapper(final HttpServletRequest inboundRequest, HttpServletResponse outboundResponse) {
setKnoxHaCookie(inboundRequest, outboundResponse);
protected void outboundResponseWrapper(final HttpUriRequest outboundRequest, final HttpServletRequest inboundRequest, final HttpServletResponse outboundResponse) {
setKnoxHaCookie(outboundRequest, inboundRequest, outboundResponse);
}

@Override
Expand Down Expand Up @@ -235,8 +235,8 @@ private Optional<URI> setBackendfromHaCookie(HttpUriRequest outboundRequest, Htt
return Optional.empty();
}

private void setKnoxHaCookie(HttpServletRequest inboundRequest,
HttpServletResponse outboundResponse) {
private void setKnoxHaCookie(final HttpUriRequest outboundRequest, final HttpServletRequest inboundRequest,
final HttpServletResponse outboundResponse) {
if (stickySessionsEnabled) {
List<Cookie> serviceHaCookies = Collections.emptyList();
if(inboundRequest.getCookies() != null) {
Expand All @@ -250,8 +250,21 @@ private void setKnoxHaCookie(HttpServletRequest inboundRequest,
&& hashToUrlLookup.containsKey(serviceHaCookies.get(0).getValue())) {
return;
} else {
String url = haProvider.getActiveURL(getServiceRole());
String cookieValue = urlToHashLookup.get(url);

/**
* Due to concurrency issues haProvider.getActiveURL() will not return the accurate list
* This will cause issues where original request goes to host-1 and cookie is set for host-2 - because
* haProvider.getActiveURL() returned host-2. To prevent this from happening we need to make sure
* we set cookie for the endpoint that was served and not rely on haProvider.getActiveURL().
* let LBing logic take care of rotating urls.
**/
final List<String> urls = haProvider.getURLs(getServiceRole())
.stream()
.filter(u -> u.contains(outboundRequest.getURI().getHost()))
.collect(Collectors.toList());

final String cookieValue = urlToHashLookup.get(urls.get(0));

Cookie stickySessionCookie = new Cookie(stickySessionCookieName, cookieValue);
stickySessionCookie.setPath(inboundRequest.getContextPath());
stickySessionCookie.setMaxAge(-1);
Expand Down Expand Up @@ -371,5 +384,4 @@ private URI updateHostURL(final URI source, final String host) throws URISyntaxE
uriBuilder.setPort(newUri.getPort());
return uriBuilder.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,7 @@ public interface HaDispatchMessages {

@Message(level = MessageLevel.ERROR, text = "Error setting non-loadbalanced url to outbound request")
void errorSettingActiveUrl();

@Message(level = MessageLevel.ERROR, text = "Unsupported encoding, cause: {0}")
void unsupportedEncodingException(String cause);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
*/
package org.apache.knox.gateway.ha.provider.impl;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import org.apache.knox.gateway.ha.provider.HaDescriptor;
import org.apache.knox.gateway.ha.provider.HaProvider;
import org.apache.knox.gateway.ha.provider.HaServiceConfig;
Expand All @@ -25,10 +30,6 @@
import org.apache.knox.gateway.ha.provider.impl.i18n.HaMessages;
import org.apache.knox.gateway.i18n.messages.MessagesFactory;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

public class DefaultHaProvider implements HaProvider {

private static final HaMessages LOG = MessagesFactory.get(HaMessages.class);
Expand All @@ -37,6 +38,8 @@ public class DefaultHaProvider implements HaProvider {

private ConcurrentHashMap<String, URLManager> haServices;

private ReentrantReadWriteLock rwl = new ReentrantReadWriteLock(true);

public DefaultHaProvider(HaDescriptor descriptor) {
if ( descriptor == null ) {
throw new IllegalArgumentException("Descriptor can not be null");
Expand Down Expand Up @@ -66,38 +69,59 @@ public boolean isHaEnabled(String serviceName) {

@Override
public String getActiveURL(String serviceName) {
if ( haServices.containsKey(serviceName) ) {
return haServices.get(serviceName).getActiveURL();
rwl.readLock().lock();
try {
if (haServices.containsKey(serviceName)) {
return haServices.get(serviceName).getActiveURL();
}
LOG.noActiveUrlFound(serviceName);
return null;
} finally {
rwl.readLock().unlock();
}
LOG.noActiveUrlFound(serviceName);
return null;
}

@Override
public void setActiveURL(String serviceName, String url) {
if ( haServices.containsKey(serviceName) ) {
haServices.get(serviceName).setActiveURL(url);
} else {
LOG.noServiceFound(serviceName);
rwl.writeLock().lock();
try {
if (haServices.containsKey(serviceName)) {
haServices.get(serviceName).setActiveURL(url);
} else {
LOG.noServiceFound(serviceName);
}
}
finally {
rwl.writeLock().unlock();
}

}

@Override
public void markFailedURL(String serviceName, String url) {
if ( haServices.containsKey(serviceName) ) {
haServices.get(serviceName).markFailed(url);
} else {
LOG.noServiceFound(serviceName);
rwl.writeLock().lock();
try {
if (haServices.containsKey(serviceName)) {
haServices.get(serviceName).markFailed(url);
} else {
LOG.noServiceFound(serviceName);
}
} finally {
rwl.writeLock().unlock();
}
}

@Override
public void makeNextActiveURLAvailable(String serviceName) {
if ( haServices.containsKey(serviceName) ) {
haServices.get(serviceName).makeNextActiveURLAvailable();
} else {
LOG.noServiceFound(serviceName);
rwl.writeLock().lock();
try {
if (haServices.containsKey(serviceName)) {
haServices.get(serviceName).makeNextActiveURLAvailable();
} else {
LOG.noServiceFound(serviceName);
}
} finally {
rwl.writeLock().unlock();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.knox.gateway.ha.dispatch;

import static org.easymock.EasyMock.capture;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicInteger;

import javax.servlet.FilterConfig;
import javax.servlet.ServletContext;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.codec.digest.DigestUtils;
import org.apache.http.Header;
import org.apache.http.HeaderElement;
import org.apache.http.HttpEntity;
import org.apache.http.HttpStatus;
import org.apache.http.StatusLine;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.params.BasicHttpParams;
import org.apache.knox.gateway.config.GatewayConfig;
import org.apache.knox.gateway.ha.provider.HaDescriptor;
import org.apache.knox.gateway.ha.provider.HaProvider;
import org.apache.knox.gateway.ha.provider.HaServletContextListener;
import org.apache.knox.gateway.ha.provider.impl.DefaultHaProvider;
import org.apache.knox.gateway.ha.provider.impl.HaDescriptorFactory;
import org.apache.knox.gateway.servlet.SynchronousServletOutputStreamAdapter;
import org.easymock.Capture;
import org.easymock.EasyMock;
import org.easymock.IAnswer;
import org.junit.Assert;
import org.junit.Test;

public class ConfigurableHADispatchTest {

/**
* Test whether the dispatch url is correctly used in case where loadbalancing is enabled
* and sticky session is enabled making sure we dispatch requests based on the HA Provider logic and
* not based on URL rewrite logic.
*
* @throws Exception
*/
@Test
public void testHADispatchURL() throws Exception {
String serviceName = "HIVE";
HaDescriptor descriptor = HaDescriptorFactory.createDescriptor();
descriptor.addServiceConfig(HaDescriptorFactory.createServiceConfig(serviceName, "true", "1", "1000", null, null, "true", "true", null, null));
HaProvider provider = new DefaultHaProvider(descriptor);
URI uri1 = new URI("http://host1.valid");
URI uri2 = new URI("http://host2.valid");
URI uri3 = new URI("http://host3.valid");
ArrayList<String> urlList = new ArrayList<>();
urlList.add(uri1.toString());
urlList.add(uri2.toString());
urlList.add(uri3.toString());
provider.addHaService(serviceName, urlList);


HttpServletRequest inboundRequest = EasyMock.createNiceMock(HttpServletRequest.class);
EasyMock.expect(inboundRequest.getRequestURL()).andReturn(new StringBuffer(provider.getActiveURL(serviceName))).anyTimes();
EasyMock.replay(inboundRequest);

ConfigurableHADispatch dispatch = new ConfigurableHADispatch();
dispatch.setHaProvider(provider);
dispatch.setServiceRole(serviceName);
dispatch.init();

/* make sure the dispatch URL is always active URL */
Assert.assertEquals(provider.getActiveURL(serviceName), dispatch.getDispatchUrl(inboundRequest).toString());
}

/**
* This tests ensure that in case where HA is configured.
* the host the the request is dispatched is the same host for
* which HA cookie is set.
*
* @throws Exception
*/
@Test
public void testSetCookieHeader() throws Exception {
String serviceName = "HIVE";
HaDescriptor descriptor = HaDescriptorFactory.createDescriptor();
descriptor.addServiceConfig(HaDescriptorFactory.createServiceConfig(serviceName, "true", "1", "1000", null, null, "true", "true", null, null));
HaProvider provider = new DefaultHaProvider(descriptor);
URI uri1 = new URI( "http://host1.valid" );
URI uri2 = new URI( "http://host2.valid" );
ArrayList<String> urlList = new ArrayList<>();
urlList.add(uri1.toString());
urlList.add(uri2.toString());
provider.addHaService(serviceName, urlList);
FilterConfig filterConfig = EasyMock.createNiceMock(FilterConfig.class);
ServletContext servletContext = EasyMock.createNiceMock(ServletContext.class);

EasyMock.expect(filterConfig.getServletContext()).andReturn(servletContext).anyTimes();
EasyMock.expect(servletContext.getAttribute(HaServletContextListener.PROVIDER_ATTRIBUTE_NAME)).andReturn(provider).anyTimes();

BasicHttpParams params = new BasicHttpParams();

HttpUriRequest outboundRequest = EasyMock.createNiceMock(HttpRequestBase.class);
EasyMock.expect(outboundRequest.getMethod()).andReturn( "GET" ).anyTimes();
EasyMock.expect(outboundRequest.getURI()).andReturn( uri1 ).anyTimes();
EasyMock.expect(outboundRequest.getParams()).andReturn( params ).anyTimes();

/* dispatched url is the active HA url */
String activeURL = provider.getActiveURL(serviceName);

/* backend request */
HttpServletRequest inboundRequest = EasyMock.createNiceMock(HttpServletRequest.class);
EasyMock.expect(inboundRequest.getRequestURL()).andReturn( new StringBuffer(activeURL)).once();
EasyMock.expect(inboundRequest.getAttribute("dispatch.ha.failover.counter")).andReturn(new AtomicInteger(0)).once();
EasyMock.expect(inboundRequest.getAttribute("dispatch.ha.failover.counter")).andReturn(new AtomicInteger(1)).once();

/* backend response */
CloseableHttpResponse inboundResponse = EasyMock.createNiceMock(CloseableHttpResponse.class);
final StatusLine statusLine = EasyMock.createNiceMock(StatusLine.class);
final HttpEntity entity = EasyMock.createNiceMock(HttpEntity.class);
final Header header = EasyMock.createNiceMock(Header.class);
final ServletContext context = EasyMock.createNiceMock(ServletContext.class);
final GatewayConfig config = EasyMock.createNiceMock(GatewayConfig.class);
final ByteArrayInputStream backendResponse = new ByteArrayInputStream("knox-backend".getBytes(
StandardCharsets.UTF_8));


EasyMock.expect(inboundResponse.getStatusLine()).andReturn(statusLine).anyTimes();
EasyMock.expect(statusLine.getStatusCode()).andReturn(HttpStatus.SC_OK).anyTimes();
EasyMock.expect(inboundResponse.getEntity()).andReturn(entity).anyTimes();
EasyMock.expect(inboundResponse.getAllHeaders()).andReturn(new Header[0]).anyTimes();
EasyMock.expect(inboundRequest.getServletContext()).andReturn(context).anyTimes();
EasyMock.expect(entity.getContent()).andReturn(backendResponse).anyTimes();
EasyMock.expect(entity.getContentType()).andReturn(header).anyTimes();
EasyMock.expect(header.getElements()).andReturn(new HeaderElement[]{}).anyTimes();
EasyMock.expect(entity.getContentLength()).andReturn(4L).anyTimes();
EasyMock.expect(context.getAttribute(GatewayConfig.GATEWAY_CONFIG_ATTRIBUTE)).andReturn(config).anyTimes();

Capture<Cookie> captureCookieValue = EasyMock.newCapture();
HttpServletResponse outboundResponse = EasyMock.createNiceMock(HttpServletResponse.class);
EasyMock.expect(outboundResponse.getOutputStream()).andAnswer( new IAnswer<SynchronousServletOutputStreamAdapter>() {
@Override
public SynchronousServletOutputStreamAdapter answer() {
return new SynchronousServletOutputStreamAdapter() {
@Override
public void write( int b ) throws IOException {
/* do nothing */
}
};
}
}).once();

outboundResponse.addCookie(capture(captureCookieValue));

CloseableHttpClient mockHttpClient = EasyMock.createNiceMock(CloseableHttpClient.class);
EasyMock.expect(mockHttpClient.execute(outboundRequest)).andReturn(inboundResponse).anyTimes();

EasyMock.replay(filterConfig, servletContext, outboundRequest, inboundRequest,
outboundResponse, mockHttpClient, inboundResponse,
statusLine, entity, header, context, config);


Assert.assertEquals(uri1.toString(), provider.getActiveURL(serviceName));
ConfigurableHADispatch dispatch = new ConfigurableHADispatch();
dispatch.setHttpClient(mockHttpClient);
dispatch.setHaProvider(provider);
dispatch.setServiceRole(serviceName);
dispatch.init();
try {
dispatch.executeRequestWrapper(outboundRequest, inboundRequest, outboundResponse);
} catch (IOException e) {
//this is expected after the failover limit is reached
}
/* make sure the url is ladbalanced */
Assert.assertEquals(uri2.toString(), provider.getActiveURL(serviceName));
/* make sure the HA backend URL hash in set-cookie is for active URL (which was in the dispatch request) */
Assert.assertEquals(DigestUtils.sha256Hex(activeURL), captureCookieValue.getValue().getValue());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ protected void executeRequestWrapper(HttpUriRequest outboundRequest,
* to modify any outgoing
* response i.e. cookies
*/
protected void outboundResponseWrapper(final HttpServletRequest inboundRequest, HttpServletResponse outboundResponse) {
protected void outboundResponseWrapper(final HttpUriRequest outboundRequest, final HttpServletRequest inboundRequest, HttpServletResponse outboundResponse) {
/* no-op */
}

Expand Down Expand Up @@ -188,7 +188,7 @@ protected HttpResponse executeOutboundRequest( HttpUriRequest outboundRequest )

protected void writeOutboundResponse(HttpUriRequest outboundRequest, HttpServletRequest inboundRequest, HttpServletResponse outboundResponse, HttpResponse inboundResponse) throws IOException {
/* in case any changes to outbound response are needed */
outboundResponseWrapper(inboundRequest, outboundResponse);
outboundResponseWrapper(outboundRequest, inboundRequest, outboundResponse);
// Copy the client respond header to the server respond.
outboundResponse.setStatus(inboundResponse.getStatusLine().getStatusCode());
copyResponseHeaderFields(outboundResponse, inboundResponse);
Expand Down

0 comments on commit eecf3ca

Please sign in to comment.