diff --git a/README.md b/README.md index 487c06d..c341369 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # JWTs for Java + This is not an officially supported Google product [![CircleCI](https://img.shields.io/circleci/project/github/auth0/java-jwt.svg?style=flat-square)](https://circleci.com/gh/auth0/java-jwt/tree/master) diff --git a/lib/src/main/java/oicclient/client/ClientUtils.java b/lib/src/main/java/oicclient/client/ClientUtils.java new file mode 100644 index 0000000..63d26f3 --- /dev/null +++ b/lib/src/main/java/oicclient/client/ClientUtils.java @@ -0,0 +1,23 @@ +package oicclient.client; + +import java.util.HashMap; +import java.util.Map; + +public class ClientUtils { + + private static final String RS256 = "RS256"; + private static final String HS256 = "HS256"; + private static final String ID_TOKEN = "idToken"; + private static final String USER_INFO = "userInfo"; + private static final String REQUEST_OBJECT = "requestObject"; + private static final String CLIENT_SECRET_JWT = "clientSecretJwt"; + private static final String PRIVATE_KEY_JWT = "privateKeyJwt"; + + public static final Map DEF_SIGN_ALG = new HashMap() {{ + put(ID_TOKEN, RS256); + put(USER_INFO, RS256); + put(REQUEST_OBJECT, RS256); + put(CLIENT_SECRET_JWT, HS256); + put(PRIVATE_KEY_JWT, RS256); + }}; +} diff --git a/lib/src/main/java/oicclient/client_auth/ClientAuthenticationMethod.java b/lib/src/main/java/oicclient/client_auth/ClientAuthenticationMethod.java new file mode 100644 index 0000000..95045f0 --- /dev/null +++ b/lib/src/main/java/oicclient/client_auth/ClientAuthenticationMethod.java @@ -0,0 +1,12 @@ +package oicclient.client_auth; + +import com.auth0.jwt.creators.Message; +import java.util.Map; +import oicclient.exceptions.AuthenticationFailure; +import oicclient.exceptions.NoMatchingKey; + +public abstract class ClientAuthenticationMethod { + + protected abstract void construct(Message request, ClientInfo clientInfo, + Map httpArgs, Map args) throws AuthenticationFailure, NoMatchingKey; +} diff --git a/lib/src/main/java/oicclient/client_auth/JWSAuthenticationMethod.java b/lib/src/main/java/oicclient/client_auth/JWSAuthenticationMethod.java new file mode 100644 index 0000000..7e4b1fd --- /dev/null +++ b/lib/src/main/java/oicclient/client_auth/JWSAuthenticationMethod.java @@ -0,0 +1,111 @@ +package oicclient.client_auth; + +import com.auth0.jwt.creators.Message; +import com.google.common.base.Strings; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import oicclient.client.ClientUtils; +import oicclient.clientinfo.ClientInfo; +import oicclient.exceptions.AuthenticationFailure; +import oicclient.exceptions.NoMatchingKey; + +public class JWSAuthenticationMethod extends ClientAuthenticationMethod { + + private static final String JWT_BEARER = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; + + protected String chooseAlgorithm(String context, Map args) throws AuthenticationFailure { + String algorithm = args.get("algorithm"); + if (Strings.isNullOrEmpty(algorithm)) { + algorithm = ClientUtils.DEF_SIGN_ALG.get(context); + if (Strings.isNullOrEmpty(algorithm)) { + throw new AuthenticationFailure("Missing algorithm specification"); + } + } + + return algorithm; + } + + protected List getSigningKey(String algorithm, ClientInfo clientInfo) { + return clientInfo.getKeyJar().getSigningKey(alg2KeyType(algorithm), algorithm); + } + + private Key getKeyByKid(String kid, String algorithm, ClientInfo clientInfo) throws NoMatchingKey { + Key key = clientInfo.getKeyJar().getKeyByKid(kid); + if (key != null) { + String keyType = alg2KeyType(algorithm); + if (!key.getKeyType().equals(keyType)) { + throw new NoMatchingKey("Wrong key type"); + } else { + return key; + } + } else { + throw new NoMatchingKey("No key with kid: " + kid); + } + } + + protected void construct(Message request, ClientInfo clientInfo, + Map httpArgs, Map args) throws AuthenticationFailure, NoMatchingKey { + String algorithm = null; + List audience = null; + Map cParams = request.getCParams(); + if (args.containsKey("clientAssertion")) { + cParams.put("clientAssertion", args.get("clientAssertion")); + if (args.containsKey("clientAssertionType")) { + cParams.put("clientAssertionType", args.get("clientAssertionType")); + } else { + cParams.put("clientAssertionType", JWT_BEARER); + } + } else if (cParams.containsKey("clientAssertion")) { + if (!cParams.containsKey("clientAssertionType")) { + cParams.put("clientAssertionType", JWT_BEARER); + } + } else { + if ((args.get("authenticationEndpoint") != null && args.get("authenticationEndpoint").equals("token")) || (args.get("authenticationEndpoint") != null && args.get("authenticationEndpoint").equals("refresh"))) { + algorithm = clientInfo.getRegistrationResponse().get("tokenEndpointAuthSigningAlg").get(0); + audience = clientInfo.getProviderInfo().get("tokenEndpoint"); + } else { + audience = clientInfo.getProviderInfo().get("issuer"); + } + } + + if (Strings.isNullOrEmpty(algorithm)) { + //how is this going to call a subclass? + algorithm = this.chooseAlgorithm(args); + } + + String ktype = alg2keytype(algorithm); + List signingKey = null; + if (args.containsKey("kid")) { + signingKey = Arrays.asList(this.getKeyByKid(args.get("kid"), algorithm, clientInfo)); + } else if (clientInfo.getKid().get("sig").containsKey(ktype)) { + Key key = this.getKeyByKid(clientInfo.getKid().get("sig").get(ktype), algorithm, clientInfo); + if (key != null) { + signingKey = Arrays.asList(key); + } else { + signingKey = this.getSigningKey(algorithm, clientInfo); + } + } else { + signingKey = this.getSigningKey(algorithm, clientInfo); + } + + int lifetime = -1; + if (!Strings.isNullOrEmpty(args.get("lifetime"))) { + lifetime = Integer.parseInt(args.get("lifetime")); + } + if(lifetime != -1) { + cParams.put("clientAssertion", assertionJWT(clientInfo.getClientId(), signingKey, audience, algorithm, lifetime)); + } else { + cParams.put("clientAssertion", assertionJWT(clientInfo.getClientId(), signingKey, audience, algorithm, 600)); + } + cParams.put("clientAssertionType", JWT_BEARER); + + cParams.remove("clientSecret"); + + if (cParams.get("clientId") != null && !cParams.get("clientId").getB()) { + cParams.remove("clientId"); + } + + request.setCParams(cParams); + } +} diff --git a/lib/src/main/java/oicclient/exceptions/AuthenticationFailure.java b/lib/src/main/java/oicclient/exceptions/AuthenticationFailure.java new file mode 100644 index 0000000..0a2ffcd --- /dev/null +++ b/lib/src/main/java/oicclient/exceptions/AuthenticationFailure.java @@ -0,0 +1,7 @@ +package oicclient.exceptions; + +public class AuthenticationFailure extends Exception{ + public AuthenticationFailure(String message) { + super(message); + } +} diff --git a/lib/src/main/java/oicclient/exceptions/NoMatchingKey.java b/lib/src/main/java/oicclient/exceptions/NoMatchingKey.java new file mode 100644 index 0000000..14001f8 --- /dev/null +++ b/lib/src/main/java/oicclient/exceptions/NoMatchingKey.java @@ -0,0 +1,7 @@ +package oicclient.exceptions; + +public class NoMatchingKey extends Exception { + public NoMatchingKey(String message) { + super(message); + } +}