#!/usr/bin/python

# CellPhoneProxy
# Copyright (c) 2007 Dan Lenski
#
# Based on: "Tiny HTTP Proxy in Python"
# Copyright (c) 2001 SUZUKI Hisao
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

__version__ = "0.1"

import BaseHTTPServer, select, socket, SocketServer, urlparse
import httplib, base64

class ProxyHandler (BaseHTTPServer.BaseHTTPRequestHandler):
    __base = BaseHTTPServer.BaseHTTPRequestHandler

    server_version = "CellPhoneProxy/" + __version__
    protocol_version = "HTTP/1.1"       # want keep-alive to work
    rbufsize = 0                        # self.rfile Be unbuffered
    homepage = 'http://wap.google.com'
    
    # override parent to base keep-alive on Proxy-Connection header, if present
    def parse_request(self):
        s = self.__base.parse_request(self)
        if s:
            self.conntype = self.headers.get('Proxy-Connection') or self.headers.get('Connection', '')
            if self.conntype.lower() == 'close':
                self.close_connection = 1
            elif (self.conntype.lower() == 'keep-alive' and
                  self.protocol_version >= "HTTP/1.1"):
                self.close_connection = 0
        return s

    # override parent to allow extra headers, and include Content-Length
    def send_error(self, code, message=None, ext={}):
        try:
            short, long = self.responses[code]
        except KeyError:
            short, long = '???', '???'
        if message is None:
            message = short
        explain = long
        self.log_error("code %d, message %s", code, message)
        # using _quote_html to prevent Cross Site Scripting attacks (see bug #1100201)
        content = (self.error_message_format %
                   {'code': code, 'message': message, 'explain': explain})
        self.send_response(code, message)
        self.send_header("Content-Type", "text/html")
        self.send_header('Connection', 'close')
        for kv in ext.items(): self.send_header(*kv)
        if self.command != 'HEAD' and code >= 200 and code not in (204, 304):
            self.send_header('Content-Length', len(content))
            self.end_headers()
            self.wfile.write(content)
        else:
            self.end_headers()

    def _read_body(self):
        if self.headers.has_key('Content-Length'):
            self.body = self.rfile.read(int(self.headers['Content-Length']))
        else:
            self.body = ''

    def _check_proxy_auth(self):
        if not hasattr(self, 'users'):
            return True
        if self.headers.has_key('Proxy-Authorization'):
            method, cred = self.headers['Proxy-Authorization'].split(None, 1)
            del self.headers['Proxy-Authorization']
            if method=='Basic':
                if tuple(base64.b64decode(cred).split(':')) in self.users:
                    return True
        elif self.client_address[0] == '127.0.0.1':
            return True
        print self.client_address
        self.send_error(407, ext={'Proxy-Authenticate':'Basic realm="%s"' % self.version_string()})
        return False

    # add Content-Length header, remove Transfer-Encoding: chunked
    # (this is a workaround for the braindead non-compliant HTTP client on the
    # Motorola E815 phone)
    def _fix_headers_for_stupid_phones(self, headers, content):
        if 'chunked' in headers.getheaders('Transfer-Encoding'):
            self.log_message("removing chunked Transfer-Encoding (stupid phone hack)")
            del headers['Transfer-Encoding']
            
        if not headers.has_key('Content-Length'):
            self.log_message("adding missing Content-Length (stupid phone hack)")
            headers['Content-Length'] = str(len(content))
        
        #d = [k for k in headers if k.lower() not in
        #     ('location', 'content-type', 'content-length', 'date',
        #      'www-authenticate', 'proxy-authenticate', 'pragma',
        #      'connection', 'etag', 'set-cookie', 'content-encoding',
        #      'expires', 'server', 'host', 'last-modified', 'cache-control')]
        #if d: self.log_message("deleting headers %s (stupid phone hack)" % ', '.join(d))
        #for k in d: del headers[k]

    def _connect_to(self, netloc):
        soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        i = netloc.find(':')
        if i >= 0:
            host_port = netloc[:i], int(netloc[i+1:])
        else:
            host_port = netloc, 80
        try: soc.connect(host_port)
        except socket.error, arg:
            try: msg = arg[1]
            except: msg = arg
            self.send_error(404, msg)
            raise
        return soc

    def _read_write(self, soc, max_idling=20):
        iw = [self.connection, soc]
        ow = []
        count = 0
        while 1:
            count += 1
            (ins, _, exs) = select.select(iw, ow, iw, 3)
            if exs: break
            if ins:
                for i in ins:
                    if i is soc:
                        out = self.connection
                    else:
                        out = soc
                    data = i.recv(8192)
                    if data:
                        out.send(data)
                        count = 0
            else:
                self.log_message("idle %d" % count)
            if count == max_idling: break

    # in the case of a CONNECT request, simply connect the client socket to the target socket
    def do_CONNECT(self):
        if not self._check_proxy_auth(): return
        self.log_request()
        try:
            soc = self._connect_to(self.path)
            self.wfile.write("200 Connection established %s\r\n" % self.protocol_version)
            self.send_header("Proxy-agent", self.version_string())
            self.end_headers()
            self._read_write(soc, 300)
        except socket.error, e:
            self.log_message("socket error: %s", str(e))
            
    def do_GET(self):
        (scm, netloc, path, params, query, fragment) = urlparse.urlparse(self.path, 'http')
        if scm != 'http' or fragment or not netloc:
            self.send_error(400, "bad url %s" % self.path)
            return
        if not self._check_proxy_auth():  return
        if netloc == 'homepage':
            self.send_error(301, ext={'Location':self.homepage})
            return

        self.log_request()
        try:
            soc = self._connect_to(netloc)
            soc.send("%s %s %s\r\n" % (
                self.command,
                urlparse.urlunparse(('', '', path, params, query, '')),
                self.request_version))
            self.headers['Connection'] = 'close'
            del self.headers['Proxy-Connection']
            for key_val in self.headers.items():
                soc.send("%s: %s\r\n" % key_val)
            soc.send("\r\n")
        except socket.error, e:
            self.log_message("while talking to target: %s", str(e))
            return

        # in the case of an HTTP/0.9 request, simply connect the client socket to the target socket
        if self.request_version < 'HTTP/1.0':
            try: self._read_write(soc)
            except socket.error, arg: self.log_message("while responding to HTTP/0.9 client: %s", str(arg))
            return
        
        # in the case of a HTTP/1.0+ request, we'll read the requested page completely so that
        # we can transform it before passing it on to the client
        try:
            # read and forward request body
            self._read_body()
            soc.send(self.body)
            # read response from target server
            response = httplib.HTTPResponse(soc)
            response.begin()
            headers, content = response.msg, response.read()
            # adjust headers
            self._fix_headers_for_stupid_phones(headers, content)
        except socket.error, arg:
            try: msg = arg[1]
            except: msg = arg
            self.send_error(404, msg)
            return
        finally:
            soc.close()
        
        # write the response to our client
        try:
            self.wfile.write("%s %s %s\r\n" % (self.protocol_version, response.status, response.reason))
            if self.conntype: headers['Connection'] = self.conntype
            else: del headers['Connection']
            for k in headers: self.send_header(k, headers[k])
            self.end_headers()
            self.wfile.write(content)
        except socket.error, arg:
            self.log_message("while responding to %s client: %s", self.request_version, str(arg))
            self.close_connection = 1
        else:
            self.log_message("request completed succesfully")
           
    do_HEAD = do_GET
    do_POST = do_GET
    do_PUT  = do_GET
    do_DELETE=do_GET

    def handle_one_request(self):
        try: self.requests += 1
        except AttributeError: self.requests = 1
        self.__base.handle_one_request(self)

    def finish(self):
        self.__base.finish(self)
        self.log_message("bye! (handled %d requests)" % self.requests)
        self.requests = 0

class ThreadingHTTPServer (SocketServer.ThreadingMixIn,
                           BaseHTTPServer.HTTPServer): pass

if __name__ == '__main__':
    from sys import argv
    if argv[1:] and argv[1] in ('-h', '--help'):
        print argv[0], "[port] [homepage]"
    else:
        if argv[2:]: ProxyHandler.homepage = argv[2]
        if argv[1:]: port=int(argv[1])
        else: port=8080
        print "Serving proxy on port %d..." % port
        ThreadingHTTPServer(('',port), ProxyHandler).serve_forever()
