(************************************************************************) (* This file is part of SKS. SKS is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA *) (***********************************************************************) (** simple web server code *) open StdLabels open MoreLabels open Printf open Common module Unix = UnixLabels open Unix module Map = PMap.Map module Set = PSet.Set exception Not_implemented of string exception Misc_error of string exception Page_not_found of string let ( |= ) map key = Map.find key map let ( |< ) map (key,data) = Map.add ~key ~data map let hexa_digit x = if x >= 10 then Char.chr (Char.code 'A' + x - 10) else Char.chr (Char.code '0' + x) let hexa_val conf = match conf with '0'..'9' -> Char.code conf - Char.code '0' | 'a'..'f' -> Char.code conf - Char.code 'a' + 10 | 'A'..'F' -> Char.code conf - Char.code 'A' + 10 | _ -> 0 let decode s = let rec need_decode i = if i < String.length s then match s.[i] with '%' | '+' -> true | _ -> need_decode (succ i) else false in let rec compute_len i i1 = if i < String.length s then let i = match s.[i] with '%' when i + 2 < String.length s -> i + 3 | _ -> succ i in compute_len i (succ i1) else i1 in let rec copy_decode_in s1 i i1 = if i < String.length s then let i = match s.[i] with '%' when i + 2 < String.length s -> let v = hexa_val s.[i + 1] * 16 + hexa_val s.[i + 2] in s1.[i1] <- Char.chr v; i + 3 | '+' -> s1.[i1] <- ' '; succ i | x -> s1.[i1] <- x; succ i in copy_decode_in s1 i (succ i1) else s1 in let rec strip_heading_and_trailing_spaces s = if String.length s > 0 then if s.[0] == ' ' then strip_heading_and_trailing_spaces (String.sub s 1 (String.length s - 1)) else if s.[String.length s - 1] == ' ' then strip_heading_and_trailing_spaces (String.sub s 0 (String.length s - 1)) else s else s in if need_decode 0 then let len = compute_len 0 0 in let s1 = String.create len in strip_heading_and_trailing_spaces (copy_decode_in s1 0 0) else s let special x = List.mem x ['='; '&'; '"'; '\r'; '\n'; '+'] let encode s = let rec need_code i = if i < String.length s then match s.[i] with ' ' -> true | x -> if special x then true else need_code (succ i) else false in let rec compute_len i i1 = if i < String.length s then let i1 = if special s.[i] then i1 + 3 else succ i1 in compute_len (succ i) i1 else i1 in let rec copy_code_in s1 i i1 = if i < String.length s then let i1 = match s.[i] with ' ' -> s1.[i1] <- '+'; succ i1 | c -> if special c then begin s1.[i1] <- '%'; s1.[i1 + 1] <- hexa_digit (Char.code c / 16); s1.[i1 + 2] <- hexa_digit (Char.code c mod 16); i1 + 3 end else begin s1.[i1] <- c; succ i1 end in copy_code_in s1 (succ i) i1 else s1 in if need_code 0 then let len = compute_len 0 0 in copy_code_in (String.create len) 0 0 else s let stripchars = Set.of_list [ ' '; '\t'; '\n'; '\r' ] let strip s = let start = ref 0 in while (!start < String.length s && Set.mem s.[!start] stripchars) do incr start done; let stop = ref (String.length s - 1) in while (!stop >= 0 && Set.mem s.[!stop] stripchars) do decr stop done; if !stop >= !start then String.sub s ~pos:!start ~len:(!stop - !start + 1) else "" type 'a request = | GET of (string * (string,string) Map.t) | POST of (string * (string,string) Map.t * 'a) let whitespace = Str.regexp "[ \t\n\r]+" let eol = Str.regexp "\r?\n" let get_all cin = let buf = Buffer.create 0 in (try Buffer.add_channel buf cin 10000 with End_of_file -> ()); Buffer.contents buf let get_lines cin = Str.split eol (get_all cin) let max_post_length = 5 * 1024 * 1024 (* posts restricted to 5 Megs or less *) let parse_post headers cin = try let lengthstr = headers |= "content-length" in let len = int_of_string lengthstr in if len > max_post_length then raise (Misc_error (sprintf "POST data too long: %f megs" (float len /. 1024. /. 1024.))); let rest = String.create len in really_input cin rest 0 len; rest with Not_found -> failwith "parse_post failed for lack of a content-length header" let is_blank line = String.length line = 0 || line.[0] = '\r' let rec parse_headers map cin = let line = input_line cin in (* DOS attack: input_line is unsafe on sockets *) if is_blank line then map else let colonpos = try String.index line ':' with Not_found -> failwith "Error parsing headers: no colon found" in let key = String.sub line ~pos:0 ~len:colonpos and data = String.sub line ~pos:(colonpos + 1) ~len:(String.length line - colonpos - 1) in parse_headers (map |< (String.lowercase key, strip data)) cin let parse_request cin = let line = input_line cin in (* DOS attack: input_line is unsafe on sockets *) let pieces = Str.split whitespace line in let headers = parse_headers Map.empty cin in match List.hd pieces with "GET" -> GET (List.nth pieces 1,headers) | "POST" -> POST (List.nth pieces 1,headers, parse_post headers cin) | _ -> failwith "Malformed header" let headers_to_string map = let pieces = List.map ~f:(fun (x,y) -> sprintf "%s:%s" x y) (Map.to_alist map) in "\n" ^ (String.concat "\n" pieces) let request_to_string request = let (kind,req,headers) = match request with | GET (req,header_map) -> ("GET",req,headers_to_string header_map) | POST (req,header_map,_) -> ("POST",req,headers_to_string header_map) in sprintf "(%s,%s,[%s])" kind req headers let request_to_string_short request = let (kind,request) = match request with | GET (req,header_map) -> ("GET",req) | POST (req,header_map,_) -> ("POST",req) in sprintf "(%s %s)" kind request let send_result cout ?(error_code = 200) ?(content_type = "text/html; charset=UTF-8") body = fprintf cout "HTTP/1.0 %03d OK\r\n" error_code; fprintf cout "Server: sks_www/%s\r\n" version; fprintf cout "Content-type: %s\r\n\r\n" content_type; fprintf cout "%s\r\n" body; flush cout let accept_connection f ~recover_timeout addr cin cout = begin try let request = parse_request cin in let output_chan = Channel.new_buffer_outc 0 in try let content_type = f addr request output_chan#upcast in let output = output_chan#contents in send_result cout ~content_type output with | Eventloop.SigAlarm as e -> ignore (Unix.alarm recover_timeout); plerror 2 "request %s timed out" (request_to_string request); let output = HtmlTemplates.page ~title:"Time Out" ~body:(sprintf "Error handling request %s: Timed out after %d seconds" (request_to_string_short request) !Settings.wserver_timeout) in send_result cout ~error_code:408 output | Sys.Break as e -> plerror 1 "Break occured while processing HKP request %s" (request_to_string request); raise e | Not_implemented s -> ignore (Unix.alarm recover_timeout); plerror 2 "Error handling request %s: %s" (request_to_string request) ("Not implemented: " ^ s); let output = HtmlTemplates.page ~title:"Not implemented" ~body:(sprintf "Error handling request %s: %s not implemented." (request_to_string request) s) in send_result cout ~error_code:501 output | Page_not_found s -> ignore (Unix.alarm recover_timeout); plerror 2 "Page not found: %s" s; let output = HtmlTemplates.page ~title:"Page not found" ~body:(sprintf "Page not found: %s" s) in send_result cout ~error_code:404 output | Misc_error s -> ignore (Unix.alarm recover_timeout); plerror 2 "Error handling request %s: %s" (request_to_string request) s; let output = HtmlTemplates.page ~title:"Error handling request" ~body:(sprintf "Error handling request: %s" s) in send_result cout ~error_code:500 output | e -> ignore (Unix.alarm recover_timeout); plerror 2 "Error handling request %s: %s" (request_to_string request) (Common.err_to_string e); let content_type = "text/html; charset=UTF-8" in let output = (HtmlTemplates.page ~title:"Error handling request" ~body:(sprintf "Error handling request. Exception raised: %s" (Common.err_to_string e))) in send_result cout ~error_code:500 output with | Sys.Break as e -> raise e | Eventloop.SigAlarm as e -> ignore (Unix.alarm recover_timeout); let output = HtmlTemplates.page ~title:"Timeout" ~body:(sprintf "Request timed during request parsing after %d seconds" !Settings.wserver_timeout) in send_result cout ~error_code:408 output | e -> eplerror 5 e "Miscellaneous error" end; []