diff --git a/lti_consumer/lti_1p3/README.md b/lti_consumer/lti_1p3/README.md new file mode 100644 index 0000000000000000000000000000000000000000..021c45dae517c9cb548eb22eb6d5f2a379e29a88 --- /dev/null +++ b/lti_consumer/lti_1p3/README.md @@ -0,0 +1,18 @@ +LTI 1.3 Consumer Class +- + +This implements a LTI 1.3 compliant consumer class which is request agnostic and can +be reused in different contexts (XBlock, Django plugin, and even on other frameworks). + +This doesn't implement any data storage, just the methods required for handling LTI messages +and Access Tokens. + +Features: +- LTI 1.3 Launch with full OIDC flow +- Support for custom parameters claim +- Support for launch presentation claim +- Access token creation + +This implementation was based on the following IMS Global Documents: +- LTI 1.3 Core Specification: http://www.imsglobal.org/spec/lti/v1p3/ +- IMS Global Security Framework: https://www.imsglobal.org/spec/security/v1p0/ diff --git a/lti_consumer/lti_1p3/constants.py b/lti_consumer/lti_1p3/constants.py index c0d93a0644a987a389242e5ad44dfcbcfb4fff97..d6b26e85cdaca677106b7c3dcd99010acbe9be84 100644 --- a/lti_consumer/lti_1p3/constants.py +++ b/lti_consumer/lti_1p3/constants.py @@ -33,3 +33,13 @@ LTI_1P3_ROLE_MAP = { 'http://purl.imsglobal.org/vocab/lis/v2/institution/person#Student' ], } + + +LTI_1P3_ACCESS_TOKEN_REQUIRED_CLAIMS = set([ + "grant_type", + "client_assertion_type", + "client_assertion", + "scope", +]) + +LTI_1P3_ACCESS_TOKEN_SCOPES = [] diff --git a/lti_consumer/lti_1p3/consumer.py b/lti_consumer/lti_1p3/consumer.py index f5fb01c0ce277d7e3b4deb7f3ec5e9b94661235b..11444ee40cb69809ce418745ced35f279be34a16 100644 --- a/lti_consumer/lti_1p3/consumer.py +++ b/lti_consumer/lti_1p3/consumer.py @@ -1,18 +1,17 @@ """ LTI 1.3 Consumer implementation """ -import json -import time - # Quality checks failing due to know pylint bug from six.moves.urllib.parse import urlencode -from Crypto.PublicKey import RSA -from jwkest.jwk import RSAKey -from jwkest.jws import JWS -from jwkest import jwk - -from .constants import LTI_1P3_ROLE_MAP, LTI_BASE_MESSAGE +from . import exceptions +from .constants import ( + LTI_1P3_ROLE_MAP, + LTI_BASE_MESSAGE, + LTI_1P3_ACCESS_TOKEN_REQUIRED_CLAIMS, + LTI_1P3_ACCESS_TOKEN_SCOPES, +) +from .key_handlers import ToolKeyHandler, PlatformKeyHandler class LtiConsumer1p3: @@ -27,7 +26,9 @@ class LtiConsumer1p3: client_id, deployment_id, rsa_key, - rsa_key_id + rsa_key_id, + tool_key=None, + tool_keyset_url=None, ): """ Initialize LTI 1.3 Consumer class @@ -38,14 +39,13 @@ class LtiConsumer1p3: self.client_id = client_id self.deployment_id = deployment_id - # Generate JWK from RSA key - self.jwk = RSAKey( - # Using the same key ID as client id - # This way we can easily serve multiple public - # keys on teh same endpoint and keep all - # LTI 1.3 blocks working - kid=rsa_key_id, - key=RSA.import_key(rsa_key) + # Set up platform message signature class + self.key_handler = PlatformKeyHandler(rsa_key, rsa_key_id) + + # Set up tool public key verification class + self.tool_jwt = ToolKeyHandler( + public_key=tool_key, + keyset_url=tool_keyset_url ) # IMS LTI Claim data @@ -53,17 +53,6 @@ class LtiConsumer1p3: self.lti_claim_launch_presentation = None self.lti_claim_custom_parameters = None - def _encode_and_sign(self, message): - """ - Encode and sign JSON with RSA key - """ - # The class instance that sets up the signing operation - # An RS 256 key is required for LTI 1.3 - _jws = JWS(message, alg="RS256", cty="JWT") - - # Encode and sign LTI message - return _jws.sign_compact([self.jwk]) - @staticmethod def _get_user_roles(role): """ @@ -256,15 +245,12 @@ class LtiConsumer1p3: if self.lti_claim_custom_parameters: lti_message.update(self.lti_claim_custom_parameters) - # Add `exp` and `iat` JWT attributes - lti_message.update({ - "iat": int(round(time.time())), - "exp": int(round(time.time()) + 3600) - }) - return { "state": preflight_response.get("state"), - "id_token": self._encode_and_sign(lti_message) + "id_token": self.key_handler.encode_and_sign( + message=lti_message, + expiration=300 + ) } def get_public_keyset(self): @@ -290,3 +276,72 @@ class LtiConsumer1p3: assert response.get("redirect_uri") == self.launch_url except AssertionError: raise ValueError("Preflight reponse failed validation") + + return self.key_handler.get_public_jwk() + + def access_token(self, token_request_data): + """ + Validate request and return JWT access token. + + This complies to IMS Security Framework and accepts a JWT + as a secret for the client credentials grant. + See this section: + https://www.imsglobal.org/spec/security/v1p0/#securing_web_services + + Full spec reference: + https://www.imsglobal.org/spec/security/v1p0/ + + Parameters: + token_request_data: Dict of parameters sent by LTI tool as form_data. + + Returns: + A dict containing the JSON response containing a JWT and some extra + parameters required by LTI tools. This token gives access to all + supported LTI Scopes from this tool. + """ + # Check if all required claims are present + for required_claim in LTI_1P3_ACCESS_TOKEN_REQUIRED_CLAIMS: + if required_claim not in token_request_data.keys(): + raise exceptions.MissingRequiredClaim() + + # Check that grant type is `client_credentials` + if token_request_data['grant_type'] != 'client_credentials': + raise exceptions.UnsupportedGrantType() + + # Validate JWT token + self.tool_jwt.validate_and_decode( + token_request_data['client_assertion'] + ) + + # Check scopes and only return valid and supported ones + valid_scopes = [] + requested_scopes = token_request_data['scope'].split(' ') + + for scope in requested_scopes: + # TODO: Add additional checks for permitted scopes + # Currently there are no scopes, because there is no use for + # these access tokens until a tool needs to access the LMS. + # LTI Advantage extensions make use of this. + if scope in LTI_1P3_ACCESS_TOKEN_SCOPES: + valid_scopes.append(scope) + + # Scopes are space separated as described in + # https://tools.ietf.org/html/rfc6749 + scopes_str = " ".join(valid_scopes) + + # This response is compliant with RFC 6749 + # https://tools.ietf.org/html/rfc6749#section-4.4.3 + return { + "access_token": self.key_handler.encode_and_sign( + { + "sub": self.client_id, + "scopes": scopes_str + }, + # Create token valid for 3600 seconds (1h) as per specification + # https://www.imsglobal.org/spec/security/v1p0/#expires_in-values-and-renewing-the-access-token + expiration=3600 + ), + "token_type": "bearer", + "expires_in": 3600, + "scope": scopes_str + } diff --git a/lti_consumer/lti_1p3/exceptions.py b/lti_consumer/lti_1p3/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..5f55bac84290c4327718146a86080622bf232827 --- /dev/null +++ b/lti_consumer/lti_1p3/exceptions.py @@ -0,0 +1,38 @@ +""" +Custom exceptions for LTI 1.3 consumer + +# TODO: Improve exception documentation and output. +""" +# pylint: disable=missing-docstring + + +class TokenSignatureExpired(Exception): + pass + + +class NoSuitableKeys(Exception): + pass + + +class UnknownClientId(Exception): + pass + + +class MalformedJwtToken(Exception): + pass + + +class MissingRequiredClaim(Exception): + pass + + +class UnsupportedGrantType(Exception): + pass + + +class InvalidRsaKey(Exception): + pass + + +class RsaKeyNotSet(Exception): + pass diff --git a/lti_consumer/lti_1p3/key_handlers.py b/lti_consumer/lti_1p3/key_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb582ef0c46c528c339dc9f5528b356598dee69 --- /dev/null +++ b/lti_consumer/lti_1p3/key_handlers.py @@ -0,0 +1,184 @@ +""" +LTI 1.3 - Access token library + +This handles validating messages sent by the tool and generating +access token with LTI scopes. +""" +import codecs +import copy +import time +import json + +from Crypto.PublicKey import RSA +from jwkest import BadSyntax, WrongNumberOfParts, jwk +from jwkest.jwk import RSAKey, load_jwks_from_url +from jwkest.jws import JWS, NoSuitableSigningKeys +from jwkest.jwt import JWT + +from . import exceptions + + +class ToolKeyHandler(object): + """ + LTI 1.3 Tool Jwt Handler. + + Uses a tool public keys or keysets URL to retrieve + a key and validate a message sent by the tool. + + This is primarily used by the Access Token endpoint + in order to validate the JWT Signature of messages + signed with the tools signature. + """ + def __init__(self, public_key=None, keyset_url=None): + """ + Instance message validator + + Import a public key from the tool by either using a keyset url + or a combination of public key + key id. + + Keyset URL takes precedence because it makes key rotation easier to do. + """ + # Only store keyset URL to avoid blocking the class + # instancing on an external url, which is only used + # when validating a token. + self.keyset_url = keyset_url + self.public_key = None + + # Import from public key + if public_key: + try: + new_key = RSAKey(use='sig') + + # Unescape key before importing it + raw_key = codecs.decode(public_key, 'unicode_escape') + + # Import Key and save to internal state + new_key.load_key(RSA.import_key(raw_key)) + self.public_key = new_key + except ValueError: + raise exceptions.InvalidRsaKey() + + def _get_keyset(self, kid=None): + """ + Get keyset from available sources. + + If using a RSA key, forcefully set the key id + to match the one from the JWT token. + """ + keyset = [] + + if self.keyset_url: + # TODO: Improve support for keyset handling, handle errors. + keyset.extend(load_jwks_from_url(self.keyset_url)) + + if self.public_key and kid: + # Fill in key id of stored key. + # This is needed because if the JWS is signed with a + # key with a kid, pyjwkest doesn't match them with + # keys without kid (kid=None) and fails verification + self.public_key.kid = kid + + # Add to keyset + keyset.append(self.public_key) + + return keyset + + def validate_and_decode(self, token): + """ + Check if a message sent by the tool is valid. + + From https://www.imsglobal.org/spec/security/v1p0/#using-oauth-2-0-client-credentials-grant: + + The authorization server decodes the JWT and MUST validate the values for the + iss, sub, exp, aud and jti claims. + """ + try: + # Get KID from JWT header + jwt = JWT().unpack(token) + + # Verify message signature + message = JWS().verify_compact( + token, + keys=self._get_keyset( + jwt.headers.get('kid') + ) + ) + + # If message is valid, check expiration from JWT + if 'exp' in message and message['exp'] < time.time(): + raise exceptions.TokenSignatureExpired() + + # TODO: Validate other JWT claims + + # Else returns decoded message + return message + + except NoSuitableSigningKeys: + raise exceptions.NoSuitableKeys() + except BadSyntax: + raise exceptions.MalformedJwtToken() + except WrongNumberOfParts: + raise exceptions.MalformedJwtToken() + + +class PlatformKeyHandler(object): + """ + Platform RSA Key handler. + + This class loads the platform key and is responsible for + encoding JWT messages and exporting public keys. + """ + def __init__(self, key_pem, kid=None): + """ + Import Key when instancing class if a key is present. + """ + self.key = None + + if key_pem: + # Import JWK from RSA key + try: + self.key = RSAKey( + # Using the same key ID as client id + # This way we can easily serve multiple public + # keys on teh same endpoint and keep all + # LTI 1.3 blocks working + kid=kid, + key=RSA.import_key(key_pem) + ) + except ValueError: + raise exceptions.InvalidRsaKey() + + def encode_and_sign(self, message, expiration=None): + """ + Encode and sign JSON with RSA key + """ + if not self.key: + raise exceptions.RsaKeyNotSet() + + _message = copy.deepcopy(message) + + # Set iat and exp if expiration is set + if expiration: + _message.update({ + "iat": int(round(time.time())), + "exp": int(round(time.time()) + expiration), + }) + + # The class instance that sets up the signing operation + # An RS 256 key is required for LTI 1.3 + _jws = JWS(_message, alg="RS256", cty="JWT") + + # Encode and sign LTI message + return _jws.sign_compact([self.key]) + + def get_public_jwk(self): + """ + Export Public JWK + """ + public_keys = jwk.KEYS() + + # Only append to keyset if a key exists + if self.key: + public_keys.append(self.key) + + return json.loads(public_keys.dump_jwks()) diff --git a/lti_consumer/lti_1p3/tests/__init__.py b/lti_consumer/lti_1p3/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lti_consumer/tests/unit/test_lti_1p3_consumer.py b/lti_consumer/lti_1p3/tests/test_consumer.py similarity index 82% rename from lti_consumer/tests/unit/test_lti_1p3_consumer.py rename to lti_consumer/lti_1p3/tests/test_consumer.py index 747daae0972b88b8e6388d9c8db734a2fa9bebea..e3b0e7e9b55ebb799ee72cdfe8d0cfaf5687166f 100644 --- a/lti_consumer/tests/unit/test_lti_1p3_consumer.py +++ b/lti_consumer/lti_1p3/tests/test_consumer.py @@ -15,6 +15,7 @@ from jwkest.jwk import load_jwks from jwkest.jws import JWS from lti_consumer.lti_1p3.consumer import LtiConsumer1p3 +from lti_consumer.lti_1p3 import exceptions # Variables required for testing and verification @@ -47,7 +48,9 @@ class TestLti1p3Consumer(TestCase): client_id=CLIENT_ID, deployment_id=DEPLOYMENT_ID, rsa_key=RSA_KEY, - rsa_key_id=RSA_KEY_ID + rsa_key_id=RSA_KEY_ID, + # Use the same key for testing purposes + tool_key=RSA_KEY ) def _setup_lti_user(self): @@ -321,3 +324,63 @@ class TestLti1p3Consumer(TestCase): """ with self.assertRaises(ValueError): self.lti_consumer.set_custom_parameters("invalid") + + def test_access_token_missing_params(self): + """ + Check if access token with missing request data raises. + """ + with self.assertRaises(exceptions.MissingRequiredClaim): + self.lti_consumer.access_token({}) + + def test_access_token_invalid_jwt(self): + """ + Check if access token with invalid request data raises. + """ + request_data = { + "grant_type": "client_credentials", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + # This should be a valid JWT + "client_assertion": "invalid-jwt", + # Scope can be empty + "scope": "", + } + + with self.assertRaises(exceptions.MalformedJwtToken): + self.lti_consumer.access_token(request_data) + + def test_access_token(self): + """ + Check if a valid access token is returned. + + Since we're using the same key for both tool and + platform here, we can make use of the internal + _decode_token to check validity. + """ + # Generate a dummy, but valid JWT + token = self.lti_consumer.key_handler.encode_and_sign( + { + "test": "test" + }, + expiration=1000 + ) + + request_data = { + # We don't actually care about these 2 first values + "grant_type": "client_credentials", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + # This should be a valid JWT + "client_assertion": token, + # Scope can be empty + "scope": "", + } + + response = self.lti_consumer.access_token(request_data) + + # Check response contents + self.assertIn('access_token', response) + self.assertEqual(response.get('token_type'), 'bearer') + self.assertEqual(response.get('expires_in'), 3600) + self.assertEqual(response.get('scope'), '') + + # Check if token is valid + self._decode_token(response.get('access_token')) diff --git a/lti_consumer/lti_1p3/tests/test_key_handlers.py b/lti_consumer/lti_1p3/tests/test_key_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..86d696c766f5da4ef3fa2206f8b3cd53d6d883e4 --- /dev/null +++ b/lti_consumer/lti_1p3/tests/test_key_handlers.py @@ -0,0 +1,229 @@ +""" +Unit tests for LTI 1.3 consumer implementation +""" +from __future__ import absolute_import, unicode_literals + +import json +import ddt + +from mock import patch +from django.test.testcases import TestCase + +from Crypto.PublicKey import RSA +from jwkest.jwk import RSAKey, load_jwks +from jwkest.jws import JWS + +from lti_consumer.lti_1p3.key_handlers import PlatformKeyHandler, ToolKeyHandler +from lti_consumer.lti_1p3 import exceptions +from .utils import create_jwt + + +@ddt.ddt +class TestPlatformKeyHandler(TestCase): + """ + Unit tests for PlatformKeyHandler + """ + def setUp(self): + super(TestPlatformKeyHandler, self).setUp() + + self.rsa_key_id = "1" + self.rsa_key = RSA.generate(2048).export_key('PEM') + + # Set up consumer + self.key_handler = PlatformKeyHandler( + key_pem=self.rsa_key, + kid=self.rsa_key_id + ) + + def _decode_token(self, token): + """ + Checks for a valid signarute and decodes JWT signed LTI message + + This also touches the public keyset method. + """ + public_keyset = self.key_handler.get_public_jwk() + key_set = load_jwks(json.dumps(public_keyset)) + + return JWS().verify_compact(token, keys=key_set) + + def test_encode_and_sign(self): + """ + Test if a message was correctly signed with RSA key. + """ + message = { + "test": "test" + } + signed_token = self.key_handler.encode_and_sign(message) + self.assertEqual( + self._decode_token(signed_token), + message + ) + + # pylint: disable=unused-argument + @patch('time.time', return_value=1000) + def test_encode_and_sign_with_exp(self, mock_time): + """ + Test if a message was correctly signed and has exp and iat parameters. + """ + message = { + "test": "test" + } + + signed_token = self.key_handler.encode_and_sign( + message, + expiration=1000 + ) + + self.assertEqual( + self._decode_token(signed_token), + { + "test": "test", + "iat": 1000, + "exp": 2000 + } + ) + + def test_invalid_rsa_key(self): + """ + Check that class raises when trying to import invalid RSA Key. + """ + with self.assertRaises(exceptions.InvalidRsaKey): + PlatformKeyHandler(key_pem="invalid PEM input") + + def test_empty_rsa_key(self): + """ + Check that class doesn't fail instancing when not using a key. + """ + empty_key_handler = PlatformKeyHandler(key_pem='') + + # Trying to encode a message should fail + with self.assertRaises(exceptions.RsaKeyNotSet): + empty_key_handler.encode_and_sign({}) + + # Public JWK should return an empty value + self.assertEqual( + empty_key_handler.get_public_jwk(), + {'keys': []} + ) + + +@ddt.ddt +class TestToolKeyHandler(TestCase): + """ + Unit tests for ToolKeyHandler + """ + def setUp(self): + super(TestToolKeyHandler, self).setUp() + + self.rsa_key_id = "1" + + # Generate RSA and save exports + rsa_key = RSA.generate(2048) + self.key = RSAKey( + key=rsa_key, + kid=self.rsa_key_id + ) + self.public_key = rsa_key.publickey().export_key() + + # Key handler + self.key_handler = None + + def _setup_key_handler(self): + """ + Set up a instance of the key handler. + """ + self.key_handler = ToolKeyHandler(public_key=self.public_key) + + def test_import_rsa_key(self): + """ + Check if the class is correctly instanced using a valid RSA key. + """ + self._setup_key_handler() + + def test_import_invalid_rsa_key(self): + """ + Check if the class errors out when using a invalid RSA key. + """ + with self.assertRaises(exceptions.InvalidRsaKey): + ToolKeyHandler(public_key="invalid-key") + + def test_get_empty_keyset(self): + """ + Test getting an empty keyset. + """ + key_handler = ToolKeyHandler() + + self.assertEqual( + # pylint: disable=protected-access + key_handler._get_keyset(), + [] + ) + + def test_get_keyset_with_pub_key(self): + """ + Check that getting a keyset from a RSA key. + """ + self._setup_key_handler() + + # pylint: disable=protected-access + keyset = self.key_handler._get_keyset(kid=self.rsa_key_id) + self.assertEqual(len(keyset), 1) + self.assertEqual( + keyset[0].kid, + self.rsa_key_id + ) + + # pylint: disable=unused-argument + @patch('time.time', return_value=1000) + def test_validate_and_decode(self, mock_time): + """ + Check that the validate and decode works. + """ + self._setup_key_handler() + + message = { + "test": "test_message", + "iat": 1000, + "exp": 1200, + } + signed = create_jwt(self.key, message) + + # Decode and check results + decoded_message = self.key_handler.validate_and_decode(signed) + self.assertEqual(decoded_message, message) + + # pylint: disable=unused-argument + @patch('time.time', return_value=1000) + def test_validate_and_decode_expired(self, mock_time): + """ + Check that the validate and decode raises when signature expires. + """ + self._setup_key_handler() + + message = { + "test": "test_message", + "iat": 900, + "exp": 910, + } + signed = create_jwt(self.key, message) + + # Decode and check results + with self.assertRaises(exceptions.TokenSignatureExpired): + self.key_handler.validate_and_decode(signed) + + def test_validate_and_decode_no_keys(self): + """ + Check that the validate and decode raises when no keys are found. + """ + key_handler = ToolKeyHandler() + + message = { + "test": "test_message", + "iat": 900, + "exp": 910, + } + signed = create_jwt(self.key, message) + + # Decode and check results + with self.assertRaises(exceptions.NoSuitableKeys): + key_handler.validate_and_decode(signed) diff --git a/lti_consumer/lti_1p3/tests/utils.py b/lti_consumer/lti_1p3/tests/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3a76d162b27a4689045db48096b62d9f332a9a4d --- /dev/null +++ b/lti_consumer/lti_1p3/tests/utils.py @@ -0,0 +1,12 @@ +""" +Test utils +""" +from jwkest.jws import JWS + + +def create_jwt(key, message): + """ + Uses private key to create a JWS from a dict. + """ + jws = JWS(message, alg="RS256", cty="JWT") + return jws.sign_compact([key]) diff --git a/lti_consumer/lti_consumer.py b/lti_consumer/lti_consumer.py index 3f6f855039278ebc82cb7d16890940e4eba67771..442ad3e11098d1faf1ada747d67c0b7c23fbcfa3 100644 --- a/lti_consumer/lti_consumer.py +++ b/lti_consumer/lti_consumer.py @@ -58,10 +58,9 @@ import uuid from collections import namedtuple from importlib import import_module -import six.moves.urllib.error -import six.moves.urllib.parse -import six import bleach +import six +from six.moves.urllib import parse from Crypto.PublicKey import RSA from django.utils import timezone from webob import Response @@ -74,12 +73,21 @@ from xblockutils.studio_editable import StudioEditableXBlockMixin from .exceptions import LtiError from .lti import LtiConsumer +from .lti_1p3.exceptions import ( + UnsupportedGrantType, + MalformedJwtToken, + MissingRequiredClaim, + NoSuitableKeys, + TokenSignatureExpired, + UnknownClientId, +) from .lti_1p3.consumer import LtiConsumer1p3 from .oauth import log_authorization_header from .outcomes import OutcomeService from .utils import ( _, get_lms_base, + get_lms_lti_access_token_link, get_lms_lti_keyset_link, get_lms_lti_launch_link, ) @@ -306,6 +314,11 @@ class LtiConsumerXBlock(StudioEditableXBlockMixin, XBlock): default='', scope=Scope.settings ) + lti_1p3_tool_public_key_id = String( + display_name=_("LTI 1.3 Tool Public Key ID"), + default='', + scope=Scope.settings + ) # Client ID and block key lti_1p3_client_id = String( display_name=_("LTI 1.3 Block Client ID"), @@ -492,6 +505,7 @@ class LtiConsumerXBlock(StudioEditableXBlockMixin, XBlock): 'display_name', 'description', # LTI 1.3 variables 'lti_version', 'lti_1p3_launch_url', 'lti_1p3_oidc_url', 'lti_1p3_tool_public_key', + 'lti_1p3_tool_public_key_id', # LTI 1.1 variables 'lti_id', 'launch_url', # Other parameters @@ -801,8 +815,12 @@ class LtiConsumerXBlock(StudioEditableXBlockMixin, XBlock): lti_launch_url=self.lti_1p3_launch_url, client_id=self.lti_1p3_client_id, deployment_id="1", + # XBlock Private RSA Key rsa_key=self.lti_1p3_block_key, - rsa_key_id=self.lti_1p3_client_id + rsa_key_id=self.lti_1p3_client_id, + # LTI 1.3 Tool key/keyset url + tool_key=self.lti_1p3_tool_public_key, + tool_keyset_url=None, ) def studio_view(self, context): @@ -857,6 +875,7 @@ class LtiConsumerXBlock(StudioEditableXBlockMixin, XBlock): "deployment_id": "1", "keyset_url": get_lms_lti_keyset_link(self.location), # pylint: disable=no-member "oidc_callback": get_lms_lti_launch_link(), + "token_url": get_lms_lti_access_token_link(self.location), # pylint: disable=no-member "launch_url": self.lti_1p3_launch_url, } fragment.add_content(loader.render_mako_template('/templates/html/lti_1p3_studio.html', context)) @@ -997,6 +1016,66 @@ class LtiConsumerXBlock(StudioEditableXBlockMixin, XBlock): ) return Response(status=404) + @XBlock.handler + def lti_1p3_access_token(self, request, suffix=''): # pylint: disable=unused-argument + """ + XBlock handler for creating access tokens for the LTI 1.3 tool. + + This endpoint is only valid when a LTI 1.3 tool is being used. + + Returns: + webob.response: + Either an access token or error message detailing the failure. + All responses are RFC 6749 compliant. + + References: + Sucess: https://tools.ietf.org/html/rfc6749#section-4.4.3 + Failure: https://tools.ietf.org/html/rfc6749#section-5.2 + """ + if self.lti_version != "lti_1p3": + return Response(status=404) + elif request.method != "POST": + return Response(status=405) + + lti_consumer = self._get_lti1p3_consumer() + try: + token = lti_consumer.access_token( + dict(parse.parse_qsl( + request.body.decode('utf-8'), + keep_blank_values=True + )) + ) + # The returned `token` is compliant with RFC 6749 so we just + # need to return a 200 OK response with the token as Json body + return Response(json_body=token, content_type="application/json") + + # Handle errors and return a proper response + # pylint: disable=bare-except + except MissingRequiredClaim: + # Missing request attibutes + return Response( + json_body={"error": "invalid_request"}, + status=400 + ) + except (MalformedJwtToken, TokenSignatureExpired): + # Triggered when a invalid grant token is used + return Response( + json_body={"error": "invalid_grant"}, + status=400, + ) + except (NoSuitableKeys, UnknownClientId): + # Client ID is not registered in the block or + # isn't possible to validate token using available keys. + return Response( + json_body={"error": "invalid_client"}, + status=400, + ) + except UnsupportedGrantType: + return Response( + json_body={"error": "unsupported_grant_type"}, + status=400, + ) + @XBlock.handler def outcome_service_handler(self, request, suffix=''): # pylint: disable=unused-argument """ diff --git a/lti_consumer/plugin/urls.py b/lti_consumer/plugin/urls.py index 8bfc8789e861e63d09aaec1d5ba4560ca5f972b4..7b209e512d365c5dda702e978aaaf8b7d2eaf507 100644 --- a/lti_consumer/plugin/urls.py +++ b/lti_consumer/plugin/urls.py @@ -10,6 +10,7 @@ from django.conf.urls import url from .views import ( public_keyset_endpoint, launch_gate_endpoint, + access_token_endpoint ) @@ -23,5 +24,10 @@ urlpatterns = [ 'lti_consumer/v1/launch/(?:/(?P<suffix>.*))?$', launch_gate_endpoint, name='lti_consumer.launch_gate' + ), + url( + 'lti_consumer/v1/token/{}$'.format(settings.USAGE_ID_PATTERN), + access_token_endpoint, + name='lti_consumer.access_token' ) ] diff --git a/lti_consumer/plugin/views.py b/lti_consumer/plugin/views.py index 14aa4c977cd1cd3d01087a880a28011c9fde1458..5d9312f35a253a272e4f683e7e15ef93b2e9925b 100644 --- a/lti_consumer/plugin/views.py +++ b/lti_consumer/plugin/views.py @@ -3,6 +3,8 @@ LTI consumer plugin passthrough views """ from django.http import HttpResponse +from django.views.decorators.csrf import csrf_exempt +from django.views.decorators.http import require_http_methods from opaque_keys.edx.keys import UsageKey # pylint: disable=import-error from lms.djangoapps.courseware.module_render import ( # pylint: disable=import-error @@ -11,6 +13,7 @@ from lms.djangoapps.courseware.module_render import ( # pylint: disable=import- ) +@require_http_methods(["GET"]) def public_keyset_endpoint(request, usage_id=None): """ Gate endpoint to fetch public keysets from a problem @@ -32,6 +35,7 @@ def public_keyset_endpoint(request, usage_id=None): return HttpResponse(status=404) +@require_http_methods(["GET", "POST"]) def launch_gate_endpoint(request, suffix): """ Gate endpoint that triggers LTI launch endpoint XBlock handler @@ -54,3 +58,22 @@ def launch_gate_endpoint(request, suffix): ) except: # pylint: disable=bare-except return HttpResponse(status=404) + + +@csrf_exempt +@require_http_methods(["POST"]) +def access_token_endpoint(request, usage_id=None): + """ + Gate endpoint to enable tools to retrieve access tokens + """ + try: + usage_key = UsageKey.from_string(usage_id) + + return handle_xblock_callback_noauth( + request=request, + course_id=str(usage_key.course_key), + usage_id=str(usage_key), + handler='lti_1p3_access_token' + ) + except: # pylint: disable=bare-except + return HttpResponse(status=404) diff --git a/lti_consumer/templates/html/lti_1p3_studio.html b/lti_consumer/templates/html/lti_1p3_studio.html index e16c28ad63d6f19080d048282f4ba4d1a37a2729..2fe1f2fb43e99353e8d300e24e6ef1e46295a9ff 100644 --- a/lti_consumer/templates/html/lti_1p3_studio.html +++ b/lti_consumer/templates/html/lti_1p3_studio.html @@ -29,8 +29,8 @@ </p> <p> - <b>OAuth URL: </b> - N/A + <b>OAuth Token URL: </b> + ${token_url} </p> <p> diff --git a/lti_consumer/tests/unit/test_lti_consumer.py b/lti_consumer/tests/unit/test_lti_consumer.py index 39a05317b963d34e48f53de97ec8dd99d59373b0..504838f80e15c1cd62001ff992951e5889dd513b 100644 --- a/lti_consumer/tests/unit/test_lti_consumer.py +++ b/lti_consumer/tests/unit/test_lti_consumer.py @@ -5,13 +5,16 @@ Unit tests for LtiConsumerXBlock from __future__ import absolute_import from datetime import timedelta +import json import uuid import ddt import six +from six.moves.urllib import parse from Crypto.PublicKey import RSA from django.test.testcases import TestCase from django.utils import timezone +from jwkest.jwk import RSAKey from mock import Mock, PropertyMock, patch from lti_consumer.exceptions import LtiError @@ -19,6 +22,8 @@ from lti_consumer.lti_consumer import LtiConsumerXBlock, parse_handler_suffix from lti_consumer.tests.unit import test_utils from lti_consumer.tests.unit.test_utils import (FAKE_USER_ID, make_request, make_xblock) +from lti_consumer.lti_1p3.tests.utils import create_jwt + HTML_PROBLEM_PROGRESS = '<div class="problem-progress">' HTML_ERROR_MESSAGE = '<h3 class="error_message">' @@ -435,6 +440,15 @@ class TestStudentView(TestLtiConsumerXBlock): self.assertNotIn(HTML_ERROR_MESSAGE, fragment.content) + def test_author_view(self): + """ + Test that the `author_view` is the same as student view when using LTI 1.1. + """ + self.assertEqual( + self.xblock.student_view({}).content, + self.xblock.author_view({}).content + ) + class TestLtiLaunchHandler(TestLtiConsumerXBlock): """ @@ -931,7 +945,6 @@ class TestLtiConsumer1p3XBlock(TestCase): 'lti_1p3_client_id': '1', 'lti_1p3_launch_url': 'http://tool.example/launch', 'lti_1p3_oidc_url': 'http://tool.example/oidc', - 'lti_1p3_tool_public_key': '', # We need to set the values below because they are not automatically # generated until the user selects `lti_version == 'lti_1p3'` on the # Studio configuration view. @@ -1054,3 +1067,150 @@ class TestLtiConsumer1p3XBlock(TestCase): response = self.xblock.author_view({}) self.assertIn(self.xblock.lti_1p3_client_id, response.content) self.assertIn("https://example.com", response.content) + + +# pylint: disable=unused-argument +@patch('lti_consumer.utils.get_lms_base', return_value="https://example.com") +@patch('lti_consumer.lti_consumer.get_lms_base', return_value="https://example.com") +class TestLti1p3AccessTokenEndpoint(TestCase): + """ + Unit tests for LtiConsumerXBlock Access Token endpoint when using an LTI 1.3. + """ + def setUp(self): + super(TestLti1p3AccessTokenEndpoint, self).setUp() + + self.rsa_key_id = "1" + # Generate RSA and save exports + rsa_key = RSA.generate(2048) + self.key = RSAKey( + key=rsa_key, + kid=self.rsa_key_id + ) + self.public_key = rsa_key.publickey().export_key() + + self.xblock_attributes = { + 'lti_version': 'lti_1p3', + 'lti_1p3_launch_url': 'http://tool.example/launch', + 'lti_1p3_oidc_url': 'http://tool.example/oidc', + # We need to set the values below because they are not automatically + # generated until the user selects `lti_version == 'lti_1p3'` on the + # Studio configuration view. + 'lti_1p3_client_id': self.rsa_key_id, + 'lti_1p3_block_key': rsa_key.export_key('PEM'), + # Use same key for tool key to make testing easier + 'lti_1p3_tool_public_key': self.public_key, + } + self.xblock = make_xblock('lti_consumer', LtiConsumerXBlock, self.xblock_attributes) + + def test_access_token_endpoint_when_using_lti_1p1(self, *args, **kwargs): + """ + Test that the LTI 1.3 access token endpoind is unavailable when using 1.1. + """ + self.xblock.lti_version = 'lti_1p1' + self.xblock.save() + + request = make_request(json.dumps({}), 'POST') + request.content_type = 'application/json' + + response = self.xblock.lti_1p3_access_token(request) + self.assertEqual(response.status_code, 404) + + def test_access_token_endpoint_no_post(self, *args, **kwargs): + """ + Test that the LTI 1.3 access token endpoind is unavailable when using 1.1. + """ + request = make_request('', 'GET') + + response = self.xblock.lti_1p3_access_token(request) + self.assertEqual(response.status_code, 405) + + def test_access_token_missing_claims(self, *args, **kwargs): + """ + Test request with missing parameters. + """ + request = make_request(json.dumps({}), 'POST') + request.content_type = 'application/json' + + response = self.xblock.lti_1p3_access_token(request) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.json_body, {'error': 'invalid_request'}) + + def test_access_token_malformed(self, *args, **kwargs): + """ + Test request with invalid JWT. + """ + request = make_request( + parse.urlencode({ + "grant_type": "client_credentials", + "client_assertion_type": "something", + "client_assertion": "invalid-jwt", + "scope": "", + }), + 'POST' + ) + request.content_type = 'application/x-www-form-urlencoded' + + response = self.xblock.lti_1p3_access_token(request) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.json_body, {'error': 'invalid_grant'}) + + def test_access_token_invalid_grant(self, *args, **kwargs): + """ + Test request with invalid grant. + """ + request = make_request( + parse.urlencode({ + "grant_type": "password", + "client_assertion_type": "something", + "client_assertion": "invalit-jwt", + "scope": "", + }), + 'POST' + ) + request.content_type = 'application/x-www-form-urlencoded' + + response = self.xblock.lti_1p3_access_token(request) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.json_body, {'error': 'unsupported_grant_type'}) + + def test_access_token_invalid_client(self, *args, **kwargs): + """ + Test request with valid JWT but no matching key to check signature. + """ + self.xblock.lti_1p3_tool_public_key = '' + self.xblock.save() + + jwt = create_jwt(self.key, {}) + request = make_request( + parse.urlencode({ + "grant_type": "client_credentials", + "client_assertion_type": "something", + "client_assertion": jwt, + "scope": "", + }), + 'POST' + ) + request.content_type = 'application/x-www-form-urlencoded' + + response = self.xblock.lti_1p3_access_token(request) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.json_body, {'error': 'invalid_client'}) + + def test_access_token(self, *args, **kwargs): + """ + Test request with valid JWT. + """ + jwt = create_jwt(self.key, {}) + request = make_request( + parse.urlencode({ + "grant_type": "client_credentials", + "client_assertion_type": "something", + "client_assertion": jwt, + "scope": "", + }), + 'POST' + ) + request.content_type = 'application/x-www-form-urlencoded' + + response = self.xblock.lti_1p3_access_token(request) + self.assertEqual(response.status_code, 200) diff --git a/lti_consumer/utils.py b/lti_consumer/utils.py index 9e61613e58d4a56efd67d238ab3e2f0f65878630..3db2d5bd61a71c69f4cd326b81a13a5208723e0d 100644 --- a/lti_consumer/utils.py +++ b/lti_consumer/utils.py @@ -46,3 +46,15 @@ def get_lms_lti_launch_link(): return u"{lms_base}/api/lti_consumer/v1/launch/".format( lms_base=get_lms_base(), ) + + +def get_lms_lti_access_token_link(location): + """ + Returns an LMS link to LTI Launch endpoint + + :param location: the location of the block + """ + return u"{lms_base}/api/lti_consumer/v1/token/{location}".format( + lms_base=get_lms_base(), + location=text_type(location), + )