Files
mosec-pip-plugin/mosec/pip_resolve.py
2020-07-28 15:42:47 +08:00

266 lines
8.9 KiB
Python

import logging
import sys
import os
import argparse
import json
import ssl
import urllib.error
import urllib.request
from operator import attrgetter
from mosec import mosec_log_helper
from mosec import setup_file
from mosec import utils
from mosec.requirement_file_parser import get_requirements_list
from mosec.requirement_dist import ReqDist
try:
import pkg_resources
except ImportError:
# try using the version vendored by pip
try:
import pip._vendor.pkg_resources as pkg_resources
except ImportError:
raise ImportError(
"Could not import pkg_resources; please install setuptools or pip.")
try:
from collections import OrderedDict
except ImportError:
from ordereddict import OrderedDict
log = mosec_log_helper.Logger(name="mosec")
def create_deps_tree(
dist_tree,
top_level_requirements,
req_file_path,
allow_missing=False,
only_provenance=False
):
"""Create dist dependencies tree
:param dict dist_tree: the installed dists tree
:param list top_level_requirements: list of required dists
:param str req_file_path: path to the dependencies file (e.g. requirements.txt)
:param bool allow_missing: ignore uninstalled dependencies
:param bool only_provenance: only care provenance dependencies
:rtype: dict
"""
DEPENDENCIES = 'dependencies'
VERSION = 'version'
NAME = 'name'
DIR_VERSION = '1.0.0'
FROM = 'from'
tree = OrderedDict(
sorted(
[(k, sorted(v, key=attrgetter('key'))) for k, v in dist_tree.items()]
, key=lambda kv: kv[0].key
)
)
nodes = tree.keys()
key_tree = dict((k.key, v) for k, v in tree.items())
top_level_req_lower_names = [p.name.lower() for p in top_level_requirements]
top_level_req_dists = [p for p in nodes if p.key in top_level_req_lower_names]
def _create_children_recursive(root_package, ancestors):
root_name = root_package[NAME]
if root_name.lower() not in key_tree:
msg = 'Required packages missing: ' + root_name
if allow_missing:
log.error(msg)
return
else:
sys.exit(msg)
ancestors = ancestors.copy()
ancestors.add(root_name.lower())
children_dists = key_tree[root_name.lower()]
for child_dist in children_dists:
if child_dist.key in ancestors:
continue
child_node = _create_tree_node(child_dist, root_package)
_create_children_recursive(child_node, ancestors)
root_package[DEPENDENCIES][child_node[NAME]] = child_node
return root_package
def _create_root():
name, version = None, None
if os.path.basename(req_file_path) == 'setup.py':
with open(req_file_path, "r") as setup_py_file:
name, version = setup_file.parse_name_and_version(setup_py_file.read())
root = {
NAME: name or os.path.basename(os.path.dirname(os.path.abspath(req_file_path))),
VERSION: version or DIR_VERSION,
DEPENDENCIES: {}
}
root[FROM] = [root[NAME] + '@' + root[VERSION]]
return root
def _create_tree_node(dist_node, parent):
version = dist_node.version
if isinstance(version, tuple):
version = '.'.join(map(str, version))
return {
NAME: dist_node.project_name,
VERSION: version,
FROM: parent[FROM] + [dist_node.project_name + '@' + version],
DEPENDENCIES: {}
}
tree_root = _create_root()
for dist in top_level_req_dists:
tree_node = _create_tree_node(dist, tree_root)
if only_provenance:
tree_root[DEPENDENCIES][tree_node[NAME]] = tree_node
else:
tree_root[DEPENDENCIES][tree_node[NAME]] = _create_children_recursive(tree_node, set([]))
return tree_root
def create_dependencies_tree_by_req_file(
requirements_file,
allow_missing=False,
only_provenance=False
):
"""Create dist dependencies tree from file
:param str requirements_file: path to the dependencies file (e.g. requirements.txt)
:param bool allow_missing: ignore uninstalled dependencies
:param bool only_provenance: only care provenance dependencies
:rtype: dict
"""
# get all installed package distribution object list
dists = list(pkg_resources.working_set)
dists_dict = dict((p.key, p) for p in dists)
dists_tree = dict((p, [ReqDist(r, dists_dict.get(r.key)) for r in p.requires()]) for p in dists_dict.values())
required = get_requirements_list(requirements_file)
installed = [utils.canonicalize_dist_name(d) for d in dists_dict]
top_level_requirements = []
missing_package_names = []
for r in required:
if utils.canonicalize_dist_name(r.name) not in installed:
missing_package_names.append(r.name)
else:
top_level_requirements.append(r)
if missing_package_names:
msg = 'Required packages missing: ' + (', '.join(missing_package_names))
if allow_missing:
log.error(msg)
else:
sys.exit(msg)
return create_deps_tree(
dists_tree, top_level_requirements, requirements_file, allow_missing, only_provenance)
def render_response(response_json):
def _print_single_vuln(vuln):
log.error("{} severity vulnerability ({} - {}) found on {}@{}".format(
vuln.get('severity'),
vuln.get('title', ''),
vuln.get('cve', ''),
vuln.get('packageName'),
vuln.get('version'))
)
if vuln.get('from', None):
from_arr = vuln.get('from')
from_str = ""
for _from in from_arr:
from_str += _from + " > "
from_str = from_str[:-3]
print("- from: {}".format(from_str))
if vuln.get('target_version', []):
log.info("! Fix version {}".format(vuln.get('target_version')))
print("")
if response_json.get('ok', False):
log.info("✓ Tested {} dependencies for known vulnerabilities, no vulnerable paths found."
.format(response_json.get('dependencyCount', 0)))
elif response_json.get('vulnerabilities', None):
vulns = response_json.get('vulnerabilities')
for vuln in vulns:
_print_single_vuln(vuln)
log.warn("Tested {} dependencies for known vulnerabilities, found {} vulnerable paths."
.format(response_json.get('dependencyCount', 0), len(vulns)))
def run(args):
deps_tree = create_dependencies_tree_by_req_file(
args.requirements,
allow_missing=args.allow_missing,
only_provenance=args.only_provenance,
)
deps_tree['severityLevel'] = args.level
deps_tree['type'] = 'pip'
deps_tree['language'] = 'python'
log.debug(json.dumps(deps_tree, indent=2))
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
req = urllib.request.Request(
method='POST',
url=args.endpoint,
headers={
'Content-Type': 'application/json'
},
data=bytes(json.dumps(deps_tree).encode('utf-8'))
)
try:
response = urllib.request.urlopen(req, timeout=15, context=ctx)
response_json = json.loads(response.read().decode('utf-8'))
render_response(response_json)
if not response_json.get('ok', False):
return 1
except urllib.error.HTTPError as e:
raise Exception("Network Error: {}".format(e))
except json.JSONDecodeError as e:
raise Exception("API return data format error.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("requirements",
help="依赖文件 (requirements.txt 或 Pipfile)")
parser.add_argument("--endpoint",
action="store",
required=True,
help="上报API")
parser.add_argument("--allow-missing",
action="store_true",
help="忽略未安装的依赖")
parser.add_argument("--only-provenance",
action="store_true",
help="仅检查直接依赖")
parser.add_argument("--level",
action="store",
default="High",
help="威胁等级 [High|Medium|Low]. default: High")
parser.add_argument("--no-except",
action="store_true",
default=False,
help="发现漏洞不抛出异常")
parser.add_argument("--debug",
action="store_true",
default=False)
args = parser.parse_args()
if args.debug:
log.set_log_level(logging.DEBUG)
status = run(args)
if status == 1 and not args.no_except:
raise BaseException("Found Vulnerable!")