149 lines
5.3 KiB
Python
149 lines
5.3 KiB
Python
import requests
|
||
import utils
|
||
import json
|
||
import logging as logger
|
||
#import InvalidParameterException
|
||
|
||
from flask import make_response
|
||
from flask import request
|
||
from ExceptionHandler import *
|
||
|
||
|
||
class NoopExceptionHandler(ExceptionHandler):
|
||
|
||
DEFAULT_CONTENT = {
|
||
"error_code": "InternalServerError",
|
||
"error_message": "",
|
||
"error_message_chs": "服务器内部错误"
|
||
}
|
||
|
||
def handle(self, e, *args, **kwargs):
|
||
logger.exception("Unknown exception occurred: %s", e)
|
||
content = dict(self.DEFAULT_CONTENT)
|
||
content['error_message'] = e.message
|
||
return json.dumps(content), 500
|
||
|
||
|
||
class GenericGateway(object):
|
||
|
||
def __init__(self, app, routes, url_supplier, exception_handler=None):
|
||
"""
|
||
|
||
:param app: flask app
|
||
:type app: flask.Flask
|
||
:param routes: 路由列表
|
||
:type routes: list of Route
|
||
:param url_supplier: 服务名称解析函数,通过service name获取service url: url_supplier(service_name) -> service_url
|
||
:type url_supplier: function
|
||
:param exception_handler: exception handle decorator
|
||
:type exception_handler: ExceptionHandler
|
||
"""
|
||
if not callable(url_supplier):
|
||
raise ValueError("service url supplier (url_supplier) must be callable")
|
||
|
||
self.app = app
|
||
self.routes = routes
|
||
self.url_supplier = url_supplier
|
||
self.exception_handler = exception_handler or NoopExceptionHandler()
|
||
self.proxy_list = []
|
||
|
||
def get_url(self, service_name):
|
||
try:
|
||
return self.url_supplier(service_name)
|
||
except Exception as e:
|
||
logger.warning("Failed to get the url of service: %s", service_name, exc_info=True)
|
||
raise RuntimeError("Failed to get the url of service {}: {}".format(service_name, e))
|
||
|
||
def create_request_proxy(self, route):
|
||
|
||
logger.info("Register route: %s", route.path)
|
||
|
||
def proxy_request(*args, **kwargs):
|
||
|
||
upstream_path = route.upstream_path
|
||
for p_name in route.upstream_params:
|
||
v = get_params_from_context(p_name, kwargs)
|
||
upstream_path = utils.replace_placeholder(upstream_path, p_name, v)
|
||
|
||
endpoint = self.get_url(route.service_name)
|
||
|
||
final_url = "{}{}".format(endpoint, upstream_path)
|
||
upstream_req_info = dict(
|
||
url=final_url,
|
||
method=request.method,
|
||
params=request.args,
|
||
data=request.data,
|
||
files=request.files,
|
||
headers=request.headers,
|
||
)
|
||
if request.form:
|
||
upstream_req_info['data'] = request.form
|
||
|
||
for pipe_handler in route.inbound_pipes():
|
||
upstream_req_info, args, kwargs = pipe_handler(upstream_req_info, *args, **kwargs)
|
||
|
||
content_type = upstream_req_info['headers'].get('Content-Type')
|
||
if content_type and 'boundary' in content_type:
|
||
# requests prepare body时,会使重新生成 boundary,headers中旧的 boundary 会导致上游服务无法识别
|
||
upstream_req_info['headers'] = {k: v for k, v in upstream_req_info['headers'].items()}
|
||
upstream_req_info['headers'].pop('Content-Type', None)
|
||
upstream_req_info['headers'].pop('content-type', None)
|
||
resp = requests.request(**upstream_req_info)
|
||
status = resp.status_code
|
||
content = None
|
||
headers = resp.headers
|
||
|
||
for pipe_handler in route.outbound_pipes():
|
||
resp, headers, status, content = pipe_handler(resp, headers, status, content, *args, **kwargs)
|
||
|
||
content = make_response(content or resp.content)
|
||
for h_name, h_value in (headers or {}).items():
|
||
content.headers[h_name] = h_value
|
||
|
||
return content, status
|
||
|
||
# 修改函数名称:Flask route注册使用func name必须唯一
|
||
proxy_request.__name__ = route.name
|
||
return proxy_request
|
||
|
||
def _register_route(self, route):
|
||
"""
|
||
:param Route route:
|
||
"""
|
||
proxy = self.create_request_proxy(route)
|
||
self.proxy_list.append(
|
||
self.app.route(route.path, methods=route.methods)(proxy)
|
||
)
|
||
|
||
def register_routes(self):
|
||
for route in self.routes:
|
||
self._register_route(route)
|
||
|
||
|
||
def get_params_from_context(param_name, kwargs):
|
||
"""按顺序尝试从url参数、headers、args、form、json中获取参数值
|
||
"""
|
||
payload = {}
|
||
|
||
if kwargs and kwargs.get(param_name):
|
||
payload = kwargs
|
||
elif request.headers.get(param_name):
|
||
payload = request.headers
|
||
elif request.args and request.args.get(param_name):
|
||
payload = request.args
|
||
elif request.form and request.form.get(param_name):
|
||
payload = request.form
|
||
elif request.data:
|
||
body = json.loads(request.data)
|
||
if body and body.get(param_name):
|
||
payload = body
|
||
|
||
try:
|
||
return payload[param_name]
|
||
except KeyError as e:
|
||
# raise InvalidParameterException(
|
||
# message="Missing the required params: {}".format(e),
|
||
# message_chs="缺少必要的参数 {}".format(e)
|
||
# )
|
||
|
||
print('') |