package com.icetech.web.wrapper;

import org.apache.commons.collections.MapUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.MediaType;
import org.springframework.util.StreamUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

public class HttpServletRequestDecorator extends HttpServletRequestWrapper {

    private static final Logger logger = LoggerFactory.getLogger(HttpServletRequestDecorator.class);
    /**
     * 参数Map
     */
    protected Map<String, String[]> paramMap;
    /**
     * json/xml request body
     */
    protected byte[] body;

    /**
     * Instantiates a new Http servlet request decorator.
     *
     * @param request the request
     */
    public HttpServletRequestDecorator(HttpServletRequest request) throws IOException {
        super(request);
        paramMap = new HashMap<>(request.getParameterMap());  // 读取body后parameterMap不再初始化
        body = StreamUtils.copyToByteArray(request.getInputStream());
    }

    public byte[] getCacheBody() {
        return body;
    }

    /**
     * 得到对应的请求参数，这里返回的参数值已经被转码
     *
     * @param arg0 arg0
     * @return 对应的请求参数值
     */
    @Override
    public String getParameter(String arg0) {
        String[] paramValues = getParameterValues(arg0);
        if (paramValues != null && paramValues.length > 0) {
            return paramValues[0];
        } else {
            return null;
        }
    }

    /**
     * @return 请求参数Map
     */
    @Override
    public Map<String, String[]> getParameterMap() {
        return paramMap;
    }

    /**
     * @return 请求参数名的枚举，这里返回的枚举包括解密请求参数之后的参数名
     */
    @Override
    public Enumeration<String> getParameterNames() {
        return new IteratorWrapper<>(paramMap.keySet().iterator());
    }

    /**
     * @return 得到对应的请求参数值数组
     */
    @Override
    public String[] getParameterValues(String arg0) {
        return paramMap.get(arg0);
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        return new ServletInputStreamAdapter(body);
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(new ByteArrayInputStream(body)));
    }

    /**
     * 设置可重复读取的input stream
     *
     * @throws IOException IOException
     */
    protected boolean handleInputStream() throws IOException {
        if (body == null) {
            MediaType mt = getContentType(getRequest());
            if (mt != null && (MediaType.APPLICATION_JSON.isCompatibleWith(mt) || MediaType.MULTIPART_FORM_DATA.isCompatibleWith(mt))) {
                // 这里还可以判断其它的requestbody类型
                // 如果是json数据提交，则设置可重复读取的inputstream，用来做拦截器获取
                ByteArrayOutputStream swapStream = new ByteArrayOutputStream();
                byte[] buff = new byte[100];
                int rc;
                try (InputStream in = getRequest().getInputStream()) {
                    while ((rc = in.read(buff, 0, 100)) > 0) {
                        swapStream.write(buff, 0, rc);
                    }
                }
                swapStream.flush();
                body = swapStream.toByteArray();
                swapStream.close();
                return true;
            } else {
                return false;
            }
        } else {
            return true;
        }
    }

    private MediaType getContentType(ServletRequest request) {
        String value = request.getContentType();
        return (value != null ? MediaType.parseMediaType(value) : null);
    }

    /**
     * 处理请求中的参数
     */
    protected void handleParam() {
        if (paramMap == null) {
            // 懒加载方式生成
            paramMap = new HashMap<>(getRequest().getParameterMap());
        }
    }

    /**
     * 设置参数
     *
     * @param paramMap the param map
     */
    public void setParameters(Map<String, String[]> paramMap) {
        logger.debug("Cus set request parameter {}", paramMap == null ? "nil" : paramMap.size());
        if (MapUtils.isNotEmpty(paramMap)) {
            this.paramMap.putAll(paramMap);
        }
    }

    /**
     * @param body the body to set
     */
    public void setBody(byte[] body) {
        this.body = body;
    }

    @Override
    public String toString() {
        return "Common Request;" + getRequest().toString();
    }

    /**
     * 把Iterator包装成Enumeration的一个包装类
     */
    private static class IteratorWrapper<E> implements Enumeration<E> {
        private final Iterator<E> it;

        public IteratorWrapper(Iterator<E> it) {
            this.it = it;
        }

        @Override
        public boolean hasMoreElements() {
            return it.hasNext();
        }

        @Override
        public E nextElement() {
            return it.next();
        }
    }

    /**
     * 包装输入流
     */
    private static class ServletInputStreamAdapter extends ServletInputStream {
        private final ByteArrayInputStream inputStream;

        public ServletInputStreamAdapter(byte[] bytes) {
            inputStream = new ByteArrayInputStream(bytes);
        }

        @Override
        public int read() throws IOException {
            return inputStream.read();
        }

        @Override
        public boolean isReady() {
            return true;
        }

        @Override
        public boolean isFinished() {
            return inputStream.available() == 0;
        }

        @Override
        public void setReadListener(ReadListener listener) {

        }
    }
}
