Refactored, changed to httx and async instead of requests and threads
This commit is contained in:
parent
56b850dc4f
commit
2d2afe7930
5 changed files with 84 additions and 90 deletions
|
@ -1,27 +1,30 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from html.parser import HTMLParser
|
||||
from random import sample
|
||||
from requests import get as request_get
|
||||
from sys import stderr
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
from .linkmap import LinkMap
|
||||
|
||||
|
||||
class _HTMLExternalLinkFinder(HTMLParser):
|
||||
class HTMLLinkFinder(HTMLParser):
|
||||
|
||||
links = []
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
if tag == "a":
|
||||
for a in attrs:
|
||||
attr, val = a
|
||||
if attr == "href":
|
||||
if val.startswith("https://") or val.startswith("http://"):
|
||||
if not val in self.links:
|
||||
self.links.append(val)
|
||||
if tag != "a":
|
||||
return
|
||||
for a in attrs:
|
||||
attr, val = a
|
||||
if attr != "href":
|
||||
continue
|
||||
if val.startswith("https://") or val.startswith("http://"):
|
||||
if not val in self.links:
|
||||
self.links.append(val)
|
||||
|
||||
def get_links(self, input_html:str):
|
||||
self.feed(input_html)
|
||||
|
@ -35,80 +38,67 @@ class LinkMapFromSitelinksGenerator:
|
|||
generated_linkmap = LinkMap()
|
||||
max_links_per_site = 3
|
||||
max_depth = 3
|
||||
max_threads = 4
|
||||
enable_log = False
|
||||
|
||||
def log(self, something):
|
||||
if self.enable_log:
|
||||
print(something, file=stderr)
|
||||
|
||||
def _get_html(self, url:str) -> str:
|
||||
html_content = ""
|
||||
# receive up to self.site_request_max_len bytes after a maximum of self.site_request_timeout seconds
|
||||
self.log("----" + url)
|
||||
response = request_get(url, stream=True, timeout=self.site_request_timeout)
|
||||
response.raise_for_status()
|
||||
async def _get_html(self, url:str, client:AsyncClient) -> str:
|
||||
content = bytearray()
|
||||
content_size = 0
|
||||
content_chunks = []
|
||||
for chunk in response.iter_content(1024, decode_unicode=True):
|
||||
content_size += len(chunk)
|
||||
if content_size > self.site_request_max_len:
|
||||
self.log(f"Maximum content length exceeded! received: {content_size} (maximum: {self.site_request_max_len})")
|
||||
break
|
||||
else:
|
||||
content_chunks.append(chunk)
|
||||
html_content = "".join(content_chunks)
|
||||
# receive up to self.site_request_max_len bytes after
|
||||
# a maximum of self.site_request_timeout seconds
|
||||
self.log(f"Request: {url}")
|
||||
async with client.stream(
|
||||
"GET",
|
||||
url,
|
||||
timeout=self.site_request_timeout,
|
||||
follow_redirects=True
|
||||
) as stream:
|
||||
async for chunk in stream.aiter_bytes(1024):
|
||||
content_size += len(chunk)
|
||||
if content_size > self.site_request_max_len:
|
||||
self.log(f"Maximum content length exceeded! received: {content_size} (maximum: {self.site_request_max_len})")
|
||||
break
|
||||
else:
|
||||
content.extend(chunk)
|
||||
# decode
|
||||
try:
|
||||
html_content = content.decode()
|
||||
except UnicodeDecodeError:
|
||||
self.log(f"Couldn't decode {url}")
|
||||
html_content = ""
|
||||
return html_content
|
||||
|
||||
def _get_linked_sites_thread(self, urls:list):
|
||||
def _get_links(url:str):
|
||||
sites = []
|
||||
try:
|
||||
html = self._get_html(url)
|
||||
found_links = _HTMLExternalLinkFinder().get_links(html)
|
||||
found_links = sample(found_links, min(self.max_links_per_site, len(found_links)))
|
||||
self.log("\n".join(found_links))
|
||||
for l in found_links:
|
||||
if l != None:
|
||||
sites.append(l)
|
||||
except KeyboardInterrupt:
|
||||
exit("KeyboardInterrupt")
|
||||
except Exception as e:
|
||||
self.log("An exception occcured while trying to get links from '" + url + "': ")
|
||||
self.log(e)
|
||||
return sites
|
||||
links = {}
|
||||
for url in urls:
|
||||
links[url] = _get_links(url)
|
||||
return links
|
||||
async def _get_linked_sites_coro(self, url, client:AsyncClient):
|
||||
linked_sites = []
|
||||
try:
|
||||
html = await self._get_html(url, client)
|
||||
found_links = HTMLLinkFinder().get_links(html)
|
||||
found_links = sample(found_links, min(self.max_links_per_site, len(found_links)))
|
||||
for l in found_links:
|
||||
self.log(f"Found {l}")
|
||||
if l != None:
|
||||
linked_sites.append(l)
|
||||
except KeyboardInterrupt:
|
||||
exit("KeyboardInterrupt")
|
||||
except Exception as e:
|
||||
self.log("An exception occcured while trying to get links from '" + url + "': ")
|
||||
self.log(e)
|
||||
return url, linked_sites
|
||||
|
||||
def _get_linked_sites(self, urls:list):
|
||||
# split urls into self.max_threads chunks
|
||||
urlchunks = []
|
||||
chunk_size = max(int(len(urls) / self.max_threads), 1)
|
||||
for i in range(self.max_threads):
|
||||
start = i*chunk_size
|
||||
end = (i*chunk_size)+chunk_size
|
||||
new_chunk = urls[start:end]
|
||||
if len(new_chunk) > 0:
|
||||
urlchunks.append(new_chunk)
|
||||
results = []
|
||||
# threads
|
||||
with ThreadPoolExecutor() as tpe:
|
||||
self.log(f"--Using {len(urlchunks)} concurrent connections...")
|
||||
futures = [tpe.submit(self._get_linked_sites_thread, chunk) for chunk in urlchunks]
|
||||
for f in futures:
|
||||
# wait for results
|
||||
results.append(f.result())
|
||||
results_combined = {}
|
||||
for result_chunk in results:
|
||||
for url in result_chunk:
|
||||
results_combined[url] = result_chunk[url]
|
||||
return results_combined
|
||||
async def _get_linked_sites(self, urls:list, client:AsyncClient):
|
||||
# get results
|
||||
results = await asyncio.gather(*[self._get_linked_sites_coro(url, client) for url in urls])
|
||||
results_as_dict = {}
|
||||
for url, links in results:
|
||||
results_as_dict[url] = links
|
||||
return results_as_dict
|
||||
|
||||
def _generate_linkmap(self, start_urls:list, _current_depth:int):
|
||||
async def _generate_linkmap(self, start_urls:list, _current_depth:int, client:AsyncClient):
|
||||
linkdict = {}
|
||||
linked_sites = self._get_linked_sites(start_urls)
|
||||
linked_sites = await self._get_linked_sites(start_urls, client)
|
||||
for url in linked_sites:
|
||||
linkdict[url] = {}
|
||||
self.generated_linkmap.add_link(url)
|
||||
|
@ -118,13 +108,14 @@ class LinkMapFromSitelinksGenerator:
|
|||
self.generated_linkmap.add_link_connection(url, l)#
|
||||
if _current_depth < self.max_depth:
|
||||
for url in linkdict:
|
||||
linkdict[url] = self._generate_linkmap(list(linkdict[url]), _current_depth + 1)
|
||||
linkdict[url] = await self._generate_linkmap(list(linkdict[url]), _current_depth + 1, client)
|
||||
|
||||
def generate(self, start_url:str, max_depth:int=3, max_links_per_site:int=3):
|
||||
async def generate(self, start_url:str, max_depth:int=3, max_links_per_site:int=3):
|
||||
self.generated_linkmap = LinkMap()
|
||||
self.max_links_per_site = max_links_per_site
|
||||
self.max_depth = max_depth
|
||||
self._generate_linkmap([start_url], 1)
|
||||
async with AsyncClient() as client:
|
||||
await self._generate_linkmap([start_url], 1, client)
|
||||
|
||||
def get_linkmap(self) -> LinkMap:
|
||||
return self.generated_linkmap
|
||||
|
|
Reference in a new issue