# Copyright (C) 2009, 2010 Canonical Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from unittest import TestCase

from apache_openid import get_action_path
from apache_openid.handlers.openid.handler import OpenIDLoginHandler
from apache_openid.request import Request
from apache_openid.response import Response
from apache_openid.utils import DONE, HTTP_FORBIDDEN, OK, SERVER_RETURN
from apache_openid.utils.mock import ApacheMockRequest, Options, Session


class HandlerTestCase(TestCase):

    def setUp(self):
        self.apache_request = ApacheMockRequest()
        self.overrides = {
            'action-path': "/openid/",
        }
        self.options = Options(self.overrides)
        self.action_path = get_action_path(self.options, self.apache_request)
        self.session = Session()
        self.request = Request(
            self.apache_request, self.action_path, self.session)
        self.response = Response(
            self.request, self.action_path, self.session)
        self.handler = OpenIDLoginHandler(
            self.request, self.response, self.options, None, None)

    def assertRaisesMessage(self, exc, call, msg):
        try:
            result = call()
        except Exception, e:
            self.assertEqual(isinstance(e, exc), True)
            self.assertEqual(e.message, msg)
        else:
            self.fail("No exception raised. Got a result instead: %s" % result)


class OpenIDLoginHandlerTest(HandlerTestCase):

    def test_authenticator(self):
        self.assertEqual(getattr(self.handler, '_auth', None), None)
        authenticator = self.handler.authenticator
        self.assertNotEqual(getattr(self.handler, '_auth', None), None)

    def test_consumer(self):
        self.assertEqual(getattr(self.handler, '_consumer', None), None)
        consumer = self.handler.consumer
        self.assertNotEqual(getattr(self.handler, '_consumer', None), None)

    def test_known_actions(self):
        self.assertEqual(
            self.handler.known_actions,
            ['+logout', '+login', '+return'])

    def test_protect(self):
        self.assertRaisesMessage(SERVER_RETURN, self.handler.protect, DONE)

    def test_protect_favicon(self):
        self.handler.request.request.subprocess_env['REQUEST_URI'] = '/favicon.ico'
        self.assertEquals(self.handler.protect(), OK)

    def test_protect_post(self):
        self.handler.request.request.method = 'POST'
        self.assertRaisesMessage(
            SERVER_RETURN, self.handler.protect, HTTP_FORBIDDEN)

    def test_protect_cookied_user(self):
        self.handler.request.cookied_user = []
        self.assertRaisesMessage(
            SERVER_RETURN, self.handler.protect, DONE)
