From 79d3c3b067aa24991503966e0d62be192447f4cb Mon Sep 17 00:00:00 2001
From: Giovanni Cimolin da Silva <giovannicimolin@gmail.com>
Date: Mon, 27 Apr 2020 16:10:03 -0300
Subject: [PATCH] Add implementation of access token endpoint

This commit adds support for the LTI 1.3 Access token endpoint, as detailed in the IMS Security Framework.
The token is generated using the consumer's private key (stored in the XBlock) after validating the message sent by the LTI Tool using it's public key.

Signed-off-by: Giovanni Cimolin da Silva <giovannicimolin@gmail.com>
---
 lti_consumer/lti_1p3/README.md                |  18 ++
 lti_consumer/lti_1p3/constants.py             |  10 +
 lti_consumer/lti_1p3/consumer.py              | 127 +++++++---
 lti_consumer/lti_1p3/exceptions.py            |  38 +++
 lti_consumer/lti_1p3/key_handlers.py          | 184 ++++++++++++++
 lti_consumer/lti_1p3/tests/__init__.py        |   0
 .../tests/test_consumer.py}                   |  65 ++++-
 .../lti_1p3/tests/test_key_handlers.py        | 229 ++++++++++++++++++
 lti_consumer/lti_1p3/tests/utils.py           |  12 +
 lti_consumer/lti_consumer.py                  |  87 ++++++-
 lti_consumer/plugin/urls.py                   |   6 +
 lti_consumer/plugin/views.py                  |  23 ++
 .../templates/html/lti_1p3_studio.html        |   4 +-
 lti_consumer/tests/unit/test_lti_consumer.py  | 162 ++++++++++++-
 lti_consumer/utils.py                         |  12 +
 15 files changed, 933 insertions(+), 44 deletions(-)
 create mode 100644 lti_consumer/lti_1p3/README.md
 create mode 100644 lti_consumer/lti_1p3/exceptions.py
 create mode 100644 lti_consumer/lti_1p3/key_handlers.py
 create mode 100644 lti_consumer/lti_1p3/tests/__init__.py
 rename lti_consumer/{tests/unit/test_lti_1p3_consumer.py => lti_1p3/tests/test_consumer.py} (82%)
 create mode 100644 lti_consumer/lti_1p3/tests/test_key_handlers.py
 create mode 100644 lti_consumer/lti_1p3/tests/utils.py

diff --git a/lti_consumer/lti_1p3/README.md b/lti_consumer/lti_1p3/README.md
new file mode 100644
index 0000000..021c45d
--- /dev/null
+++ b/lti_consumer/lti_1p3/README.md
@@ -0,0 +1,18 @@
+LTI 1.3 Consumer Class
+-
+
+This implements a LTI 1.3 compliant consumer class which is request agnostic and can
+be reused in different contexts (XBlock, Django plugin, and even on other frameworks).
+
+This doesn't implement any data storage, just the methods required for handling LTI messages
+and Access Tokens.
+
+Features:
+- LTI 1.3 Launch with full OIDC flow
+- Support for custom parameters claim
+- Support for launch presentation claim
+- Access token creation
+
+This implementation was based on the following IMS Global Documents:
+- LTI 1.3 Core Specification: http://www.imsglobal.org/spec/lti/v1p3/
+- IMS Global Security Framework: https://www.imsglobal.org/spec/security/v1p0/
diff --git a/lti_consumer/lti_1p3/constants.py b/lti_consumer/lti_1p3/constants.py
index c0d93a0..d6b26e8 100644
--- a/lti_consumer/lti_1p3/constants.py
+++ b/lti_consumer/lti_1p3/constants.py
@@ -33,3 +33,13 @@ LTI_1P3_ROLE_MAP = {
         'http://purl.imsglobal.org/vocab/lis/v2/institution/person#Student'
     ],
 }
+
+
+LTI_1P3_ACCESS_TOKEN_REQUIRED_CLAIMS = set([
+    "grant_type",
+    "client_assertion_type",
+    "client_assertion",
+    "scope",
+])
+
+LTI_1P3_ACCESS_TOKEN_SCOPES = []
diff --git a/lti_consumer/lti_1p3/consumer.py b/lti_consumer/lti_1p3/consumer.py
index f5fb01c..11444ee 100644
--- a/lti_consumer/lti_1p3/consumer.py
+++ b/lti_consumer/lti_1p3/consumer.py
@@ -1,18 +1,17 @@
 """
 LTI 1.3 Consumer implementation
 """
-import json
-import time
-
 # Quality checks failing due to know pylint bug
 from six.moves.urllib.parse import urlencode
 
-from Crypto.PublicKey import RSA
-from jwkest.jwk import RSAKey
-from jwkest.jws import JWS
-from jwkest import jwk
-
-from .constants import LTI_1P3_ROLE_MAP, LTI_BASE_MESSAGE
+from . import exceptions
+from .constants import (
+    LTI_1P3_ROLE_MAP,
+    LTI_BASE_MESSAGE,
+    LTI_1P3_ACCESS_TOKEN_REQUIRED_CLAIMS,
+    LTI_1P3_ACCESS_TOKEN_SCOPES,
+)
+from .key_handlers import ToolKeyHandler, PlatformKeyHandler
 
 
 class LtiConsumer1p3:
@@ -27,7 +26,9 @@ class LtiConsumer1p3:
             client_id,
             deployment_id,
             rsa_key,
-            rsa_key_id
+            rsa_key_id,
+            tool_key=None,
+            tool_keyset_url=None,
     ):
         """
         Initialize LTI 1.3 Consumer class
@@ -38,14 +39,13 @@ class LtiConsumer1p3:
         self.client_id = client_id
         self.deployment_id = deployment_id
 
-        # Generate JWK from RSA key
-        self.jwk = RSAKey(
-            # Using the same key ID as client id
-            # This way we can easily serve multiple public
-            # keys on teh same endpoint and keep all
-            # LTI 1.3 blocks working
-            kid=rsa_key_id,
-            key=RSA.import_key(rsa_key)
+        # Set up platform message signature class
+        self.key_handler = PlatformKeyHandler(rsa_key, rsa_key_id)
+
+        # Set up tool public key verification class
+        self.tool_jwt = ToolKeyHandler(
+            public_key=tool_key,
+            keyset_url=tool_keyset_url
         )
 
         # IMS LTI Claim data
@@ -53,17 +53,6 @@ class LtiConsumer1p3:
         self.lti_claim_launch_presentation = None
         self.lti_claim_custom_parameters = None
 
-    def _encode_and_sign(self, message):
-        """
-        Encode and sign JSON with RSA key
-        """
-        # The class instance that sets up the signing operation
-        # An RS 256 key is required for LTI 1.3
-        _jws = JWS(message, alg="RS256", cty="JWT")
-
-        # Encode and sign LTI message
-        return _jws.sign_compact([self.jwk])
-
     @staticmethod
     def _get_user_roles(role):
         """
@@ -256,15 +245,12 @@ class LtiConsumer1p3:
         if self.lti_claim_custom_parameters:
             lti_message.update(self.lti_claim_custom_parameters)
 
-        # Add `exp` and `iat` JWT attributes
-        lti_message.update({
-            "iat": int(round(time.time())),
-            "exp": int(round(time.time()) + 3600)
-        })
-
         return {
             "state": preflight_response.get("state"),
-            "id_token": self._encode_and_sign(lti_message)
+            "id_token": self.key_handler.encode_and_sign(
+                message=lti_message,
+                expiration=300
+            )
         }
 
     def get_public_keyset(self):
@@ -290,3 +276,72 @@ class LtiConsumer1p3:
             assert response.get("redirect_uri") == self.launch_url
         except AssertionError:
             raise ValueError("Preflight reponse failed validation")
+
+        return self.key_handler.get_public_jwk()
+
+    def access_token(self, token_request_data):
+        """
+        Validate request and return JWT access token.
+
+        This complies to IMS Security Framework and accepts a JWT
+        as a secret for the client credentials grant.
+        See this section:
+        https://www.imsglobal.org/spec/security/v1p0/#securing_web_services
+
+        Full spec reference:
+        https://www.imsglobal.org/spec/security/v1p0/
+
+        Parameters:
+            token_request_data: Dict of parameters sent by LTI tool as form_data.
+
+        Returns:
+            A dict containing the JSON response containing a JWT and some extra
+            parameters required by LTI tools. This token gives access to all
+            supported LTI Scopes from this tool.
+        """
+        # Check if all required claims are present
+        for required_claim in LTI_1P3_ACCESS_TOKEN_REQUIRED_CLAIMS:
+            if required_claim not in token_request_data.keys():
+                raise exceptions.MissingRequiredClaim()
+
+        # Check that grant type is `client_credentials`
+        if token_request_data['grant_type'] != 'client_credentials':
+            raise exceptions.UnsupportedGrantType()
+
+        # Validate JWT token
+        self.tool_jwt.validate_and_decode(
+            token_request_data['client_assertion']
+        )
+
+        # Check scopes and only return valid and supported ones
+        valid_scopes = []
+        requested_scopes = token_request_data['scope'].split(' ')
+
+        for scope in requested_scopes:
+            # TODO: Add additional checks for permitted scopes
+            # Currently there are no scopes, because there is no use for
+            # these access tokens until a tool needs to access the LMS.
+            # LTI Advantage extensions make use of this.
+            if scope in LTI_1P3_ACCESS_TOKEN_SCOPES:
+                valid_scopes.append(scope)
+
+        # Scopes are space separated as described in
+        # https://tools.ietf.org/html/rfc6749
+        scopes_str = " ".join(valid_scopes)
+
+        # This response is compliant with RFC 6749
+        # https://tools.ietf.org/html/rfc6749#section-4.4.3
+        return {
+            "access_token": self.key_handler.encode_and_sign(
+                {
+                    "sub": self.client_id,
+                    "scopes": scopes_str
+                },
+                # Create token valid for 3600 seconds (1h) as per specification
+                # https://www.imsglobal.org/spec/security/v1p0/#expires_in-values-and-renewing-the-access-token
+                expiration=3600
+            ),
+            "token_type": "bearer",
+            "expires_in": 3600,
+            "scope": scopes_str
+        }
diff --git a/lti_consumer/lti_1p3/exceptions.py b/lti_consumer/lti_1p3/exceptions.py
new file mode 100644
index 0000000..5f55bac
--- /dev/null
+++ b/lti_consumer/lti_1p3/exceptions.py
@@ -0,0 +1,38 @@
+"""
+Custom exceptions for LTI 1.3 consumer
+
+# TODO: Improve exception documentation and output.
+"""
+# pylint: disable=missing-docstring
+
+
+class TokenSignatureExpired(Exception):
+    pass
+
+
+class NoSuitableKeys(Exception):
+    pass
+
+
+class UnknownClientId(Exception):
+    pass
+
+
+class MalformedJwtToken(Exception):
+    pass
+
+
+class MissingRequiredClaim(Exception):
+    pass
+
+
+class UnsupportedGrantType(Exception):
+    pass
+
+
+class InvalidRsaKey(Exception):
+    pass
+
+
+class RsaKeyNotSet(Exception):
+    pass
diff --git a/lti_consumer/lti_1p3/key_handlers.py b/lti_consumer/lti_1p3/key_handlers.py
new file mode 100644
index 0000000..9bb582e
--- /dev/null
+++ b/lti_consumer/lti_1p3/key_handlers.py
@@ -0,0 +1,184 @@
+"""
+LTI 1.3 - Access token library
+
+This handles validating messages sent by the tool and generating
+access token with LTI scopes.
+"""
+import codecs
+import copy
+import time
+import json
+
+from Crypto.PublicKey import RSA
+from jwkest import BadSyntax, WrongNumberOfParts, jwk
+from jwkest.jwk import RSAKey, load_jwks_from_url
+from jwkest.jws import JWS, NoSuitableSigningKeys
+from jwkest.jwt import JWT
+
+from . import exceptions
+
+
+class ToolKeyHandler(object):
+    """
+    LTI 1.3 Tool Jwt Handler.
+
+    Uses a tool public keys or keysets URL to retrieve
+    a key and validate a message sent by the tool.
+
+    This is primarily used by the Access Token endpoint
+    in order to validate the JWT Signature of messages
+    signed with the tools signature.
+    """
+    def __init__(self, public_key=None, keyset_url=None):
+        """
+        Instance message validator
+
+        Import a public key from the tool by either using a keyset url
+        or a combination of public key + key id.
+
+        Keyset URL takes precedence because it makes key rotation easier to do.
+        """
+        # Only store keyset URL to avoid blocking the class
+        # instancing on an external url, which is only used
+        # when validating a token.
+        self.keyset_url = keyset_url
+        self.public_key = None
+
+        # Import from public key
+        if public_key:
+            try:
+                new_key = RSAKey(use='sig')
+
+                # Unescape key before importing it
+                raw_key = codecs.decode(public_key, 'unicode_escape')
+
+                # Import Key and save to internal state
+                new_key.load_key(RSA.import_key(raw_key))
+                self.public_key = new_key
+            except ValueError:
+                raise exceptions.InvalidRsaKey()
+
+    def _get_keyset(self, kid=None):
+        """
+        Get keyset from available sources.
+
+        If using a RSA key, forcefully set the key id
+        to match the one from the JWT token.
+        """
+        keyset = []
+
+        if self.keyset_url:
+            # TODO: Improve support for keyset handling, handle errors.
+            keyset.extend(load_jwks_from_url(self.keyset_url))
+
+        if self.public_key and kid:
+            # Fill in key id of stored key.
+            # This is needed because if the JWS is signed with a
+            # key with a kid, pyjwkest doesn't match them with
+            # keys without kid (kid=None) and fails verification
+            self.public_key.kid = kid
+
+            # Add to keyset
+            keyset.append(self.public_key)
+
+        return keyset
+
+    def validate_and_decode(self, token):
+        """
+        Check if a message sent by the tool is valid.
+
+        From https://www.imsglobal.org/spec/security/v1p0/#using-oauth-2-0-client-credentials-grant:
+
+        The authorization server decodes the JWT and MUST validate the values for the
+        iss, sub, exp, aud and jti claims.
+        """
+        try:
+            # Get KID from JWT header
+            jwt = JWT().unpack(token)
+
+            # Verify message signature
+            message = JWS().verify_compact(
+                token,
+                keys=self._get_keyset(
+                    jwt.headers.get('kid')
+                )
+            )
+
+            # If message is valid, check expiration from JWT
+            if 'exp' in message and message['exp'] < time.time():
+                raise exceptions.TokenSignatureExpired()
+
+            # TODO: Validate other JWT claims
+
+            # Else returns decoded message
+            return message
+
+        except NoSuitableSigningKeys:
+            raise exceptions.NoSuitableKeys()
+        except BadSyntax:
+            raise exceptions.MalformedJwtToken()
+        except WrongNumberOfParts:
+            raise exceptions.MalformedJwtToken()
+
+
+class PlatformKeyHandler(object):
+    """
+    Platform RSA Key handler.
+
+    This class loads the platform key and is responsible for
+    encoding JWT messages and exporting public keys.
+    """
+    def __init__(self, key_pem, kid=None):
+        """
+        Import Key when instancing class if a key is present.
+        """
+        self.key = None
+
+        if key_pem:
+            # Import JWK from RSA key
+            try:
+                self.key = RSAKey(
+                    # Using the same key ID as client id
+                    # This way we can easily serve multiple public
+                    # keys on teh same endpoint and keep all
+                    # LTI 1.3 blocks working
+                    kid=kid,
+                    key=RSA.import_key(key_pem)
+                )
+            except ValueError:
+                raise exceptions.InvalidRsaKey()
+
+    def encode_and_sign(self, message, expiration=None):
+        """
+        Encode and sign JSON with RSA key
+        """
+        if not self.key:
+            raise exceptions.RsaKeyNotSet()
+
+        _message = copy.deepcopy(message)
+
+        # Set iat and exp if expiration is set
+        if expiration:
+            _message.update({
+                "iat": int(round(time.time())),
+                "exp": int(round(time.time()) + expiration),
+            })
+
+        # The class instance that sets up the signing operation
+        # An RS 256 key is required for LTI 1.3
+        _jws = JWS(_message, alg="RS256", cty="JWT")
+
+        # Encode and sign LTI message
+        return _jws.sign_compact([self.key])
+
+    def get_public_jwk(self):
+        """
+        Export Public JWK
+        """
+        public_keys = jwk.KEYS()
+
+        # Only append to keyset if a key exists
+        if self.key:
+            public_keys.append(self.key)
+
+        return json.loads(public_keys.dump_jwks())
diff --git a/lti_consumer/lti_1p3/tests/__init__.py b/lti_consumer/lti_1p3/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/lti_consumer/tests/unit/test_lti_1p3_consumer.py b/lti_consumer/lti_1p3/tests/test_consumer.py
similarity index 82%
rename from lti_consumer/tests/unit/test_lti_1p3_consumer.py
rename to lti_consumer/lti_1p3/tests/test_consumer.py
index 747daae..e3b0e7e 100644
--- a/lti_consumer/tests/unit/test_lti_1p3_consumer.py
+++ b/lti_consumer/lti_1p3/tests/test_consumer.py
@@ -15,6 +15,7 @@ from jwkest.jwk import load_jwks
 from jwkest.jws import JWS
 
 from lti_consumer.lti_1p3.consumer import LtiConsumer1p3
+from lti_consumer.lti_1p3 import exceptions
 
 
 # Variables required for testing and verification
@@ -47,7 +48,9 @@ class TestLti1p3Consumer(TestCase):
             client_id=CLIENT_ID,
             deployment_id=DEPLOYMENT_ID,
             rsa_key=RSA_KEY,
-            rsa_key_id=RSA_KEY_ID
+            rsa_key_id=RSA_KEY_ID,
+            # Use the same key for testing purposes
+            tool_key=RSA_KEY
         )
 
     def _setup_lti_user(self):
@@ -321,3 +324,63 @@ class TestLti1p3Consumer(TestCase):
         """
         with self.assertRaises(ValueError):
             self.lti_consumer.set_custom_parameters("invalid")
+
+    def test_access_token_missing_params(self):
+        """
+        Check if access token with missing request data raises.
+        """
+        with self.assertRaises(exceptions.MissingRequiredClaim):
+            self.lti_consumer.access_token({})
+
+    def test_access_token_invalid_jwt(self):
+        """
+        Check if access token with invalid request data raises.
+        """
+        request_data = {
+            "grant_type": "client_credentials",
+            "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
+            # This should be a valid JWT
+            "client_assertion": "invalid-jwt",
+            # Scope can be empty
+            "scope": "",
+        }
+
+        with self.assertRaises(exceptions.MalformedJwtToken):
+            self.lti_consumer.access_token(request_data)
+
+    def test_access_token(self):
+        """
+        Check if a valid access token is returned.
+
+        Since we're using the same key for both tool and
+        platform here, we can make use of the internal
+        _decode_token to check validity.
+        """
+        # Generate a dummy, but valid JWT
+        token = self.lti_consumer.key_handler.encode_and_sign(
+            {
+                "test": "test"
+            },
+            expiration=1000
+        )
+
+        request_data = {
+            # We don't actually care about these 2 first values
+            "grant_type": "client_credentials",
+            "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
+            # This should be a valid JWT
+            "client_assertion": token,
+            # Scope can be empty
+            "scope": "",
+        }
+
+        response = self.lti_consumer.access_token(request_data)
+
+        # Check response contents
+        self.assertIn('access_token', response)
+        self.assertEqual(response.get('token_type'), 'bearer')
+        self.assertEqual(response.get('expires_in'), 3600)
+        self.assertEqual(response.get('scope'), '')
+
+        # Check if token is valid
+        self._decode_token(response.get('access_token'))
diff --git a/lti_consumer/lti_1p3/tests/test_key_handlers.py b/lti_consumer/lti_1p3/tests/test_key_handlers.py
new file mode 100644
index 0000000..86d696c
--- /dev/null
+++ b/lti_consumer/lti_1p3/tests/test_key_handlers.py
@@ -0,0 +1,229 @@
+"""
+Unit tests for LTI 1.3 consumer implementation
+"""
+from __future__ import absolute_import, unicode_literals
+
+import json
+import ddt
+
+from mock import patch
+from django.test.testcases import TestCase
+
+from Crypto.PublicKey import RSA
+from jwkest.jwk import RSAKey, load_jwks
+from jwkest.jws import JWS
+
+from lti_consumer.lti_1p3.key_handlers import PlatformKeyHandler, ToolKeyHandler
+from lti_consumer.lti_1p3 import exceptions
+from .utils import create_jwt
+
+
+@ddt.ddt
+class TestPlatformKeyHandler(TestCase):
+    """
+    Unit tests for PlatformKeyHandler
+    """
+    def setUp(self):
+        super(TestPlatformKeyHandler, self).setUp()
+
+        self.rsa_key_id = "1"
+        self.rsa_key = RSA.generate(2048).export_key('PEM')
+
+        # Set up consumer
+        self.key_handler = PlatformKeyHandler(
+            key_pem=self.rsa_key,
+            kid=self.rsa_key_id
+        )
+
+    def _decode_token(self, token):
+        """
+        Checks for a valid signarute and decodes JWT signed LTI message
+
+        This also touches the public keyset method.
+        """
+        public_keyset = self.key_handler.get_public_jwk()
+        key_set = load_jwks(json.dumps(public_keyset))
+
+        return JWS().verify_compact(token, keys=key_set)
+
+    def test_encode_and_sign(self):
+        """
+        Test if a message was correctly signed with RSA key.
+        """
+        message = {
+            "test": "test"
+        }
+        signed_token = self.key_handler.encode_and_sign(message)
+        self.assertEqual(
+            self._decode_token(signed_token),
+            message
+        )
+
+    # pylint: disable=unused-argument
+    @patch('time.time', return_value=1000)
+    def test_encode_and_sign_with_exp(self, mock_time):
+        """
+        Test if a message was correctly signed and has exp and iat parameters.
+        """
+        message = {
+            "test": "test"
+        }
+
+        signed_token = self.key_handler.encode_and_sign(
+            message,
+            expiration=1000
+        )
+
+        self.assertEqual(
+            self._decode_token(signed_token),
+            {
+                "test": "test",
+                "iat": 1000,
+                "exp": 2000
+            }
+        )
+
+    def test_invalid_rsa_key(self):
+        """
+        Check that class raises when trying to import invalid RSA Key.
+        """
+        with self.assertRaises(exceptions.InvalidRsaKey):
+            PlatformKeyHandler(key_pem="invalid PEM input")
+
+    def test_empty_rsa_key(self):
+        """
+        Check that class doesn't fail instancing when not using a key.
+        """
+        empty_key_handler = PlatformKeyHandler(key_pem='')
+
+        # Trying to encode a message should fail
+        with self.assertRaises(exceptions.RsaKeyNotSet):
+            empty_key_handler.encode_and_sign({})
+
+        # Public JWK should return an empty value
+        self.assertEqual(
+            empty_key_handler.get_public_jwk(),
+            {'keys': []}
+        )
+
+
+@ddt.ddt
+class TestToolKeyHandler(TestCase):
+    """
+    Unit tests for ToolKeyHandler
+    """
+    def setUp(self):
+        super(TestToolKeyHandler, self).setUp()
+
+        self.rsa_key_id = "1"
+
+        # Generate RSA and save exports
+        rsa_key = RSA.generate(2048)
+        self.key = RSAKey(
+            key=rsa_key,
+            kid=self.rsa_key_id
+        )
+        self.public_key = rsa_key.publickey().export_key()
+
+        # Key handler
+        self.key_handler = None
+
+    def _setup_key_handler(self):
+        """
+        Set up a instance of the key handler.
+        """
+        self.key_handler = ToolKeyHandler(public_key=self.public_key)
+
+    def test_import_rsa_key(self):
+        """
+        Check if the class is correctly instanced using a valid RSA key.
+        """
+        self._setup_key_handler()
+
+    def test_import_invalid_rsa_key(self):
+        """
+        Check if the class errors out when using a invalid RSA key.
+        """
+        with self.assertRaises(exceptions.InvalidRsaKey):
+            ToolKeyHandler(public_key="invalid-key")
+
+    def test_get_empty_keyset(self):
+        """
+        Test getting an empty keyset.
+        """
+        key_handler = ToolKeyHandler()
+
+        self.assertEqual(
+            # pylint: disable=protected-access
+            key_handler._get_keyset(),
+            []
+        )
+
+    def test_get_keyset_with_pub_key(self):
+        """
+        Check that getting a keyset from a RSA key.
+        """
+        self._setup_key_handler()
+
+        # pylint: disable=protected-access
+        keyset = self.key_handler._get_keyset(kid=self.rsa_key_id)
+        self.assertEqual(len(keyset), 1)
+        self.assertEqual(
+            keyset[0].kid,
+            self.rsa_key_id
+        )
+
+    # pylint: disable=unused-argument
+    @patch('time.time', return_value=1000)
+    def test_validate_and_decode(self, mock_time):
+        """
+        Check that the validate and decode works.
+        """
+        self._setup_key_handler()
+
+        message = {
+            "test": "test_message",
+            "iat": 1000,
+            "exp": 1200,
+        }
+        signed = create_jwt(self.key, message)
+
+        # Decode and check results
+        decoded_message = self.key_handler.validate_and_decode(signed)
+        self.assertEqual(decoded_message, message)
+
+    # pylint: disable=unused-argument
+    @patch('time.time', return_value=1000)
+    def test_validate_and_decode_expired(self, mock_time):
+        """
+        Check that the validate and decode raises when signature expires.
+        """
+        self._setup_key_handler()
+
+        message = {
+            "test": "test_message",
+            "iat": 900,
+            "exp": 910,
+        }
+        signed = create_jwt(self.key, message)
+
+        # Decode and check results
+        with self.assertRaises(exceptions.TokenSignatureExpired):
+            self.key_handler.validate_and_decode(signed)
+
+    def test_validate_and_decode_no_keys(self):
+        """
+        Check that the validate and decode raises when no keys are found.
+        """
+        key_handler = ToolKeyHandler()
+
+        message = {
+            "test": "test_message",
+            "iat": 900,
+            "exp": 910,
+        }
+        signed = create_jwt(self.key, message)
+
+        # Decode and check results
+        with self.assertRaises(exceptions.NoSuitableKeys):
+            key_handler.validate_and_decode(signed)
diff --git a/lti_consumer/lti_1p3/tests/utils.py b/lti_consumer/lti_1p3/tests/utils.py
new file mode 100644
index 0000000..3a76d16
--- /dev/null
+++ b/lti_consumer/lti_1p3/tests/utils.py
@@ -0,0 +1,12 @@
+"""
+Test utils
+"""
+from jwkest.jws import JWS
+
+
+def create_jwt(key, message):
+    """
+    Uses private key to create a JWS from a dict.
+    """
+    jws = JWS(message, alg="RS256", cty="JWT")
+    return jws.sign_compact([key])
diff --git a/lti_consumer/lti_consumer.py b/lti_consumer/lti_consumer.py
index 3f6f855..442ad3e 100644
--- a/lti_consumer/lti_consumer.py
+++ b/lti_consumer/lti_consumer.py
@@ -58,10 +58,9 @@ import uuid
 from collections import namedtuple
 from importlib import import_module
 
-import six.moves.urllib.error
-import six.moves.urllib.parse
-import six
 import bleach
+import six
+from six.moves.urllib import parse
 from Crypto.PublicKey import RSA
 from django.utils import timezone
 from webob import Response
@@ -74,12 +73,21 @@ from xblockutils.studio_editable import StudioEditableXBlockMixin
 
 from .exceptions import LtiError
 from .lti import LtiConsumer
+from .lti_1p3.exceptions import (
+    UnsupportedGrantType,
+    MalformedJwtToken,
+    MissingRequiredClaim,
+    NoSuitableKeys,
+    TokenSignatureExpired,
+    UnknownClientId,
+)
 from .lti_1p3.consumer import LtiConsumer1p3
 from .oauth import log_authorization_header
 from .outcomes import OutcomeService
 from .utils import (
     _,
     get_lms_base,
+    get_lms_lti_access_token_link,
     get_lms_lti_keyset_link,
     get_lms_lti_launch_link,
 )
@@ -306,6 +314,11 @@ class LtiConsumerXBlock(StudioEditableXBlockMixin, XBlock):
         default='',
         scope=Scope.settings
     )
+    lti_1p3_tool_public_key_id = String(
+        display_name=_("LTI 1.3 Tool Public Key ID"),
+        default='',
+        scope=Scope.settings
+    )
     # Client ID and block key
     lti_1p3_client_id = String(
         display_name=_("LTI 1.3 Block Client ID"),
@@ -492,6 +505,7 @@ class LtiConsumerXBlock(StudioEditableXBlockMixin, XBlock):
         'display_name', 'description',
         # LTI 1.3 variables
         'lti_version', 'lti_1p3_launch_url', 'lti_1p3_oidc_url', 'lti_1p3_tool_public_key',
+        'lti_1p3_tool_public_key_id',
         # LTI 1.1 variables
         'lti_id', 'launch_url',
         # Other parameters
@@ -801,8 +815,12 @@ class LtiConsumerXBlock(StudioEditableXBlockMixin, XBlock):
             lti_launch_url=self.lti_1p3_launch_url,
             client_id=self.lti_1p3_client_id,
             deployment_id="1",
+            # XBlock Private RSA Key
             rsa_key=self.lti_1p3_block_key,
-            rsa_key_id=self.lti_1p3_client_id
+            rsa_key_id=self.lti_1p3_client_id,
+            # LTI 1.3 Tool key/keyset url
+            tool_key=self.lti_1p3_tool_public_key,
+            tool_keyset_url=None,
         )
 
     def studio_view(self, context):
@@ -857,6 +875,7 @@ class LtiConsumerXBlock(StudioEditableXBlockMixin, XBlock):
             "deployment_id": "1",
             "keyset_url": get_lms_lti_keyset_link(self.location),  # pylint: disable=no-member
             "oidc_callback": get_lms_lti_launch_link(),
+            "token_url": get_lms_lti_access_token_link(self.location),  # pylint: disable=no-member
             "launch_url": self.lti_1p3_launch_url,
         }
         fragment.add_content(loader.render_mako_template('/templates/html/lti_1p3_studio.html', context))
@@ -997,6 +1016,66 @@ class LtiConsumerXBlock(StudioEditableXBlockMixin, XBlock):
             )
         return Response(status=404)
 
+    @XBlock.handler
+    def lti_1p3_access_token(self, request, suffix=''):  # pylint: disable=unused-argument
+        """
+        XBlock handler for creating access tokens for the LTI 1.3 tool.
+
+        This endpoint is only valid when a LTI 1.3 tool is being used.
+
+        Returns:
+            webob.response:
+                Either an access token or error message detailing the failure.
+                All responses are RFC 6749 compliant.
+
+        References:
+            Sucess: https://tools.ietf.org/html/rfc6749#section-4.4.3
+            Failure: https://tools.ietf.org/html/rfc6749#section-5.2
+        """
+        if self.lti_version != "lti_1p3":
+            return Response(status=404)
+        elif request.method != "POST":
+            return Response(status=405)
+
+        lti_consumer = self._get_lti1p3_consumer()
+        try:
+            token = lti_consumer.access_token(
+                dict(parse.parse_qsl(
+                    request.body.decode('utf-8'),
+                    keep_blank_values=True
+                ))
+            )
+            # The returned `token` is compliant with RFC 6749 so we just
+            # need to return a 200 OK response with the token as Json body
+            return Response(json_body=token, content_type="application/json")
+
+        # Handle errors and return a proper response
+        # pylint: disable=bare-except
+        except MissingRequiredClaim:
+            # Missing request attibutes
+            return Response(
+                json_body={"error": "invalid_request"},
+                status=400
+            )
+        except (MalformedJwtToken, TokenSignatureExpired):
+            # Triggered when a invalid grant token is used
+            return Response(
+                json_body={"error": "invalid_grant"},
+                status=400,
+            )
+        except (NoSuitableKeys, UnknownClientId):
+            # Client ID is not registered in the block or
+            # isn't possible to validate token using available keys.
+            return Response(
+                json_body={"error": "invalid_client"},
+                status=400,
+            )
+        except UnsupportedGrantType:
+            return Response(
+                json_body={"error": "unsupported_grant_type"},
+                status=400,
+            )
+
     @XBlock.handler
     def outcome_service_handler(self, request, suffix=''):  # pylint: disable=unused-argument
         """
diff --git a/lti_consumer/plugin/urls.py b/lti_consumer/plugin/urls.py
index 8bfc878..7b209e5 100644
--- a/lti_consumer/plugin/urls.py
+++ b/lti_consumer/plugin/urls.py
@@ -10,6 +10,7 @@ from django.conf.urls import url
 from .views import (
     public_keyset_endpoint,
     launch_gate_endpoint,
+    access_token_endpoint
 )
 
 
@@ -23,5 +24,10 @@ urlpatterns = [
         'lti_consumer/v1/launch/(?:/(?P<suffix>.*))?$',
         launch_gate_endpoint,
         name='lti_consumer.launch_gate'
+    ),
+    url(
+        'lti_consumer/v1/token/{}$'.format(settings.USAGE_ID_PATTERN),
+        access_token_endpoint,
+        name='lti_consumer.access_token'
     )
 ]
diff --git a/lti_consumer/plugin/views.py b/lti_consumer/plugin/views.py
index 14aa4c9..5d9312f 100644
--- a/lti_consumer/plugin/views.py
+++ b/lti_consumer/plugin/views.py
@@ -3,6 +3,8 @@ LTI consumer plugin passthrough views
 """
 
 from django.http import HttpResponse
+from django.views.decorators.csrf import csrf_exempt
+from django.views.decorators.http import require_http_methods
 
 from opaque_keys.edx.keys import UsageKey  # pylint: disable=import-error
 from lms.djangoapps.courseware.module_render import (  # pylint: disable=import-error
@@ -11,6 +13,7 @@ from lms.djangoapps.courseware.module_render import (  # pylint: disable=import-
 )
 
 
+@require_http_methods(["GET"])
 def public_keyset_endpoint(request, usage_id=None):
     """
     Gate endpoint to fetch public keysets from a problem
@@ -32,6 +35,7 @@ def public_keyset_endpoint(request, usage_id=None):
         return HttpResponse(status=404)
 
 
+@require_http_methods(["GET", "POST"])
 def launch_gate_endpoint(request, suffix):
     """
     Gate endpoint that triggers LTI launch endpoint XBlock handler
@@ -54,3 +58,22 @@ def launch_gate_endpoint(request, suffix):
         )
     except:  # pylint: disable=bare-except
         return HttpResponse(status=404)
+
+
+@csrf_exempt
+@require_http_methods(["POST"])
+def access_token_endpoint(request, usage_id=None):
+    """
+    Gate endpoint to enable tools to retrieve access tokens
+    """
+    try:
+        usage_key = UsageKey.from_string(usage_id)
+
+        return handle_xblock_callback_noauth(
+            request=request,
+            course_id=str(usage_key.course_key),
+            usage_id=str(usage_key),
+            handler='lti_1p3_access_token'
+        )
+    except:  # pylint: disable=bare-except
+        return HttpResponse(status=404)
diff --git a/lti_consumer/templates/html/lti_1p3_studio.html b/lti_consumer/templates/html/lti_1p3_studio.html
index e16c28a..2fe1f2f 100644
--- a/lti_consumer/templates/html/lti_1p3_studio.html
+++ b/lti_consumer/templates/html/lti_1p3_studio.html
@@ -29,8 +29,8 @@
         </p>
 
         <p>
-            <b>OAuth URL: </b>
-            N/A
+            <b>OAuth Token URL: </b>
+            ${token_url}
         </p>
 
         <p>
diff --git a/lti_consumer/tests/unit/test_lti_consumer.py b/lti_consumer/tests/unit/test_lti_consumer.py
index 39a0531..504838f 100644
--- a/lti_consumer/tests/unit/test_lti_consumer.py
+++ b/lti_consumer/tests/unit/test_lti_consumer.py
@@ -5,13 +5,16 @@ Unit tests for LtiConsumerXBlock
 from __future__ import absolute_import
 
 from datetime import timedelta
+import json
 import uuid
 
 import ddt
 import six
+from six.moves.urllib import parse
 from Crypto.PublicKey import RSA
 from django.test.testcases import TestCase
 from django.utils import timezone
+from jwkest.jwk import RSAKey
 from mock import Mock, PropertyMock, patch
 
 from lti_consumer.exceptions import LtiError
@@ -19,6 +22,8 @@ from lti_consumer.lti_consumer import LtiConsumerXBlock, parse_handler_suffix
 from lti_consumer.tests.unit import test_utils
 from lti_consumer.tests.unit.test_utils import (FAKE_USER_ID, make_request,
                                                 make_xblock)
+from lti_consumer.lti_1p3.tests.utils import create_jwt
+
 
 HTML_PROBLEM_PROGRESS = '<div class="problem-progress">'
 HTML_ERROR_MESSAGE = '<h3 class="error_message">'
@@ -435,6 +440,15 @@ class TestStudentView(TestLtiConsumerXBlock):
 
         self.assertNotIn(HTML_ERROR_MESSAGE, fragment.content)
 
+    def test_author_view(self):
+        """
+        Test that the `author_view` is the same as student view when using LTI 1.1.
+        """
+        self.assertEqual(
+            self.xblock.student_view({}).content,
+            self.xblock.author_view({}).content
+        )
+
 
 class TestLtiLaunchHandler(TestLtiConsumerXBlock):
     """
@@ -931,7 +945,6 @@ class TestLtiConsumer1p3XBlock(TestCase):
             'lti_1p3_client_id': '1',
             'lti_1p3_launch_url': 'http://tool.example/launch',
             'lti_1p3_oidc_url': 'http://tool.example/oidc',
-            'lti_1p3_tool_public_key': '',
             # We need to set the values below because they are not automatically
             # generated until the user selects `lti_version == 'lti_1p3'` on the
             # Studio configuration view.
@@ -1054,3 +1067,150 @@ class TestLtiConsumer1p3XBlock(TestCase):
         response = self.xblock.author_view({})
         self.assertIn(self.xblock.lti_1p3_client_id, response.content)
         self.assertIn("https://example.com", response.content)
+
+
+# pylint: disable=unused-argument
+@patch('lti_consumer.utils.get_lms_base', return_value="https://example.com")
+@patch('lti_consumer.lti_consumer.get_lms_base', return_value="https://example.com")
+class TestLti1p3AccessTokenEndpoint(TestCase):
+    """
+    Unit tests for LtiConsumerXBlock Access Token endpoint when using an LTI 1.3.
+    """
+    def setUp(self):
+        super(TestLti1p3AccessTokenEndpoint, self).setUp()
+
+        self.rsa_key_id = "1"
+        # Generate RSA and save exports
+        rsa_key = RSA.generate(2048)
+        self.key = RSAKey(
+            key=rsa_key,
+            kid=self.rsa_key_id
+        )
+        self.public_key = rsa_key.publickey().export_key()
+
+        self.xblock_attributes = {
+            'lti_version': 'lti_1p3',
+            'lti_1p3_launch_url': 'http://tool.example/launch',
+            'lti_1p3_oidc_url': 'http://tool.example/oidc',
+            # We need to set the values below because they are not automatically
+            # generated until the user selects `lti_version == 'lti_1p3'` on the
+            # Studio configuration view.
+            'lti_1p3_client_id': self.rsa_key_id,
+            'lti_1p3_block_key': rsa_key.export_key('PEM'),
+            # Use same key for tool key to make testing easier
+            'lti_1p3_tool_public_key': self.public_key,
+        }
+        self.xblock = make_xblock('lti_consumer', LtiConsumerXBlock, self.xblock_attributes)
+
+    def test_access_token_endpoint_when_using_lti_1p1(self, *args, **kwargs):
+        """
+        Test that the LTI 1.3 access token endpoind is unavailable when using 1.1.
+        """
+        self.xblock.lti_version = 'lti_1p1'
+        self.xblock.save()
+
+        request = make_request(json.dumps({}), 'POST')
+        request.content_type = 'application/json'
+
+        response = self.xblock.lti_1p3_access_token(request)
+        self.assertEqual(response.status_code, 404)
+
+    def test_access_token_endpoint_no_post(self, *args, **kwargs):
+        """
+        Test that the LTI 1.3 access token endpoind is unavailable when using 1.1.
+        """
+        request = make_request('', 'GET')
+
+        response = self.xblock.lti_1p3_access_token(request)
+        self.assertEqual(response.status_code, 405)
+
+    def test_access_token_missing_claims(self, *args, **kwargs):
+        """
+        Test request with missing parameters.
+        """
+        request = make_request(json.dumps({}), 'POST')
+        request.content_type = 'application/json'
+
+        response = self.xblock.lti_1p3_access_token(request)
+        self.assertEqual(response.status_code, 400)
+        self.assertEqual(response.json_body, {'error': 'invalid_request'})
+
+    def test_access_token_malformed(self, *args, **kwargs):
+        """
+        Test request with invalid JWT.
+        """
+        request = make_request(
+            parse.urlencode({
+                "grant_type": "client_credentials",
+                "client_assertion_type": "something",
+                "client_assertion": "invalid-jwt",
+                "scope": "",
+            }),
+            'POST'
+        )
+        request.content_type = 'application/x-www-form-urlencoded'
+
+        response = self.xblock.lti_1p3_access_token(request)
+        self.assertEqual(response.status_code, 400)
+        self.assertEqual(response.json_body, {'error': 'invalid_grant'})
+
+    def test_access_token_invalid_grant(self, *args, **kwargs):
+        """
+        Test request with invalid grant.
+        """
+        request = make_request(
+            parse.urlencode({
+                "grant_type": "password",
+                "client_assertion_type": "something",
+                "client_assertion": "invalit-jwt",
+                "scope": "",
+            }),
+            'POST'
+        )
+        request.content_type = 'application/x-www-form-urlencoded'
+
+        response = self.xblock.lti_1p3_access_token(request)
+        self.assertEqual(response.status_code, 400)
+        self.assertEqual(response.json_body, {'error': 'unsupported_grant_type'})
+
+    def test_access_token_invalid_client(self, *args, **kwargs):
+        """
+        Test request with valid JWT but no matching key to check signature.
+        """
+        self.xblock.lti_1p3_tool_public_key = ''
+        self.xblock.save()
+
+        jwt = create_jwt(self.key, {})
+        request = make_request(
+            parse.urlencode({
+                "grant_type": "client_credentials",
+                "client_assertion_type": "something",
+                "client_assertion": jwt,
+                "scope": "",
+            }),
+            'POST'
+        )
+        request.content_type = 'application/x-www-form-urlencoded'
+
+        response = self.xblock.lti_1p3_access_token(request)
+        self.assertEqual(response.status_code, 400)
+        self.assertEqual(response.json_body, {'error': 'invalid_client'})
+
+    def test_access_token(self, *args, **kwargs):
+        """
+        Test request with valid JWT.
+        """
+        jwt = create_jwt(self.key, {})
+        request = make_request(
+            parse.urlencode({
+                "grant_type": "client_credentials",
+                "client_assertion_type": "something",
+                "client_assertion": jwt,
+                "scope": "",
+            }),
+            'POST'
+        )
+        request.content_type = 'application/x-www-form-urlencoded'
+
+        response = self.xblock.lti_1p3_access_token(request)
+        self.assertEqual(response.status_code, 200)
diff --git a/lti_consumer/utils.py b/lti_consumer/utils.py
index 9e61613..3db2d5b 100644
--- a/lti_consumer/utils.py
+++ b/lti_consumer/utils.py
@@ -46,3 +46,15 @@ def get_lms_lti_launch_link():
     return u"{lms_base}/api/lti_consumer/v1/launch/".format(
         lms_base=get_lms_base(),
     )
+
+
+def get_lms_lti_access_token_link(location):
+    """
+    Returns an LMS link to LTI Launch endpoint
+
+    :param location: the location of the block
+    """
+    return u"{lms_base}/api/lti_consumer/v1/token/{location}".format(
+        lms_base=get_lms_base(),
+        location=text_type(location),
+    )
-- 
GitLab