diff --git a/lti_consumer/lti_1p3/consumer.py b/lti_consumer/lti_1p3/consumer.py index ce6282c465c65eb0f39792bdbd919f3413b8efa8..071cff206ca20390678e294cddc69876e1dd3d0a 100644 --- a/lti_consumer/lti_1p3/consumer.py +++ b/lti_consumer/lti_1p3/consumer.py @@ -397,3 +397,26 @@ class LtiConsumer1p3: assert response.get("redirect_uri") == self.launch_url except AssertionError: raise exceptions.PreflightRequestValidationFailure() + + def check_token(self, token, allowed_scopes=None): + """ + Check if token has access to allowed scopes. + """ + token_contents = self.key_handler.validate_and_decode( + token, + # The issuer of the token is the platform + iss=self.iss, + ) + # Tokens are space separated + token_scopes = token_contents['scopes'].split(' ') + + # Check if token has permission for the requested scope, + # and throws exception if not. + # If `allowed_scopes` is empty, return true (just check + # token validity). + if allowed_scopes: + return any( + [scope in allowed_scopes for scope in token_scopes] + ) + + return True diff --git a/lti_consumer/lti_1p3/exceptions.py b/lti_consumer/lti_1p3/exceptions.py index a0d2c70e3fb0bede122da225572991fc2933a73d..47169d3f8e51117a09774da30cfc8bd85f0ecb24 100644 --- a/lti_consumer/lti_1p3/exceptions.py +++ b/lti_consumer/lti_1p3/exceptions.py @@ -34,6 +34,10 @@ class UnsupportedGrantType(Lti1p3Exception): pass +class InvalidClaimValue(Lti1p3Exception): + pass + + class InvalidRsaKey(Lti1p3Exception): pass diff --git a/lti_consumer/lti_1p3/key_handlers.py b/lti_consumer/lti_1p3/key_handlers.py index daafc761b30f7a1be165543b965123f761d5f836..a3397dacc30aefa1a4905b199245b9f69b9d6ecc 100644 --- a/lti_consumer/lti_1p3/key_handlers.py +++ b/lti_consumer/lti_1p3/key_handlers.py @@ -182,3 +182,36 @@ class PlatformKeyHandler: public_keys.append(self.key) return json.loads(public_keys.dump_jwks()) + + def validate_and_decode(self, token, iss=None, aud=None): + """ + Check if a platform token is valid, and return allowed scopes. + + Validates a token sent by the tool using the platform's RSA Key. + Optionally validate iss and aud claims if provided. + """ + try: + # Verify message signature + message = JWS().verify_compact(token, keys=[self.key]) + + # If message is valid, check expiration from JWT + if 'exp' in message and message['exp'] < time.time(): + raise exceptions.TokenSignatureExpired() + + # Validate issuer claim (if present) + if iss: + if 'iss' not in message or message['iss'] != iss: + raise exceptions.InvalidClaimValue() + + # Validate audience claim (if present) + if aud: + if 'aud' not in message or aud not in message['aud']: + raise exceptions.InvalidClaimValue() + + # Else return token contents + return message + + except NoSuitableSigningKeys: + raise exceptions.NoSuitableKeys() + except BadSyntax: + raise exceptions.MalformedJwtToken() diff --git a/lti_consumer/lti_1p3/tests/test_consumer.py b/lti_consumer/lti_1p3/tests/test_consumer.py index 0941ea167c8d0f480becd51d45946cbdf9cd1f1a..8387411d72a014ce47800214c80893d4da975901 100644 --- a/lti_consumer/lti_1p3/tests/test_consumer.py +++ b/lti_consumer/lti_1p3/tests/test_consumer.py @@ -490,3 +490,33 @@ class TestLti1p3Consumer(TestCase): # Check if token is valid self._decode_token(response.get('access_token')) + + def test_check_token_no_scopes(self): + """ + Test if `check_token` method returns True for a valid token without scopes. + """ + token = self.lti_consumer.key_handler.encode_and_sign({ + "iss": ISS, + "scopes": "", + }) + self.assertTrue(self.lti_consumer.check_token(token, None)) + + def test_check_token_with_allowed_scopes(self): + """ + Test if `check_token` method returns True for a valid token with allowed scopes. + """ + token = self.lti_consumer.key_handler.encode_and_sign({ + "iss": ISS, + "scopes": "test" + }) + self.assertTrue(self.lti_consumer.check_token(token, ['test', '123'])) + + def test_check_token_without_allowed_scopes(self): + """ + Test if `check_token` method returns True for a valid token with allowed scopes. + """ + token = self.lti_consumer.key_handler.encode_and_sign({ + "iss": ISS, + "scopes": "test" + }) + self.assertFalse(self.lti_consumer.check_token(token, ['123', ])) diff --git a/lti_consumer/lti_1p3/tests/test_key_handlers.py b/lti_consumer/lti_1p3/tests/test_key_handlers.py index 82b63451ff340d9b348fdfa92d3da826882ee245..73b3a826b7117ec0fe6c6d9fcade0d209ea17670 100644 --- a/lti_consumer/lti_1p3/tests/test_key_handlers.py +++ b/lti_consumer/lti_1p3/tests/test_key_handlers.py @@ -106,6 +106,80 @@ class TestPlatformKeyHandler(TestCase): {'keys': []} ) + # pylint: disable=unused-argument + @patch('time.time', return_value=1000) + def test_validate_and_decode(self, mock_time): + """ + Test validate and decode with all parameters. + """ + signed_token = self.key_handler.encode_and_sign( + { + "iss": "test-issuer", + "aud": "test-aud", + }, + expiration=1000 + ) + + self.assertEqual( + self.key_handler.validate_and_decode(signed_token), + { + "iss": "test-issuer", + "aud": "test-aud", + "iat": 1000, + "exp": 2000 + } + ) + + # pylint: disable=unused-argument + @patch('time.time', return_value=1000) + def test_validate_and_decode_expired(self, mock_time): + """ + Test validate and decode with all parameters. + """ + signed_token = self.key_handler.encode_and_sign( + {}, + expiration=-10 + ) + + with self.assertRaises(exceptions.TokenSignatureExpired): + self.key_handler.validate_and_decode(signed_token) + + def test_validate_and_decode_invalid_iss(self): + """ + Test validate and decode with invalid iss. + """ + signed_token = self.key_handler.encode_and_sign({"iss": "wrong"}) + + with self.assertRaises(exceptions.InvalidClaimValue): + self.key_handler.validate_and_decode(signed_token, iss="right") + + def test_validate_and_decode_invalid_aud(self): + """ + Test validate and decode with invalid aud. + """ + signed_token = self.key_handler.encode_and_sign({"aud": "wrong"}) + + with self.assertRaises(exceptions.InvalidClaimValue): + self.key_handler.validate_and_decode(signed_token, aud="right") + + def test_validate_and_decode_no_jwt(self): + """ + Test validate and decode with invalid JWT. + """ + with self.assertRaises(exceptions.MalformedJwtToken): + self.key_handler.validate_and_decode("1.2.3") + + def test_validate_and_decode_no_keys(self): + """ + Test validate and decode when no keys are available. + """ + signed_token = self.key_handler.encode_and_sign({}) + # Changing the KID so it doesn't match + self.key_handler.key.kid = "invalid_kid" + + with self.assertRaises(exceptions.NoSuitableKeys): + self.key_handler.validate_and_decode(signed_token) + @ddt.ddt class TestToolKeyHandler(TestCase): diff --git a/setup.py b/setup.py index b26d61c5905e01185d3f78a0a32ae5dd71a3e240..783efd29b4e5f3d1a39961e85b8264b3c11d00f8 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ with open('README.rst') as _f: setup( name='lti-consumer-xblock', - version='2.1.1', + version='2.1.2', description='This XBlock implements the consumer side of the LTI specification.', long_description=long_description, long_description_content_type='text/markdown',