| 
									
										
										
										
											2020-04-25 11:35:18 +03:00
										 |  |  | import logging | 
					
						
							|  |  |  | import collections | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from .case import _BaseTestCaseContext | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | _LoggingWatcher = collections.namedtuple("_LoggingWatcher", | 
					
						
							|  |  |  |                                          ["records", "output"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class _CapturingHandler(logging.Handler): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     A logging handler capturing all (raw and formatted) logging output. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self): | 
					
						
							|  |  |  |         logging.Handler.__init__(self) | 
					
						
							|  |  |  |         self.watcher = _LoggingWatcher([], []) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def flush(self): | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def emit(self, record): | 
					
						
							|  |  |  |         self.watcher.records.append(record) | 
					
						
							|  |  |  |         msg = self.format(record) | 
					
						
							|  |  |  |         self.watcher.output.append(msg) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class _AssertLogsContext(_BaseTestCaseContext): | 
					
						
							| 
									
										
										
										
											2020-07-01 22:08:38 +01:00
										 |  |  |     """A context manager for assertLogs() and assertNoLogs() """ | 
					
						
							| 
									
										
										
										
											2020-04-25 11:35:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     LOGGING_FORMAT = "%(levelname)s:%(name)s:%(message)s" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-01 22:08:38 +01:00
										 |  |  |     def __init__(self, test_case, logger_name, level, no_logs): | 
					
						
							| 
									
										
										
										
											2020-04-25 11:35:18 +03:00
										 |  |  |         _BaseTestCaseContext.__init__(self, test_case) | 
					
						
							|  |  |  |         self.logger_name = logger_name | 
					
						
							|  |  |  |         if level: | 
					
						
							|  |  |  |             self.level = logging._nameToLevel.get(level, level) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.level = logging.INFO | 
					
						
							|  |  |  |         self.msg = None | 
					
						
							| 
									
										
										
										
											2020-07-01 22:08:38 +01:00
										 |  |  |         self.no_logs = no_logs | 
					
						
							| 
									
										
										
										
											2020-04-25 11:35:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def __enter__(self): | 
					
						
							|  |  |  |         if isinstance(self.logger_name, logging.Logger): | 
					
						
							|  |  |  |             logger = self.logger = self.logger_name | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             logger = self.logger = logging.getLogger(self.logger_name) | 
					
						
							|  |  |  |         formatter = logging.Formatter(self.LOGGING_FORMAT) | 
					
						
							|  |  |  |         handler = _CapturingHandler() | 
					
						
							| 
									
										
										
										
											2020-11-02 19:25:29 +00:00
										 |  |  |         handler.setLevel(self.level) | 
					
						
							| 
									
										
										
										
											2020-04-25 11:35:18 +03:00
										 |  |  |         handler.setFormatter(formatter) | 
					
						
							|  |  |  |         self.watcher = handler.watcher | 
					
						
							|  |  |  |         self.old_handlers = logger.handlers[:] | 
					
						
							|  |  |  |         self.old_level = logger.level | 
					
						
							|  |  |  |         self.old_propagate = logger.propagate | 
					
						
							|  |  |  |         logger.handlers = [handler] | 
					
						
							|  |  |  |         logger.setLevel(self.level) | 
					
						
							|  |  |  |         logger.propagate = False | 
					
						
							| 
									
										
										
										
											2020-07-01 22:08:38 +01:00
										 |  |  |         if self.no_logs: | 
					
						
							|  |  |  |             return | 
					
						
							| 
									
										
										
										
											2020-04-25 11:35:18 +03:00
										 |  |  |         return handler.watcher | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __exit__(self, exc_type, exc_value, tb): | 
					
						
							|  |  |  |         self.logger.handlers = self.old_handlers | 
					
						
							|  |  |  |         self.logger.propagate = self.old_propagate | 
					
						
							|  |  |  |         self.logger.setLevel(self.old_level) | 
					
						
							| 
									
										
										
										
											2020-07-01 22:08:38 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-25 11:35:18 +03:00
										 |  |  |         if exc_type is not None: | 
					
						
							|  |  |  |             # let unexpected exceptions pass through | 
					
						
							|  |  |  |             return False | 
					
						
							| 
									
										
										
										
											2020-07-01 22:08:38 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if self.no_logs: | 
					
						
							|  |  |  |             # assertNoLogs | 
					
						
							|  |  |  |             if len(self.watcher.records) > 0: | 
					
						
							|  |  |  |                 self._raiseFailure( | 
					
						
							|  |  |  |                     "Unexpected logs found: {!r}".format( | 
					
						
							|  |  |  |                         self.watcher.output | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             # assertLogs | 
					
						
							|  |  |  |             if len(self.watcher.records) == 0: | 
					
						
							|  |  |  |                 self._raiseFailure( | 
					
						
							|  |  |  |                     "no logs of level {} or higher triggered on {}" | 
					
						
							|  |  |  |                     .format(logging.getLevelName(self.level), self.logger.name)) |