From d10fc8f0c42b0d4b41b0cccb6811a401088b67dc Mon Sep 17 00:00:00 2001
From: Giovanni Cimolin da Silva <giovannicimolin@gmail.com>
Date: Thu, 30 Jul 2020 18:17:27 -0300
Subject: [PATCH] Add decoding and scope verification methods

---
 lti_consumer/lti_1p3/consumer.py              | 23 ++++++
 lti_consumer/lti_1p3/exceptions.py            |  4 +
 lti_consumer/lti_1p3/key_handlers.py          | 33 +++++++++
 lti_consumer/lti_1p3/tests/test_consumer.py   | 30 ++++++++
 .../lti_1p3/tests/test_key_handlers.py        | 74 +++++++++++++++++++
 setup.py                                      |  2 +-
 6 files changed, 165 insertions(+), 1 deletion(-)

diff --git a/lti_consumer/lti_1p3/consumer.py b/lti_consumer/lti_1p3/consumer.py
index ce6282c..071cff2 100644
--- a/lti_consumer/lti_1p3/consumer.py
+++ b/lti_consumer/lti_1p3/consumer.py
@@ -397,3 +397,26 @@ class LtiConsumer1p3:
             assert response.get("redirect_uri") == self.launch_url
         except AssertionError:
             raise exceptions.PreflightRequestValidationFailure()
+
+    def check_token(self, token, allowed_scopes=None):
+        """
+        Check if token has access to allowed scopes.
+        """
+        token_contents = self.key_handler.validate_and_decode(
+            token,
+            # The issuer of the token is the platform
+            iss=self.iss,
+        )
+        # Tokens are space separated
+        token_scopes = token_contents['scopes'].split(' ')
+
+        # Check if token has permission for the requested scope,
+        # and throws exception if not.
+        # If `allowed_scopes` is empty, return true (just check
+        # token validity).
+        if allowed_scopes:
+            return any(
+                [scope in allowed_scopes for scope in token_scopes]
+            )
+
+        return True
diff --git a/lti_consumer/lti_1p3/exceptions.py b/lti_consumer/lti_1p3/exceptions.py
index a0d2c70..47169d3 100644
--- a/lti_consumer/lti_1p3/exceptions.py
+++ b/lti_consumer/lti_1p3/exceptions.py
@@ -34,6 +34,10 @@ class UnsupportedGrantType(Lti1p3Exception):
     pass
 
 
+class InvalidClaimValue(Lti1p3Exception):
+    pass
+
+
 class InvalidRsaKey(Lti1p3Exception):
     pass
 
diff --git a/lti_consumer/lti_1p3/key_handlers.py b/lti_consumer/lti_1p3/key_handlers.py
index daafc76..a3397da 100644
--- a/lti_consumer/lti_1p3/key_handlers.py
+++ b/lti_consumer/lti_1p3/key_handlers.py
@@ -182,3 +182,36 @@ class PlatformKeyHandler:
             public_keys.append(self.key)
 
         return json.loads(public_keys.dump_jwks())
+
+    def validate_and_decode(self, token, iss=None, aud=None):
+        """
+        Check if a platform token is valid, and return allowed scopes.
+
+        Validates a token sent by the tool using the platform's RSA Key.
+        Optionally validate iss and aud claims if provided.
+        """
+        try:
+            # Verify message signature
+            message = JWS().verify_compact(token, keys=[self.key])
+
+            # If message is valid, check expiration from JWT
+            if 'exp' in message and message['exp'] < time.time():
+                raise exceptions.TokenSignatureExpired()
+
+            # Validate issuer claim (if present)
+            if iss:
+                if 'iss' not in message or message['iss'] != iss:
+                    raise exceptions.InvalidClaimValue()
+
+            # Validate audience claim (if present)
+            if aud:
+                if 'aud' not in message or aud not in message['aud']:
+                    raise exceptions.InvalidClaimValue()
+
+            # Else return token contents
+            return message
+
+        except NoSuitableSigningKeys:
+            raise exceptions.NoSuitableKeys()
+        except BadSyntax:
+            raise exceptions.MalformedJwtToken()
diff --git a/lti_consumer/lti_1p3/tests/test_consumer.py b/lti_consumer/lti_1p3/tests/test_consumer.py
index 0941ea1..8387411 100644
--- a/lti_consumer/lti_1p3/tests/test_consumer.py
+++ b/lti_consumer/lti_1p3/tests/test_consumer.py
@@ -490,3 +490,33 @@ class TestLti1p3Consumer(TestCase):
 
         # Check if token is valid
         self._decode_token(response.get('access_token'))
+
+    def test_check_token_no_scopes(self):
+        """
+        Test if `check_token` method returns True for a valid token without scopes.
+        """
+        token = self.lti_consumer.key_handler.encode_and_sign({
+            "iss": ISS,
+            "scopes": "",
+        })
+        self.assertTrue(self.lti_consumer.check_token(token, None))
+
+    def test_check_token_with_allowed_scopes(self):
+        """
+        Test if `check_token` method returns True for a valid token with allowed scopes.
+        """
+        token = self.lti_consumer.key_handler.encode_and_sign({
+            "iss": ISS,
+            "scopes": "test"
+        })
+        self.assertTrue(self.lti_consumer.check_token(token, ['test', '123']))
+
+    def test_check_token_without_allowed_scopes(self):
+        """
+        Test if `check_token` method returns True for a valid token with allowed scopes.
+        """
+        token = self.lti_consumer.key_handler.encode_and_sign({
+            "iss": ISS,
+            "scopes": "test"
+        })
+        self.assertFalse(self.lti_consumer.check_token(token, ['123', ]))
diff --git a/lti_consumer/lti_1p3/tests/test_key_handlers.py b/lti_consumer/lti_1p3/tests/test_key_handlers.py
index 82b6345..73b3a82 100644
--- a/lti_consumer/lti_1p3/tests/test_key_handlers.py
+++ b/lti_consumer/lti_1p3/tests/test_key_handlers.py
@@ -106,6 +106,80 @@ class TestPlatformKeyHandler(TestCase):
             {'keys': []}
         )
 
+    # pylint: disable=unused-argument
+    @patch('time.time', return_value=1000)
+    def test_validate_and_decode(self, mock_time):
+        """
+        Test validate and decode with all parameters.
+        """
+        signed_token = self.key_handler.encode_and_sign(
+            {
+                "iss": "test-issuer",
+                "aud": "test-aud",
+            },
+            expiration=1000
+        )
+
+        self.assertEqual(
+            self.key_handler.validate_and_decode(signed_token),
+            {
+                "iss": "test-issuer",
+                "aud": "test-aud",
+                "iat": 1000,
+                "exp": 2000
+            }
+        )
+
+    # pylint: disable=unused-argument
+    @patch('time.time', return_value=1000)
+    def test_validate_and_decode_expired(self, mock_time):
+        """
+        Test validate and decode with all parameters.
+        """
+        signed_token = self.key_handler.encode_and_sign(
+            {},
+            expiration=-10
+        )
+
+        with self.assertRaises(exceptions.TokenSignatureExpired):
+            self.key_handler.validate_and_decode(signed_token)
+
+    def test_validate_and_decode_invalid_iss(self):
+        """
+        Test validate and decode with invalid iss.
+        """
+        signed_token = self.key_handler.encode_and_sign({"iss": "wrong"})
+
+        with self.assertRaises(exceptions.InvalidClaimValue):
+            self.key_handler.validate_and_decode(signed_token, iss="right")
+
+    def test_validate_and_decode_invalid_aud(self):
+        """
+        Test validate and decode with invalid aud.
+        """
+        signed_token = self.key_handler.encode_and_sign({"aud": "wrong"})
+
+        with self.assertRaises(exceptions.InvalidClaimValue):
+            self.key_handler.validate_and_decode(signed_token, aud="right")
+
+    def test_validate_and_decode_no_jwt(self):
+        """
+        Test validate and decode with invalid JWT.
+        """
+        with self.assertRaises(exceptions.MalformedJwtToken):
+            self.key_handler.validate_and_decode("1.2.3")
+
+    def test_validate_and_decode_no_keys(self):
+        """
+        Test validate and decode when no keys are available.
+        """
+        signed_token = self.key_handler.encode_and_sign({})
+        # Changing the KID so it doesn't match
+        self.key_handler.key.kid = "invalid_kid"
+
+        with self.assertRaises(exceptions.NoSuitableKeys):
+            self.key_handler.validate_and_decode(signed_token)
+
 
 @ddt.ddt
 class TestToolKeyHandler(TestCase):
diff --git a/setup.py b/setup.py
index b26d61c..783efd2 100644
--- a/setup.py
+++ b/setup.py
@@ -49,7 +49,7 @@ with open('README.rst') as _f:
 
 setup(
     name='lti-consumer-xblock',
-    version='2.1.1',
+    version='2.1.2',
     description='This XBlock implements the consumer side of the LTI specification.',
     long_description=long_description,
     long_description_content_type='text/markdown',
-- 
GitLab