package com.apigee.callout.jwe; import com.apigee.flow.execution.ExecutionContext; import com.apigee.flow.execution.ExecutionResult; import com.apigee.flow.execution.IOIntensive; import com.apigee.flow.execution.spi.Execution; import com.apigee.flow.message.MessageContext; import java.util.Calendar; import java.util.Date; import java.util.Map; import java.util.HashMap; import java.util.Iterator; import java.io.IOException; import java.io.InputStream; import org.jose4j.jwk.JsonWebKey; import org.jose4j.jwe.JsonWebEncryption; import org.jose4j.jwe.KeyManagementAlgorithmIdentifiers; import org.jose4j.jwe.ContentEncryptionAlgorithmIdentifiers; import org.apache.commons.ssl.PKCS8Key; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.exception.ExceptionUtils; import org.apache.commons.codec.binary.Base64; import org.apache.commons.lang.text.StrSubstitutor; import java.security.PrivateKey; import java.security.spec.PKCS8EncodedKeySpec; import java.security.KeyFactory; import java.nio.charset.Charset; // import java.security.interfaces.RSAPrivateKey; import java.security.GeneralSecurityException; import java.security.NoSuchAlgorithmException; import java.security.spec.InvalidKeySpecException; import com.apigee.utils.TemplateString; @IOIntensive public class JweDecryptorCallout implements Execution { private final static String _varPrefix = "jwe_"; private Map<String,String> properties; // read-only public JweDecryptorCallout (Map properties) { // convert the untyped Map to a generic map Map<String,String> m = new HashMap<String,String>(); Iterator iterator = properties.keySet().iterator(); while(iterator.hasNext()){ Object key = iterator.next(); Object value = properties.get(key); if ((key instanceof String) && (value instanceof String)) { m.put((String) key, (String) value); } } this.properties = m; } private static InputStream getResourceAsStream(String resourceName) throws IOException { // forcibly prepend a slash if (!resourceName.startsWith("/")) { resourceName = "/" + resourceName; } if (!resourceName.startsWith("/resources")) { resourceName = "/resources" + resourceName; } InputStream in = JweDecryptorCallout.class.getResourceAsStream(resourceName); if (in == null) { throw new IOException("resource \"" + resourceName + "\" not found"); } return in; } private String getJweCompactSerialization(MessageContext msgCtxt) throws Exception { String jwe = (String) this.properties.get("jwe"); if (jwe == null || jwe.equals("")) { throw new IllegalStateException("jwe is not specified or is empty."); } jwe = resolvePropertyValue(jwe, msgCtxt); if (jwe == null || jwe.equals("")) { throw new IllegalStateException("jwe is null or empty."); } return jwe; } private String getSecretKey(MessageContext msgCtxt) throws Exception { String key = (String) this.properties.get("secret-key"); if (key == null || key.equals("")) { throw new IllegalStateException("secret-key is not specified or is empty."); } key = resolvePropertyValue(key, msgCtxt); if (key == null || key.equals("")) { throw new IllegalStateException("secret-key is null or empty."); } return key; } private String getAlgorithm(MessageContext msgCtxt) throws Exception { String algorithm = ((String) this.properties.get("algorithm")).trim(); if (algorithm == null || algorithm.equals("")) { throw new IllegalStateException("algorithm is not specified or is empty."); } algorithm = resolvePropertyValue(algorithm, msgCtxt); if (algorithm == null || algorithm.equals("")) { throw new IllegalStateException("issuer is not specified or is empty."); } JweUtils.validateJweAlgorithm(algorithm); return algorithm; } private String[] getAudience(MessageContext msgCtxt) throws Exception { String audience = (String) this.properties.get("audience"); if (audience == null || audience.equals("")) { // don't care. Audience is optional, per JWT Spec sec 4.1.3 return null; } String[] audiences = StringUtils.split(audience,","); for(int i=0; i<audiences.length; i++) { audiences[i] = resolvePropertyValue(audiences[i], msgCtxt); } return audiences; } private String getPrivateKeyPassword(MessageContext msgCtxt) { String password = (String) this.properties.get("private-key-password"); if (password == null || password.equals("")) { // don't care. Use of a password on the private key is optional. return null; } password = resolvePropertyValue(password, msgCtxt); if (password == null || password.equals("")) { return null; } return password; } private PrivateKey getPrivateKey(MessageContext msgCtxt) throws IOException, GeneralSecurityException, NoSuchAlgorithmException, InvalidKeySpecException { byte[] keyBytes = null; String privateKey = (String) this.properties.get("private-key"); String passwd = getPrivateKeyPassword(msgCtxt); if (privateKey==null) { String pemfile = (String) this.properties.get("pemfile"); if (pemfile == null || pemfile.equals("")) { throw new IllegalStateException("must specify pemfile or private-key when algorithm is RS*"); } pemfile = resolvePropertyValue(pemfile, msgCtxt); if (pemfile == null || pemfile.equals("")) { throw new IllegalStateException("pemfile resolves to nothing; invalid when algorithm is RS*"); } InputStream in = getResourceAsStream(pemfile); keyBytes = new byte[in.available()]; in.read(keyBytes); in.close(); } else { if (privateKey.equals("")) { throw new IllegalStateException("private-key must be non-empty"); } privateKey = resolvePropertyValue(privateKey, msgCtxt); if (privateKey==null || privateKey.equals("")) { throw new IllegalStateException("private-key variable resolves to empty; invalid when algorithm is RS*"); } privateKey = privateKey.trim(); // clear any leading whitespace on each line privateKey = privateKey.replaceAll("([\\r|\\n] +)","\n"); keyBytes = privateKey.getBytes(Charset.forName("UTF-8")); } // If the provided data is encrypted, we need a password to decrypt // it. If the InputStream is not encrypted, then the password is ignored // (can be null). The InputStream can be DER (raw ASN.1) or PEM (base64). PKCS8Key pkcs8 = new PKCS8Key( keyBytes, passwd.toCharArray() ); // If an unencrypted PKCS8 key was provided, then getDecryptedBytes() // actually returns exactly what was originally passed in (with no // changes). If an OpenSSL key was provided, it gets reformatted as // PKCS #8. byte[] decrypted = pkcs8.getDecryptedBytes(); PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec( decrypted ); // A Java PrivateKey object is born. PrivateKey pk = null; if ( pkcs8.isDSA() ) { pk = KeyFactory.getInstance( "DSA" ).generatePrivate( spec ); } else if ( pkcs8.isRSA() ) { pk = KeyFactory.getInstance( "RSA" ).generatePrivate( spec ); } return pk; } // If the value of a property value contains open and close curlies, eg, // {apiproxy.name} or ABC-{apikey}, then "resolve" the value by de-referencing // the context variables whose names appear between curlies. private String resolvePropertyValue(String spec, MessageContext msgCtxt) { if (spec.indexOf('{') > -1 && spec.indexOf('}')>-1) { // Replace ALL curly-braced items in the spec string with // the value of the corresponding context variable. TemplateString ts = new TemplateString(spec); Map<String,String> valuesMap = new HashMap<String,String>(); for (String s : ts.variableNames) { valuesMap.put(s, (String) msgCtxt.getVariable(s)); } StrSubstitutor sub = new StrSubstitutor(valuesMap); String resolvedString = sub.replace(ts.template); return resolvedString; } return spec; } private static final String varName(String s) { return _varPrefix + s; } public ExecutionResult execute(MessageContext msgCtxt, ExecutionContext exeCtxt) { try { msgCtxt.removeVariable(varName("error")); String jweText = getJweCompactSerialization(msgCtxt); String secretKey = getSecretKey(msgCtxt); String b64Key = Base64.encodeBase64String(secretKey.getBytes("UTF-8")); String jwkJson = "{\"kty\":\"oct\",\"k\":\""+ b64Key + "\"}"; JsonWebKey jwk = JsonWebKey.Factory.newJwk(jwkJson); JsonWebEncryption jwe = new JsonWebEncryption(); // Set the compact serialization on new Json Web Encryption object jwe.setCompactSerialization(jweText); jwe.setKey(jwk.getKey()); // Get the message that was encrypted in the JWE. This step // performs the actual decryption steps. String plaintext = jwe.getPlaintextString(); msgCtxt.setVariable(varName("plaintext"), plaintext); String foundAlgorithm = jwe.getEncryptionMethodHeaderParameter(); msgCtxt.setVariable(varName("algorithm"), foundAlgorithm); if (!StringUtils.isEmpty(foundAlgorithm)) { String requiredAlgorithm = getAlgorithm(msgCtxt); if (! foundAlgorithm.equals(requiredAlgorithm)) { msgCtxt.setVariable(varName("error"), String.format("Algorithm mismatch: found [%s], expected [%s]", foundAlgorithm, requiredAlgorithm)); return ExecutionResult.ABORT; } } } catch (Exception e) { msgCtxt.setVariable(varName("error"), "Exception " + e.toString()); msgCtxt.setVariable(varName("stacktrace"), ExceptionUtils.getStackTrace(e)); return ExecutionResult.ABORT; } return ExecutionResult.SUCCESS; } }