package pkg

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"net/http"
	"net/http/httputil"
	"strconv"
	"strings"
	"sync"

	"moul.io/http2curl"
)

const (
	respSplitHeader = "Web_Cache"
	respSplitValue  = "Vulnerability_Scanner"
)

type requestParams struct {
	repResult  *reportResult
	headers    []string
	values     []string
	parameters []string
	//cookie           oldCookie
	technique        string
	name             string
	identifier       string
	poison           string
	url              string
	cb               string
	success          string
	bodyString       string
	prependCB        bool
	forcePost        bool
	duplicateHeaders bool
	newCookie        http.Cookie
	m                *sync.Mutex
}

/*type oldCookie struct {
	position int
	oldValue string
}*/

func init() {
}

func getRespSplit() string {
	return "\\r\\n" + respSplitHeader + ": " + respSplitValue
}

func checkPoisoningIndicators(repResult *reportResult, request reportRequest, success string, body string, poison string, statusCode1 int, statusCode2 int, sameBodyLength bool, header http.Header, recursive bool) string {
	headerWithPoison := ""
	if header != nil && poison != "" {
		for x := range header {
			if x == respSplitHeader && header.Get(x) == respSplitValue {
				request.Reason = "HTTP Response Splitting"
			}
			if strings.Contains(header.Get(x), poison) {
				headerWithPoison = x
			}
		}
	}

	if request.Reason == "" {
		if poison != "" && poison != "http" && poison != "https" && poison != "nothttps" && strings.Contains(body, poison) { // dont check for reflection of http/https/nothttps (used by forwarded headers) or empty poison
			request.Reason = "Response Body contained " + poison
		} else if headerWithPoison != "" {
			request.Reason = fmt.Sprintf("%s header contains poison value %s", headerWithPoison, poison)
		} else if statusCode1 >= 0 && statusCode1 != Config.Website.StatusCode && statusCode1 == statusCode2 {
			// check if status code should be ignored
			if len(Config.IgnoreStatus) > 0 {
				for _, status := range Config.IgnoreStatus {
					if statusCode1 == status || Config.Website.StatusCode == status {
						PrintVerbose("Skipped Status Code "+strconv.Itoa(status)+"\n", Cyan, 1)
						_, err := GetWebsite(Config.Website.Url.String(), true, true)
						if err != nil {
							Print(fmt.Sprintln("Error while checking whether the default status code changed: ", err.Error()), Yellow)
						}
						return headerWithPoison
					}
				}
			}
			if !recursive {
				var tmpWebsite WebsiteStruct
				var err error

				// try up to 3 times
				count := 3
				for i := 0; i < count; i++ {
					Print(fmt.Sprintln("Status Code", statusCode1, "differed from the default", Config.Website.StatusCode, ", sending verification request", i+1, "from up to 3"), Yellow)
					tmpWebsite, err = GetWebsite(Config.Website.Url.String(), true, true)
					if err == nil {
						Print(fmt.Sprintln("The verification request returned the Status Code", tmpWebsite.StatusCode), Yellow)
						break
					}
				}
				if err != nil {
					repResult.HasError = true
					msg := fmt.Sprintf("%s: couldn't verify if status code %d is the new default status code, because the verification encountered the following error %d times: %s", request.URL, statusCode1, count, err.Error())
					repResult.ErrorMessages = append(repResult.ErrorMessages, msg)
				} else {
					Config.Website = tmpWebsite
				}
				return checkPoisoningIndicators(repResult, request, success, body, poison, statusCode1, statusCode2, sameBodyLength, header, true)
			} else {
				request.Reason = fmt.Sprintf("Status Code %d differed from %d", statusCode1, Config.Website.StatusCode)
			}
		} else if Config.CLDiff != 0 && success != "" && sameBodyLength && len(body) > 0 && compareLengths(len(body), len(Config.Website.Body), Config.CLDiff) {
			if !recursive {
				var tmpWebsite WebsiteStruct
				var err error

				// try up to 3 times
				count := 3
				for i := 0; i < count; i++ {
					tmpWebsite, err = GetWebsite(Config.Website.Url.String(), true, true)
					if err == nil {
						break
					}
				}
				if err != nil {
					repResult.HasError = true
					msg := fmt.Sprintf("%s: couldn't verify if body length %d is the new default body length, because the verification request encountered the following error %d times: %s", request.URL, statusCode1, count, err.Error())
					repResult.ErrorMessages = append(repResult.ErrorMessages, msg)
				} else {
					Config.Website = tmpWebsite
				}
				return checkPoisoningIndicators(repResult, request, success, body, poison, statusCode1, statusCode2, sameBodyLength, header, true)
			} else {
				request.Reason = fmt.Sprintf("Length %d differed more than %d bytes from normal length %d", len(body), Config.CLDiff, len(Config.Website.Body))
			}
		} else {
			return headerWithPoison
		}
	}

	PrintNewLine()
	Print(success, Green)
	msg := "URL: " + request.URL + "\n"
	Print(msg, Green)
	msg = "Reason: " + request.Reason + "\n"
	Print(msg, Green)
	msg = "Curl: " + request.CurlCommand + "\n\n"
	Print(msg, Green)
	repResult.Vulnerable = true
	repResult.Requests = append(repResult.Requests, request)
	return headerWithPoison
}

func compareLengths(len1 int, len2 int, limit int) bool {

	var diff int
	if len1 >= len2 {
		diff = len1 - len2
	} else {
		diff = len2 - len1
	}

	return diff > limit
}

/* Check if the second response makes sense or the continuation shall be stopped */
func stopContinuation(body []byte, statusCode int, headers http.Header) bool {
	if string(body) != Config.Website.Body {
		return false
	} else if statusCode != Config.Website.StatusCode {
		return false
	} else if len(headers) != len(Config.Website.Headers) {
		return false
	}

	for k, v := range headers {
		v2 := Config.Website.Headers.Values(k)

		// check if length of v and v2 is the same
		if len(v) != len(v2) {
			return false
		}
	}
	return true
}

func addParameters(urlStr *string, parameters []string) {
	for _, p := range parameters {
		if p == "" {
			continue
		}
		if !strings.Contains(*urlStr, "?") {
			*urlStr += "?"
		} else {
			*urlStr += Config.QuerySeparator
		}
		*urlStr += p
	}
}

func firstRequest(rp requestParams) ([]byte, int, reportRequest, http.Header, error) {
	var req *http.Request
	var resp *http.Response
	var err error
	var msg string
	var body []byte
	var repRequest reportRequest

	if rp.headers == nil {
		rp.headers = []string{""}
	}
	if rp.values == nil {
		rp.values = []string{""}
	}
	if rp.parameters == nil {
		rp.parameters = []string{""}
	}

	if rp.values[0] == "2ndrequest" {
		rp.identifier = fmt.Sprintf("2nd request of %s", rp.identifier)
	} else {
		rp.identifier = fmt.Sprintf("1st request of %s", rp.identifier)
	}

	// check if headers and values have the same length
	if len(rp.headers) != len(rp.values) && rp.values[0] != "2ndrequest" {
		msg = fmt.Sprintf("%s: len(header) %s %d != len(value) %s %d\n", rp.identifier, rp.headers, len(rp.headers), rp.values, len(rp.values))
		Print(msg, Red)
		return body, -1, repRequest, nil, errors.New(msg)
	}

	addParameters(&rp.url, rp.parameters)

	if !rp.forcePost && Config.Website.Cache.CBisHTTPMethod && rp.values[0] != "2ndrequest" {
		req, err = http.NewRequest(Config.Website.Cache.CBName, rp.url, bytes.NewBufferString(rp.bodyString))
	} else if Config.DoPost || rp.forcePost {
		if rp.bodyString == "" {
			rp.bodyString = Config.Body
		}
		req, err = http.NewRequest("POST", rp.url, bytes.NewBufferString(rp.bodyString))
	} else if rp.bodyString != "" {
		req, err = http.NewRequest("GET", rp.url, bytes.NewBufferString(rp.bodyString))
	} else {
		req, err = http.NewRequest("GET", rp.url, nil)
	}
	if err != nil {
		msg = fmt.Sprintf("%s: http.NewRequest: %s\n", rp.identifier, err.Error())
		Print(msg, Red)
		return body, -1, repRequest, nil, errors.New(msg)
	}

	newClient := http.Client{
		CheckRedirect: http.DefaultClient.CheckRedirect,
		Timeout:       http.DefaultClient.Timeout,
	}

	setRequest(req, Config.DoPost, rp.cb, rp.newCookie, rp.prependCB)

	for i := range rp.headers {
		if rp.headers[i] == "" {
			continue
		}
		if rp.values[0] == "2ndrequest" {
			msg = rp.identifier + "2nd request doesnt allow headers to be set\n"
			Print(msg, Red)
			break
		}
		if strings.EqualFold(rp.headers[i], "Host") && !rp.duplicateHeaders {
			newHost := req.URL.Host + rp.values[i]
			msg := fmt.Sprintf("Overwriting Host:%s with Host:%s\n", req.URL.Host, newHost)
			PrintVerbose(msg, NoColor, 2)
			req.Host = newHost
		} else if rp.headers[i] != "" {
			if h := req.Header.Get(rp.headers[i]); h != "" && !rp.duplicateHeaders {
				msg := fmt.Sprintf("Overwriting %s:%s with %s:%s\n", rp.headers[i], h, rp.headers[i], rp.values[i])
				PrintVerbose(msg, NoColor, 2)
				// Directly writing to map doesn't uppercase header(for HTTP1)
				req.Header[rp.headers[i]] = []string{rp.values[i]}
			} else if h != "" && rp.duplicateHeaders {
				// Directly writing to map doesn't uppercase the header(for HTTP1)
				req.Header[rp.headers[i]] = []string{h, rp.values[i]}
			} else {
				// Directly writing to map doesn't uppercase the header(for HTTP1)
				req.Header[rp.headers[i]] = []string{rp.values[i]}
			}
		}
	}

	waitLimiter(rp.identifier)
	var dumpReqBytes []byte
	var bodyBackup []byte
	if req.Body != nil {
		// Backup the request body, because it can only be read once
		bodyBackup, err = io.ReadAll(req.Body)
		if err != nil {
			msg = fmt.Sprintf("%s: io.ReadAll: %s\n", rp.identifier, err.Error())
			Print(msg, Red)
			return body, -1, repRequest, nil, errors.New(msg)
		}
		// Restore the request body for the first use
		req.Body = io.NopCloser(bytes.NewReader(bodyBackup))
		// Dump the request including the body
		dumpReqBytes, _ = httputil.DumpRequest(req, true)
	} else {
		// Dump the request without the body
		dumpReqBytes, _ = httputil.DumpRequest(req, false)
	}
	repRequest.Request = string(dumpReqBytes)

	// Do request
	resp, err = newClient.Do(req)
	if err != nil {
		msg = fmt.Sprintf("%s: newClient.Do: %s\n", rp.identifier, err.Error())
		Print(msg, Red)
		return body, -1, repRequest, nil, errors.New(msg)
	} else {
		defer resp.Body.Close()

		body, err = io.ReadAll(resp.Body)
		if err != nil {
			msg = fmt.Sprintf("%s: io.ReadAll: %s\n", rp.identifier, err.Error())
			Print(msg, Red)
			return body, -1, repRequest, nil, errors.New(msg)
		}

		if resp.StatusCode != Config.Website.StatusCode {
			msg = fmt.Sprintf("Unexpected Status Code %d for %s\n", resp.StatusCode, rp.identifier)
			Print(msg, Yellow)
		}
	}
	if stopContinuation(body, resp.StatusCode, resp.Header.Clone()) {
		msg := "stop"
		return body, resp.StatusCode, repRequest, resp.Header.Clone(), errors.New(msg)
	}

	// Add the request as curl command to the report
	command, err := http2curl.GetCurlCommand(req)
	if err != nil {
		PrintVerbose("Error: firstRequest: "+err.Error()+"\n", Yellow, 1)
	}
	commandFixed := strings.Replace(command.String(), "-d ''", "-d '"+string(bodyBackup)+"'", 1)
	if !strings.Contains(rp.url, req.Host) {
		commandFixed = strings.Replace(commandFixed, "'http", "-H 'Host: "+req.Host+"' 'http", 1)
	}
	repRequest.CurlCommand = commandFixed
	PrintVerbose("Curl command: "+repRequest.CurlCommand+"\n", NoColor, 2)

	responseBytes, _ := httputil.DumpResponse(resp, true)
	repRequest.Response = string(responseBytes)

	repRequest.URL = req.URL.String()

	//TODO: Also use dumped request/response of 2nd request

	return body, resp.StatusCode, repRequest, resp.Header.Clone(), nil
}

func secondRequest(rpFirst requestParams) ([]byte, int, http.Header, error) {
	var parameter []string
	if rpFirst.technique == "pollution" {
		for _, param := range rpFirst.parameters {
			if strings.Contains(param, "=foobar") {
				parameter = append(parameter, param)
			}
		}
	}

	rp := requestParams{
		parameters: parameter,
		values:     []string{"2ndrequest"},
		identifier: rpFirst.identifier,
		url:        rpFirst.url,
		cb:         rpFirst.cb,
	}

	body, statusCode, _, header, err := firstRequest(rp)

	return body, statusCode, header, err
}

// TODO: ResponseSplitting Methode
/* return value:first bool is needed for responsesplitting, second bool is only needed for ScanParameters */
func issueRequest(rp requestParams) (string, bool) {
	body1, statusCode1, request, header1, err := firstRequest(rp)
	if err != nil {
		if err.Error() != "stop" {
			if rp.m != nil {
				rp.m.Lock()
				defer rp.m.Unlock()
			}
			rp.repResult.HasError = true
			rp.repResult.ErrorMessages = append(rp.repResult.ErrorMessages, err.Error())
		}

		return "", false
	}

	impactful := firstRequestPoisoningIndicator(rp.identifier, body1, rp.poison, header1, Config.Website.Cache.CBName == rp.name, rp.cb) // TODO add different status code as indicator?

	body2, statusCode2, respHeader, err := secondRequest(rp)
	if err != nil {
		if err.Error() != "stop" {
			if rp.m != nil {
				rp.m.Lock()
				defer rp.m.Unlock()
			}
			rp.repResult.HasError = true
			rp.repResult.ErrorMessages = append(rp.repResult.ErrorMessages, err.Error())
		}
		return "", impactful
	}
	sameBodyLength := len(body1) == len(body2)

	// Lock here, to prevent false positives and too many GetWebsite requests

	if rp.m != nil {
		rp.m.Lock()
		defer rp.m.Unlock()
	}
	responseSplittingHeader := checkPoisoningIndicators(rp.repResult, request, rp.success, string(body2), rp.poison, statusCode1, statusCode2, sameBodyLength, respHeader, false)

	return responseSplittingHeader, impactful
}

func firstRequestPoisoningIndicator(identifier string, body []byte, poison string, header http.Header, identifierIsCB bool, cb string) bool {
	var reason string
	if poison != "" && poison != "http" && poison != "https" && poison != "nothttps" { // dont check for reflection of http/https/nothttps (used by forwarded headers) or empty poison
		if strings.Contains(string(body), poison) || (identifierIsCB && strings.Contains(string(body), cb)) { //
			reason = "Response Body contained " + poison
		}
		for x := range header {
			if strings.Contains(header.Get(x), poison) || (identifierIsCB && strings.Contains(header.Get(x), cb)) {
				reason = "Response Body contained " + poison
			}
		}
	}
	if Config.CLDiff != 0 && reason == "" && len(body) > 0 && compareLengths(len(body), len(Config.Website.Body), Config.CLDiff) {
		reason = fmt.Sprintf("Length %d differed more than %d bytes from normal length %d", len(body), Config.CLDiff, len(Config.Website.Body))
	}

	if reason != "" {
		msg := identifier + ": " + reason + "\n"
		Print(msg, Green)
		return true
	} else {
		return false
	}

}
