package net.unit8.jmeter.protocol.websocket.sampler; import org.apache.commons.lang3.StringUtils; import org.apache.jmeter.config.Argument; import org.apache.jmeter.config.Arguments; import org.apache.jmeter.config.ConfigTestElement; import org.apache.jmeter.protocol.http.util.EncoderCache; import org.apache.jmeter.protocol.http.util.HTTPArgument; import org.apache.jmeter.protocol.http.util.HTTPConstants; import org.apache.jmeter.samplers.AbstractSampler; import org.apache.jmeter.samplers.Entry; import org.apache.jmeter.samplers.SampleResult; import org.apache.jmeter.testelement.TestElement; import org.apache.jmeter.testelement.TestStateListener; import org.apache.jmeter.testelement.property.*; import org.apache.jmeter.threads.JMeterContextService; import org.apache.jorphan.logging.LoggingManager; import org.apache.jorphan.util.JOrphanUtils; import org.apache.log.Logger; import org.eclipse.jetty.util.ConcurrentHashSet; import org.eclipse.jetty.websocket.WebSocket; import org.eclipse.jetty.websocket.WebSocketClient; import org.eclipse.jetty.websocket.WebSocketClientFactory; import java.io.UnsupportedEncodingException; import java.net.URI; import java.net.URISyntaxException; import java.util.Arrays; import java.util.HashSet; import java.util.Set; import java.util.concurrent.Future; import java.util.concurrent.TimeoutException; import java.util.regex.Pattern; /** * The sampler for WebSocket. * @author kawasima */ public class WebSocketSampler extends AbstractSampler implements TestStateListener { private static final Logger log = LoggingManager.getLoggerForClass(); private static final Set<String> APPLIABLE_CONFIG_CLASSES = new HashSet<String>( Arrays.asList(new String[]{ "net.unit8.jmeter.protocol.websocket.control.gui.WebSocketSamplerGui", "org.apache.jmeter.config.gui.SimpleConfigGui"})); private static final String ARG_VAL_SEP = "="; // $NON-NLS-1$ private static final String QRY_SEP = "&"; // $NON-NLS-1$ private static final String QRY_PFX = "?"; // $NON-NLS-1$ private static final String WS_PREFIX = "ws://"; // $NON-NLS-1$ private static final String WSS_PREFIX = "wss://"; // $NON-NLS-1$ private static final String DEFAULT_PROTOCOL = "ws"; private static final int UNSPECIFIED_PORT = 0; private static final String UNSPECIFIED_PORT_AS_STRING = "0"; // $NON-NLS-1$ private static final int URL_UNSPECIFIED_PORT = -1; private WebSocket.Connection connection = null; private static final ConcurrentHashSet<WebSocket.Connection> samplerConnections = new ConcurrentHashSet<WebSocket.Connection>(); private boolean initialized = false; private String responseMessage; public static final String DOMAIN = "WebSocketSampler.domain"; public static final String PORT = "WebSocketSampler.port"; public static final String PATH = "WebSocketSampler.path"; public static final String PROTOCOL = "WebSocketSampler.protocol"; public static final String CONTENT_ENCODING = "WebSocketSampler.contentEncoding"; public static final String ARGUMENTS = "WebSocketSampler.arguments"; public static final String SEND_MESSAGE = "WebSocketSampler.sendMessage"; public static final String RECV_MESSAGE = "WebSocketSampler.recvMessage"; public static final String RECV_TIMEOUT = "WebSocketSampler.recvTimeout"; private static WebSocketClientFactory webSocketClientFactory = new WebSocketClientFactory(); public WebSocketSampler() { setArguments(new Arguments()); } public void initialize() throws Exception { URI uri = getUri(); WebSocketClient webSocketClient = webSocketClientFactory.newWebSocketClient(); final WebSocketSampler parent = this; final String threadName = JMeterContextService.getContext().getThread().getThreadName(); final Pattern regex = (getRecvMessage() != null) ? Pattern.compile(getRecvMessage()) : null; Future<WebSocket.Connection> futureConnection = webSocketClient.open(uri, new WebSocket.OnTextMessage() { @Override public void onMessage(String s) { synchronized (parent) { if (regex == null || regex.matcher(s).find()) { responseMessage = s; parent.notify(); } } } @Override public void onOpen(Connection connection) { log.debug("Connect " + threadName); } @Override public void onClose(int i, String s) { log.debug("Disconnect " + threadName); } }); connection = futureConnection.get(); samplerConnections.add(connection); initialized = true; } @Override public SampleResult sample(Entry entry) { SampleResult res = new SampleResult(); res.setSampleLabel(getName()); boolean isOK = false; if (!initialized) { try { initialize(); } catch (Exception e) { res.setResponseMessage(e.getMessage()); res.setSuccessful(false); return res; } } String message = getPropertyAsString(SEND_MESSAGE, "default message"); res.setSamplerData(message); res.sampleStart(); try { if (connection.isOpen()) { res.setDataEncoding(getContentEncoding()); connection.sendMessage(message); } else { initialize(); } synchronized (this) { wait(getRecvTimeout()); } if (responseMessage == null) { res.setResponseCode("204"); throw new TimeoutException("No content (probably timeout)."); } res.setResponseCodeOK(); res.setResponseData(responseMessage, getContentEncoding()); isOK = true; } catch (Exception e) { log.debug(e.getMessage()); res.setResponseMessage(e.getMessage()); } res.sampleEnd(); res.setSuccessful(isOK); return res; } @Override public void setName(String name) { if (name != null) setProperty(TestElement.NAME, name); } @Override public String getName() { return getPropertyAsString(TestElement.NAME); } @Override public void setComment(String comment){ setProperty(new StringProperty(TestElement.COMMENTS, comment)); } @Override public String getComment(){ return getProperty(TestElement.COMMENTS).getStringValue(); } public URI getUri() throws URISyntaxException { String path = this.getPath(); // Hack to allow entire URL to be provided in host field if (path.startsWith(WS_PREFIX) || path.startsWith(WSS_PREFIX)){ return new URI(path); } String domain = getDomain(); String protocol = getProtocol(); // HTTP URLs must be absolute, allow file to be relative if (!path.startsWith("/")){ // $NON-NLS-1$ path = "/" + path; // $NON-NLS-1$ } String queryString = getQueryString(getContentEncoding()); if(isProtocolDefaultPort()) { return new URI(protocol, null, domain, -1, path, queryString, null); } return new URI(protocol, null, domain, getPort(), path, queryString, null); } public void setPath(String path, String contentEncoding) { boolean fullUrl = path.startsWith(WS_PREFIX) || path.startsWith(WSS_PREFIX); if (!fullUrl) { int index = path.indexOf(QRY_PFX); if (index > -1) { setProperty(PATH, path.substring(0, index)); // Parse the arguments in querystring, assuming specified encoding for values parseArguments(path.substring(index + 1), contentEncoding); } else { setProperty(PATH, path); } } else { setProperty(PATH, path); } } public String getPath() { String p = getPropertyAsString(PATH); return encodeSpaces(p); } public void setPort(int value) { setProperty(new IntegerProperty(PORT, value)); } public static int getDefaultPort(String protocol,int port){ if (port==URL_UNSPECIFIED_PORT){ return protocol.equalsIgnoreCase(HTTPConstants.PROTOCOL_HTTP) ? HTTPConstants.DEFAULT_HTTP_PORT : protocol.equalsIgnoreCase(HTTPConstants.PROTOCOL_HTTPS) ? HTTPConstants.DEFAULT_HTTPS_PORT : port; } return port; } /** * Get the port number from the port string, allowing for trailing blanks. * * @return port number or UNSPECIFIED_PORT (== 0) */ public int getPortIfSpecified() { String port_s = getPropertyAsString(PORT, UNSPECIFIED_PORT_AS_STRING); try { return Integer.parseInt(port_s.trim()); } catch (NumberFormatException e) { return UNSPECIFIED_PORT; } } /** * Tell whether the default port for the specified protocol is used * * @return true if the default port number for the protocol is used, false otherwise */ public boolean isProtocolDefaultPort() { final int port = getPortIfSpecified(); final String protocol = getProtocol(); return port == UNSPECIFIED_PORT || ("ws".equalsIgnoreCase(protocol) && port == HTTPConstants.DEFAULT_HTTP_PORT) || ("wss".equalsIgnoreCase(protocol) && port == HTTPConstants.DEFAULT_HTTPS_PORT); } public int getPort() { final int port = getPortIfSpecified(); if (port == UNSPECIFIED_PORT) { String prot = getProtocol(); if ("wss".equalsIgnoreCase(prot)) { return HTTPConstants.DEFAULT_HTTPS_PORT; } if (!"ws".equalsIgnoreCase(prot)) { log.warn("Unexpected protocol: "+prot); // TODO - should this return something else? } return HTTPConstants.DEFAULT_HTTP_PORT; } return port; } public void setDomain(String value) { setProperty(DOMAIN, value); } public String getDomain() { return getPropertyAsString(DOMAIN); } public void setProtocol(String value) { setProperty(PROTOCOL, value.toLowerCase(java.util.Locale.ENGLISH)); } public String getProtocol() { String protocol = getPropertyAsString(PROTOCOL); if (protocol == null || protocol.length() == 0 ) { return DEFAULT_PROTOCOL; } return protocol; } public void setContentEncoding(String charsetName) { setProperty(CONTENT_ENCODING, charsetName); } public String getContentEncoding() { return getPropertyAsString(CONTENT_ENCODING); } public String getQueryString(String contentEncoding) { // Check if the sampler has a specified content encoding if(JOrphanUtils.isBlank(contentEncoding)) { // We use the encoding which should be used according to the HTTP spec, which is UTF-8 contentEncoding = EncoderCache.URL_ARGUMENT_ENCODING; } StringBuilder buf = new StringBuilder(); PropertyIterator iter = getArguments().iterator(); boolean first = true; while (iter.hasNext()) { HTTPArgument item = null; Object objectValue = iter.next().getObjectValue(); try { item = (HTTPArgument) objectValue; } catch (ClassCastException e) { item = new HTTPArgument((Argument) objectValue); } final String encodedName = item.getEncodedName(); if (encodedName.length() == 0) { continue; // Skip parameters with a blank name (allows use of optional variables in parameter lists) } if (!first) { buf.append(QRY_SEP); } else { first = false; } buf.append(encodedName); if (item.getMetaData() == null) { buf.append(ARG_VAL_SEP); } else { buf.append(item.getMetaData()); } // Encode the parameter value in the specified content encoding try { buf.append(item.getEncodedValue(contentEncoding)); } catch(UnsupportedEncodingException e) { log.warn("Unable to encode parameter in encoding " + contentEncoding + ", parameter value not included in query string"); } } return buf.toString(); } public void setSendMessage(String value) { setProperty(SEND_MESSAGE, value); } public String getSendMessage() { return getPropertyAsString(SEND_MESSAGE); } public void setRecvMessage(String value) { setProperty(RECV_MESSAGE, value); } public String getRecvMessage() { return getPropertyAsString(RECV_MESSAGE); } public void setRecvTimeout(long value) { setProperty(new LongProperty(RECV_TIMEOUT, value)); } public long getRecvTimeout() { return getPropertyAsLong(RECV_TIMEOUT, 20000L); } public void setArguments(Arguments value) { setProperty(new TestElementProperty(ARGUMENTS, value)); } public Arguments getArguments() { return (Arguments) getProperty(ARGUMENTS).getObjectValue(); } protected String encodeSpaces(String path) { return JOrphanUtils.replaceAllChars(path, ' ', "%20"); // $NON-NLS-1$ } public void parseArguments(String queryString, String contentEncoding) { String[] args = JOrphanUtils.split(queryString, QRY_SEP); for (int i = 0; i < args.length; i++) { // need to handle four cases: // - string contains name=value // - string contains name= // - string contains name // - empty string String metaData; // records the existance of an equal sign String name; String value; int length = args[i].length(); int endOfNameIndex = args[i].indexOf(ARG_VAL_SEP); if (endOfNameIndex != -1) {// is there a separator? // case of name=value, name= metaData = ARG_VAL_SEP; name = args[i].substring(0, endOfNameIndex); value = args[i].substring(endOfNameIndex + 1, length); } else { metaData = ""; name=args[i]; value=""; } if (name.length() > 0) { // If we know the encoding, we can decode the argument value, // to make it easier to read for the user if(!StringUtils.isEmpty(contentEncoding)) { addEncodedArgument(name, value, metaData, contentEncoding); } else { // If we do not know the encoding, we just use the encoded value // The browser has already done the encoding, so save the values as is addNonEncodedArgument(name, value, metaData); } } } } public void addEncodedArgument(String name, String value, String metaData, String contentEncoding) { if (log.isDebugEnabled()){ log.debug("adding argument: name: " + name + " value: " + value + " metaData: " + metaData + " contentEncoding: " + contentEncoding); } HTTPArgument arg = null; final boolean nonEmptyEncoding = !StringUtils.isEmpty(contentEncoding); if(nonEmptyEncoding) { arg = new HTTPArgument(name, value, metaData, true, contentEncoding); } else { arg = new HTTPArgument(name, value, metaData, true); } // Check if there are any difference between name and value and their encoded name and value String valueEncoded = null; if(nonEmptyEncoding) { try { valueEncoded = arg.getEncodedValue(contentEncoding); } catch (UnsupportedEncodingException e) { log.warn("Unable to get encoded value using encoding " + contentEncoding); valueEncoded = arg.getEncodedValue(); } } else { valueEncoded = arg.getEncodedValue(); } // If there is no difference, we mark it as not needing encoding if (arg.getName().equals(arg.getEncodedName()) && arg.getValue().equals(valueEncoded)) { arg.setAlwaysEncoded(false); } this.getArguments().addArgument(arg); } public void addEncodedArgument(String name, String value, String metaData) { this.addEncodedArgument(name, value, metaData, null); } public void addNonEncodedArgument(String name, String value, String metadata) { HTTPArgument arg = new HTTPArgument(name, value, metadata, false); arg.setAlwaysEncoded(false); this.getArguments().addArgument(arg); } public void addArgument(String name, String value) { this.getArguments().addArgument(new HTTPArgument(name, value)); } public void addArgument(String name, String value, String metadata) { this.getArguments().addArgument(new HTTPArgument(name, value, metadata)); } public boolean hasArguments() { return getArguments().getArgumentCount() > 0; } @Override public void testStarted() { testStarted(""); } @Override public void testStarted(String host) { try { webSocketClientFactory.start(); } catch(Exception e) { log.error("Can't start WebSocketClientFactory", e); } } @Override public void testEnded() { testEnded(""); } @Override public void testEnded(String host) { try { for(WebSocket.Connection connection : samplerConnections) { connection.close(); } webSocketClientFactory.stop(); } catch (Exception e) { log.error("sampler error when close.", e); } } /** * @see org.apache.jmeter.samplers.AbstractSampler#applies(org.apache.jmeter.config.ConfigTestElement) */ @Override public boolean applies(ConfigTestElement configElement) { String guiClass = configElement.getProperty(TestElement.GUI_CLASS).getStringValue(); return APPLIABLE_CONFIG_CLASSES.contains(guiClass); } }