From 209cf4597ab52a3e840905666d011ddea005b029 Mon Sep 17 00:00:00 2001
From: Paulo Viadanna <paulo@opencraft.com>
Date: Wed, 29 Apr 2020 11:20:59 -0300
Subject: [PATCH] BB-2332: Add preflight response validation

---
 lti_consumer/lti_1p3/consumer.py              | 23 +++++++++++++---
 .../tests/unit/test_lti_1p3_consumer.py       | 27 ++++++++++++++++++-
 lti_consumer/tests/unit/test_lti_consumer.py  |  4 ++-
 3 files changed, 49 insertions(+), 5 deletions(-)

diff --git a/lti_consumer/lti_1p3/consumer.py b/lti_consumer/lti_1p3/consumer.py
index 3ea9784..f5fb01c 100644
--- a/lti_consumer/lti_1p3/consumer.py
+++ b/lti_consumer/lti_1p3/consumer.py
@@ -5,7 +5,6 @@ import json
 import time
 
 # Quality checks failing due to know pylint bug
-# pylint: disable=relative-import
 from six.moves.urllib.parse import urlencode
 
 from Crypto.PublicKey import RSA
@@ -16,7 +15,7 @@ from jwkest import jwk
 from .constants import LTI_1P3_ROLE_MAP, LTI_BASE_MESSAGE
 
 
-class LtiConsumer1p3(object):
+class LtiConsumer1p3:
     """
     LTI 1.3 Consumer Implementation
     """
@@ -199,10 +198,12 @@ class LtiConsumer1p3(object):
         This will add all required parameters from the LTI 1.3 spec and any additional ones set in
         the configuration and JTW encode the message using the provided key.
         """
+        # Validate preflight response
+        self._validate_preflight_response(preflight_response)
+
         # Start from base message
         lti_message = LTI_BASE_MESSAGE.copy()
 
-        # TODO: Validate preflight response
         # Add base parameters
         lti_message.update({
             # Issuer
@@ -273,3 +274,19 @@ class LtiConsumer1p3(object):
         public_keys = jwk.KEYS()
         public_keys.append(self.jwk)
         return json.loads(public_keys.dump_jwks())
+
+    def _validate_preflight_response(self, response):
+        """
+        Validates a preflight response to be used in a launch request
+
+        Raises ValueError in case of validation failure
+
+        :param response: the preflight response to be validated
+        """
+        try:
+            assert response.get("nonce")
+            assert response.get("state")
+            assert response.get("client_id") == self.client_id
+            assert response.get("redirect_uri") == self.launch_url
+        except AssertionError:
+            raise ValueError("Preflight reponse failed validation")
diff --git a/lti_consumer/tests/unit/test_lti_1p3_consumer.py b/lti_consumer/tests/unit/test_lti_1p3_consumer.py
index d0d89bb..747daae 100644
--- a/lti_consumer/tests/unit/test_lti_1p3_consumer.py
+++ b/lti_consumer/tests/unit/test_lti_1p3_consumer.py
@@ -23,6 +23,8 @@ OIDC_URL = "http://test-platform/oidc"
 LAUNCH_URL = "http://test-platform/launch"
 CLIENT_ID = "1"
 DEPLOYMENT_ID = "1"
+NONCE = "1234"
+STATE = "ABCD"
 # Consider storing a fixed key
 RSA_KEY_ID = "1"
 RSA_KEY = RSA.generate(2048).export_key('PEM')
@@ -73,7 +75,12 @@ class TestLti1p3Consumer(TestCase):
         parameters, but allows overriding them.
         """
         if preflight_response is None:
-            preflight_response = {"nonce": "", "state": ""}
+            preflight_response = {
+                "client_id": CLIENT_ID,
+                "redirect_uri": LAUNCH_URL,
+                "nonce": NONCE,
+                "state": STATE
+            }
 
         return self.lti_consumer.generate_launch_request(
             preflight_response,
@@ -91,6 +98,22 @@ class TestLti1p3Consumer(TestCase):
 
         return JWS().verify_compact(token, keys=key_set)
 
+    @ddt.data(
+        ({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, True),
+        ({"client_id": "2", "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, False),
+        ({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL[::-1], "nonce": STATE, "state": STATE}, False),
+        ({"redirect_uri": LAUNCH_URL, "nonce": NONCE, "state": STATE}, False),
+        ({"client_id": CLIENT_ID, "nonce": NONCE, "state": STATE}, False),
+        ({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "state": STATE}, False),
+        ({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": NONCE}, False),
+    )
+    @ddt.unpack
+    def test_preflight_validation(self, preflight_response, success):
+        if success:
+            return self.lti_consumer._validate_preflight_response(preflight_response)  # pylint: disable=protected-access
+        with self.assertRaises(ValueError):
+            return self.lti_consumer._validate_preflight_response(preflight_response)  # pylint: disable=protected-access
+
     @ddt.data(
         (
             'student',
@@ -253,6 +276,8 @@ class TestLti1p3Consumer(TestCase):
         self._setup_lti_user()
         launch_request = self._get_lti_message(
             preflight_response={
+                "client_id": "1",
+                "redirect_uri": "http://test-platform/launch",
                 "nonce": "test",
                 "state": "state"
             },
diff --git a/lti_consumer/tests/unit/test_lti_consumer.py b/lti_consumer/tests/unit/test_lti_consumer.py
index a56613f..23b67bb 100644
--- a/lti_consumer/tests/unit/test_lti_consumer.py
+++ b/lti_consumer/tests/unit/test_lti_consumer.py
@@ -926,6 +926,7 @@ class TestLtiConsumer1p3XBlock(TestCase):
 
         self.xblock_attributes = {
             'lti_version': 'lti_1p3',
+            'lti_1p3_client_id': '1',
             'lti_1p3_launch_url': 'http://tool.example/launch',
             'lti_1p3_oidc_url': 'http://tool.example/oidc',
             'lti_1p3_tool_public_key': ''
@@ -957,7 +958,8 @@ class TestLtiConsumer1p3XBlock(TestCase):
 
         # Craft request sent back by LTI tool
         request = make_request('', 'GET')
-        request.query_string = "state=state_test_123&nonce=nonce&login_hint=oidchint&lti_message_hint=ltihint"
+        request.query_string = "client_id=1&redirect_uri=http://tool.example/launch&state=state_test_123&nonce=nonce\
+&login_hint=oidchint&lti_message_hint=ltihint"
 
         response = self.xblock.lti_1p3_launch_callback(request)
 
-- 
GitLab