// Copyright 2018 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.android.things.iotcore; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; import static org.mockito.Mockito.when; import org.jose4j.lang.JoseException; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; import java.security.KeyPair; import java.security.KeyPairGenerator; import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; import java.security.PublicKey; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.time.ZoneId; import io.jsonwebtoken.Claims; import io.jsonwebtoken.Jws; import io.jsonwebtoken.JwsHeader; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.MalformedJwtException; import io.jsonwebtoken.SignatureException; import io.jsonwebtoken.UnsupportedJwtException; /** JwtGenerator unit tests. */ @RunWith(MockitoJUnitRunner.class) public class JwtGeneratorTest { private static final String JWT_AUDIENCE = "foo"; private static final KeyPair RSA_KEY_PAIR = generateKeyPair("RSA"); private static final KeyPair EC_KEY_PAIR = generateKeyPair("EC"); private static final Clock TEST_CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault()); private static final Duration TOKEN_LIFETIME = Duration.ofHours(1); @Mock private PrivateKey mMockPrivateKey; @Mock private PublicKey mMockPublicKey; // Generate key pairs for testing private static KeyPair generateKeyPair(String algorithm) { try { if (algorithm.equals("EC")) { KeyPairGenerator generator = KeyPairGenerator.getInstance(algorithm); generator.initialize(256); return generator.generateKeyPair(); } KeyPairGenerator generator = KeyPairGenerator.getInstance(algorithm); generator.initialize(2048); return generator.generateKeyPair(); } catch (NoSuchAlgorithmException e) { fail("Error generating test keypair"); } return null; // Satisfy compiler } @Test public void testConstructorRsaKeyAlgorithm() { KeyPair kp = new KeyPair(mMockPublicKey, mMockPrivateKey); when(mMockPrivateKey.getAlgorithm()).thenReturn("RSA"); when(mMockPrivateKey.getAlgorithm()).thenReturn("RSA"); assertThat(new JwtGenerator(kp, JWT_AUDIENCE, TOKEN_LIFETIME)).isNotNull(); int numPrivateGetAlgorithmCalled = Mockito.mockingDetails(mMockPrivateKey) .getInvocations() .size(); int numPublicGetAlgorithmCalled = Mockito.mockingDetails(mMockPublicKey) .getInvocations() .size(); assertThat(numPrivateGetAlgorithmCalled + numPublicGetAlgorithmCalled).isEqualTo(1); } @Test public void testConstructorEcKeyAlgorithm() { KeyPair kp = new KeyPair(mMockPublicKey, mMockPrivateKey); when(mMockPrivateKey.getAlgorithm()).thenReturn("EC"); when(mMockPrivateKey.getAlgorithm()).thenReturn("EC"); assertThat(new JwtGenerator(kp, JWT_AUDIENCE, TOKEN_LIFETIME)).isNotNull(); int numPrivateGetAlgorithmCalled = Mockito.mockingDetails(mMockPrivateKey) .getInvocations() .size(); int numPublicGetAlgorithmCalled = Mockito.mockingDetails(mMockPublicKey) .getInvocations() .size(); assertThat(numPrivateGetAlgorithmCalled + numPublicGetAlgorithmCalled).isEqualTo(1); } @Test public void testConstructorInvalidKeyAlgorithm() { KeyPair kp = new KeyPair(mMockPublicKey, mMockPrivateKey); when(mMockPrivateKey.getAlgorithm()).thenReturn("bad"); when(mMockPrivateKey.getAlgorithm()).thenReturn("bad"); try { new JwtGenerator(kp, JWT_AUDIENCE, TOKEN_LIFETIME); fail("JwtGenerator constructed with unsupported encryption algorithm"); } catch (IllegalArgumentException expected) { assertThat(expected).hasMessageThat().contains("unsupported"); int numPrivateGetAlgorithmCalled = Mockito.mockingDetails(mMockPrivateKey) .getInvocations() .size(); int numPublicGetAlgorithmCalled = Mockito.mockingDetails(mMockPublicKey) .getInvocations() .size(); assertThat(numPrivateGetAlgorithmCalled + numPublicGetAlgorithmCalled).isEqualTo(1); } } /** * Make sure Jwt created is formatted according to the Google Cloud IoT Core<a * href="https://cloud.google.com/iot/docs/how-tos/credentials/jwts#jwt_composition">spec</a>. */ @Test public void testCreateJwtRsa() throws JoseException { JwtGenerator jwtGenerator = new JwtGenerator(RSA_KEY_PAIR, JWT_AUDIENCE, TOKEN_LIFETIME, TEST_CLOCK); String rawJwt = jwtGenerator.createJwt(); // Validate JWT Jws<Claims> parsedJwt = Jwts.parser() .setSigningKey(RSA_KEY_PAIR.getPublic()) .parseClaimsJws(rawJwt); JwsHeader header = parsedJwt.getHeader(); Claims claims = parsedJwt.getBody(); assertThat(header.getAlgorithm()).isEqualTo("RS256"); assertThat(header.getType()).isEqualTo("JWT"); assertThat(claims.getAudience()).isEqualTo(JWT_AUDIENCE); // JWT requires time in seconds from epoch, not millis, so allow issue time within one // second. assertThat(claims.getIssuedAt().getTime()).isAtLeast(TEST_CLOCK.millis() - 1000); assertThat(claims.getIssuedAt().getTime()).isAtMost(TEST_CLOCK.millis() + 1000); // Check expiration time within one second of issue time + TOKEN_LIFETIME assertThat(claims.getExpiration().getTime()) .isLessThan(Clock.offset(TEST_CLOCK, TOKEN_LIFETIME.plusSeconds(1)).millis()); assertThat(claims.getExpiration().getTime()) .isAtLeast(Clock.offset(TEST_CLOCK, TOKEN_LIFETIME.minusSeconds(1)).millis()); } /** * Make sure Jwt created is formatted according to the Google Cloud IoT Core<a * href="https://cloud.google.com/iot/docs/how-tos/credentials/jwts#jwt_composition">spec</a>. */ @Test public void testCreateJwtEc() throws JoseException { JwtGenerator jwtGenerator = new JwtGenerator(EC_KEY_PAIR, JWT_AUDIENCE, TOKEN_LIFETIME, TEST_CLOCK); String rawJwt = jwtGenerator.createJwt(); // Validate JWT Jws<Claims> parsedJwt; try { parsedJwt = Jwts.parser() .setSigningKey(EC_KEY_PAIR.getPublic()) .parseClaimsJws(rawJwt); } catch (UnsupportedJwtException | MalformedJwtException | SignatureException e) { fail("Error parsing JWT: " + e); return; // Satisfy compiler } JwsHeader header = parsedJwt.getHeader(); Claims claims = parsedJwt.getBody(); assertThat(header.getAlgorithm()).isEqualTo("ES256"); assertThat(header.getType()).isEqualTo("JWT"); assertThat(claims.getAudience()).isEqualTo(JWT_AUDIENCE); // JWT requires time in seconds from epoch, not millis, so allow issue time within one // second. assertThat(claims.getIssuedAt().getTime()).isAtLeast(TEST_CLOCK.millis() - 1000); assertThat(claims.getIssuedAt().getTime()).isAtMost(TEST_CLOCK.millis() + 1000); // Check expiration time within one second of issue time + TOKEN_LIFETIME assertThat(claims.getExpiration().getTime()) .isLessThan(Clock.offset(TEST_CLOCK, TOKEN_LIFETIME.plusSeconds(1)).millis()); assertThat(claims.getExpiration().getTime()) .isAtLeast(Clock.offset(TEST_CLOCK, TOKEN_LIFETIME.minusSeconds(1)).millis()); } }