# pylint: disable=too-many-arguments, import-error """ :mod:`nyx_schema_compare` -- compare mysql schemas ------------------------------------------------------ Todo: Add more """ import sys from difflib import unified_diff import click import pymysql from sshtunnel import SSHTunnelForwarder from tabulate import tabulate from .nyxgateway import NyxGateway class MySqlConnection: def __init__(self, host="localhost", user="test", passwd="test", database_name="test1", is_vagrant=False): self.database_name = database_name self.tunnel = SSHTunnelForwarder( ('localhost', 2222), ssh_username="vagrant", ssh_password="vagrant", remote_bind_address=('127.0.0.1', 3306) ) try: self.gateway = NyxGateway() self.tunnel.start() if is_vagrant else None self.connection = pymysql.connect(host=host, user=user, passwd=passwd, db=database_name) self.cursor = self.connection.cursor() # If anything fails, we exit out immediately except Exception as error: self.clean_up(error) def __del__(self): self.tunnel.close() if self.tunnel.is_alive else None def get_table_names(self): try: self.cursor.execute(""" SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_SCHEMA=%s """, (self.database_name,)) return [element for tupl in self.cursor.fetchall() for element in tupl] except Exception as error: self.clean_up(error) def get_create_table_query(self, table_name): try: self.cursor.execute(""" SHOW CREATE TABLE {} """.format(table_name)) data = self.cursor.fetchone() return data[1].split("\n") if data[1] and isinstance(data[1], str) else '' except Exception as error: self.clean_up(error.args[1]) except Exception as error: self.clean_up(error) def clean_up(self, error): self.cursor.close() self.connection.close() print(error) sys.exit(1) class Utilities: @classmethod def get_diff_from_lists(cls, primary, comparator): return list(set(comparator) - set(primary)) @classmethod def get_common_from_lists(cls, comparator_a, comparator_b): return list(set(comparator_a) & set(comparator_b)) class TableInfo: def __init__(self, alias_name, gateway=NyxGateway()): self.gateway = gateway self.profile = self.gateway.get_schema_profile(alias_name) self.connection = MySqlConnection(**self.profile) self.tables = self.connection.get_table_names() class Main: ADDITIONAL = 'additional' MISSING = 'missing' COMMON = 'common' def __init__(self, primary_alias, comparator_alias, gateway=NyxGateway()): self.primary_table = TableInfo(primary_alias) self.comparator_table = TableInfo(comparator_alias) self.table_comparison_data = self.get_differences_in_schema_tables( self.primary_table.tables, self.comparator_table.tables) self.gateway = gateway self.primary_alias = primary_alias self.comparator_alias = comparator_alias def print_results(self): print(tabulate({ "Primary": [ self.primary_table.profile["host"], self.primary_table.profile["database_name"], self.primary_table.profile["user"], ], "Comparator": [ self.comparator_table.profile["host"], self.comparator_table.profile["database_name"], self.comparator_table.profile["user"] ] }, headers="keys", tablefmt="grid"), "\n") try: # Print table common, missing, and additional print(tabulate(self.table_comparison_data, headers="keys", tablefmt="grid"), "\n") # Print diffs for each table create sql query for table in self.table_comparison_data['common']: for line in unified_diff( self.primary_table.connection.get_create_table_query( table), self.comparator_table.connection.get_create_table_query( table), "Primary: {}:{}".format(self.primary_alias, table), "Comparator: {}:{}".format(self.comparator_alias, table)): print(line) except Exception as error: print(error) def get_differences_in_schema_tables(self, primary, comparator): return { self.COMMON: Utilities.get_common_from_lists(primary, comparator), self.MISSING: Utilities.get_diff_from_lists(primary, comparator), self.ADDITIONAL: Utilities.get_diff_from_lists(comparator, primary) } @click.command() @click.argument('primary_alias') @click.argument('comparator_alias') def nyx_schema_compare(primary_alias, comparator_alias): main = Main(primary_alias, comparator_alias) main.print_results() if __name__ == "__main__": nyx_schema_compare()