|  | @@ -6,11 +6,12 @@ Author: dhb52 (https://gitee.com/dhb52)
 | 
											
												
													
														|  |  pip install simple-ddl-parser
 |  |  pip install simple-ddl-parser
 | 
											
												
													
														|  |  """
 |  |  """
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +import argparse
 | 
											
												
													
														|  |  import pathlib
 |  |  import pathlib
 | 
											
												
													
														|  |  import re
 |  |  import re
 | 
											
												
													
														|  |  import time
 |  |  import time
 | 
											
												
													
														|  |  from abc import ABC, abstractmethod
 |  |  from abc import ABC, abstractmethod
 | 
											
												
													
														|  | -from typing import Dict, Tuple
 |  | 
 | 
											
												
													
														|  | 
 |  | +from typing import Dict, Generator, Optional, Tuple, Union
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  from simple_ddl_parser import DDLParser
 |  |  from simple_ddl_parser import DDLParser
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -60,12 +61,12 @@ class Convertor(ABC):
 | 
											
												
													
														|  |          self.table_script_list = re.findall(r"CREATE TABLE [^;]*;", self.content)
 |  |          self.table_script_list = re.findall(r"CREATE TABLE [^;]*;", self.content)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      @abstractmethod
 |  |      @abstractmethod
 | 
											
												
													
														|  | -    def translate_type(self, type: str, size: None | int | Tuple[int]) -> str:
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]) -> str:
 | 
											
												
													
														|  |          """字段类型转换
 |  |          """字段类型转换
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          Args:
 |  |          Args:
 | 
											
												
													
														|  |              type (str): 字段类型
 |  |              type (str): 字段类型
 | 
											
												
													
														|  | -            size (None | int | Tuple[int]): 字段长度描述, 如varchar(255), decimal(10,2)
 |  | 
 | 
											
												
													
														|  | 
 |  | +            size (Optional[Union[int, Tuple[int]]]): 字段长度描述, 如varchar(255), decimal(10,2)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          Returns:
 |  |          Returns:
 | 
											
												
													
														|  |              str: 类型定义
 |  |              str: 类型定义
 | 
											
										
											
												
													
														|  | @@ -97,7 +98,7 @@ class Convertor(ABC):
 | 
											
												
													
														|  |          pass
 |  |          pass
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      @abstractmethod
 |  |      @abstractmethod
 | 
											
												
													
														|  | -    def gen_index(self, table_ddl: Dict) -> str:
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def gen_index(self, ddl: Dict) -> str:
 | 
											
												
													
														|  |          """生成索引定义
 |  |          """生成索引定义
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          Args:
 |  |          Args:
 | 
											
										
											
												
													
														|  | @@ -133,6 +134,55 @@ class Convertor(ABC):
 | 
											
												
													
														|  |          """
 |  |          """
 | 
											
												
													
														|  |          pass
 |  |          pass
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +    @staticmethod
 | 
											
												
													
														|  | 
 |  | +    def inserts(table_name: str, script_content: str) -> Generator:
 | 
											
												
													
														|  | 
 |  | +        PREFIX = f"INSERT INTO `{table_name}`"
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # 收集 `table_name` 对应的 insert 语句
 | 
											
												
													
														|  | 
 |  | +        for line in script_content.split("\n"):
 | 
											
												
													
														|  | 
 |  | +            if line.startswith(PREFIX):
 | 
											
												
													
														|  | 
 |  | +                head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
 | 
											
												
													
														|  | 
 |  | +                head = head.strip().replace("`", "").lower()
 | 
											
												
													
														|  | 
 |  | +                tail = tail.strip().replace(r"\"", '"')
 | 
											
												
													
														|  | 
 |  | +                # tail = tail.replace("b'0'", "'0'").replace("b'1'", "'1'")
 | 
											
												
													
														|  | 
 |  | +                yield f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    @staticmethod
 | 
											
												
													
														|  | 
 |  | +    def index(ddl: Dict) -> Generator:
 | 
											
												
													
														|  | 
 |  | +        """生成索引定义
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        Args:
 | 
											
												
													
														|  | 
 |  | +            ddl (Dict): 表DDL
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        Yields:
 | 
											
												
													
														|  | 
 |  | +            Generator[str]: create index 语句
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        def generate_columns(columns):
 | 
											
												
													
														|  | 
 |  | +            keys = [
 | 
											
												
													
														|  | 
 |  | +                f"{col['name'].lower()}{' ' + col['order'].lower() if col['order'] != 'ASC' else ''}"
 | 
											
												
													
														|  | 
 |  | +                for col in columns[0]
 | 
											
												
													
														|  | 
 |  | +            ]
 | 
											
												
													
														|  | 
 |  | +            return ", ".join(keys)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        for no, index in enumerate(ddl["index"], 1):
 | 
											
												
													
														|  | 
 |  | +            columns = generate_columns(index["columns"])
 | 
											
												
													
														|  | 
 |  | +            table_name = ddl["table_name"].lower()
 | 
											
												
													
														|  | 
 |  | +            yield f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})"
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    @staticmethod
 | 
											
												
													
														|  | 
 |  | +    def filed_comments(table_sql: str) -> Generator:
 | 
											
												
													
														|  | 
 |  | +        for line in table_sql.split("\n"):
 | 
											
												
													
														|  | 
 |  | +            match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip())
 | 
											
												
													
														|  | 
 |  | +            if match:
 | 
											
												
													
														|  | 
 |  | +                field = match.group(1)
 | 
											
												
													
														|  | 
 |  | +                comment_string = match.group(2).replace("\\n", "\n")
 | 
											
												
													
														|  | 
 |  | +                yield field, comment_string
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def table_comment(self, table_sql: str) -> str:
 | 
											
												
													
														|  | 
 |  | +        match = re.search(r"COMMENT \= '([^']+)';", table_sql)
 | 
											
												
													
														|  | 
 |  | +        return match.group(1) if match else None
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |      def print(self):
 |  |      def print(self):
 | 
											
												
													
														|  |          """打印转换后的sql脚本到终端"""
 |  |          """打印转换后的sql脚本到终端"""
 | 
											
												
													
														|  |          print(
 |  |          print(
 | 
											
										
											
												
													
														|  | @@ -192,7 +242,7 @@ class PostgreSQLConvertor(Convertor):
 | 
											
												
													
														|  |      def __init__(self, src):
 |  |      def __init__(self, src):
 | 
											
												
													
														|  |          super().__init__(src, "PostgreSQL")
 |  |          super().__init__(src, "PostgreSQL")
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def translate_type(self, type, size):
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
 | 
											
												
													
														|  |          """类型转换"""
 |  |          """类型转换"""
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          type = type.lower()
 |  |          type = type.lower()
 | 
											
										
											
												
													
														|  | @@ -234,27 +284,30 @@ class PostgreSQLConvertor(Convertor):
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          table_name = ddl["table_name"].lower()
 |  |          table_name = ddl["table_name"].lower()
 | 
											
												
													
														|  |          columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
 |  |          columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
 | 
											
												
													
														|  | 
 |  | +        filed_def_list = ",\n  ".join(columns)
 | 
											
												
													
														|  |          script = f"""-- ----------------------------
 |  |          script = f"""-- ----------------------------
 | 
											
												
													
														|  |  -- Table structure for {table_name}
 |  |  -- Table structure for {table_name}
 | 
											
												
													
														|  |  -- ----------------------------
 |  |  -- ----------------------------
 | 
											
												
													
														|  |  DROP TABLE IF EXISTS {table_name};
 |  |  DROP TABLE IF EXISTS {table_name};
 | 
											
												
													
														|  |  CREATE TABLE {table_name} (
 |  |  CREATE TABLE {table_name} (
 | 
											
												
													
														|  | -    {',\n  '.join(columns)}
 |  | 
 | 
											
												
													
														|  | 
 |  | +    {filed_def_list}
 | 
											
												
													
														|  |  );"""
 |  |  );"""
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          return script
 |  |          return script
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def gen_comment(self, table_sql, table_name) -> str:
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def gen_index(self, ddl: Dict) -> str:
 | 
											
												
													
														|  | 
 |  | +        return "\n".join(f"{script};" for script in self.index(ddl))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def gen_comment(self, table_sql: str, table_name: str) -> str:
 | 
											
												
													
														|  |          """生成字段及表的注释"""
 |  |          """生成字段及表的注释"""
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          script = ""
 |  |          script = ""
 | 
											
												
													
														|  | -        for line in table_sql.split("\n"):
 |  | 
 | 
											
												
													
														|  | -            match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip())
 |  | 
 | 
											
												
													
														|  | -            if match:
 |  | 
 | 
											
												
													
														|  | -                script += f"COMMENT ON COLUMN {table_name}.{match.group(1)} IS '{match.group(2).replace('\\n', '\n')}';\n"
 |  | 
 | 
											
												
													
														|  | 
 |  | +        for field, comment_string in self.filed_comments(table_sql):
 | 
											
												
													
														|  | 
 |  | +            script += (
 | 
											
												
													
														|  | 
 |  | +                f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n"
 | 
											
												
													
														|  | 
 |  | +            )
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        match = re.search(r"COMMENT \= '([^']+)';", table_sql)
 |  | 
 | 
											
												
													
														|  | -        table_comment = match.group(1) if match else None
 |  | 
 | 
											
												
													
														|  | 
 |  | +        table_comment = self.table_comment(table_sql)
 | 
											
												
													
														|  |          if table_comment:
 |  |          if table_comment:
 | 
											
												
													
														|  |              script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
 |  |              script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -264,53 +317,21 @@ CREATE TABLE {table_name} (
 | 
											
												
													
														|  |          """生成主键定义"""
 |  |          """生成主键定义"""
 | 
											
												
													
														|  |          return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
 |  |          return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def gen_index(self, ddl) -> str:
 |  | 
 | 
											
												
													
														|  | -        """生成 index"""
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        def generate_columns(columns):
 |  | 
 | 
											
												
													
														|  | -            keys = [
 |  | 
 | 
											
												
													
														|  | -                f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
 |  | 
 | 
											
												
													
														|  | -                for col in columns[0]
 |  | 
 | 
											
												
													
														|  | -            ]
 |  | 
 | 
											
												
													
														|  | -            return ", ".join(keys)
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        script = ""
 |  | 
 | 
											
												
													
														|  | -        for no, index in enumerate(ddl["index"], 1):
 |  | 
 | 
											
												
													
														|  | -            columns = generate_columns(index["columns"])
 |  | 
 | 
											
												
													
														|  | -            table_name = ddl["table_name"].lower()
 |  | 
 | 
											
												
													
														|  | -            script += (
 |  | 
 | 
											
												
													
														|  | -                f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns});\n"
 |  | 
 | 
											
												
													
														|  | -            )
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        return script
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -    def gen_insert(self, table_name) -> str:
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def gen_insert(self, table_name: str) -> str:
 | 
											
												
													
														|  |          """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
 |  |          """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        PREFIX = f"INSERT INTO `{table_name}`"
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        # 收集 `table_name` 对应的 insert 语句
 |  | 
 | 
											
												
													
														|  | -        inserts = []
 |  | 
 | 
											
												
													
														|  | -        for line in self.content.split("\n"):
 |  | 
 | 
											
												
													
														|  | -            if line.startswith(PREFIX):
 |  | 
 | 
											
												
													
														|  | -                head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
 |  | 
 | 
											
												
													
														|  | -                head = head.strip().replace("`", "").lower()
 |  | 
 | 
											
												
													
														|  | -                tail = tail.strip().replace(r"\"", '"')
 |  | 
 | 
											
												
													
														|  | -                script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
 |  | 
 | 
											
												
													
														|  | -                # bit(1)数据转换
 |  | 
 | 
											
												
													
														|  | -                script = script.replace("b'0'", "'0'").replace("b'1'", "'1'")
 |  | 
 | 
											
												
													
														|  | -                inserts.append(script)
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | 
 |  | +        inserts = list(Convertor.inserts(table_name, self.content))
 | 
											
												
													
														|  |          ## 生成 insert 脚本
 |  |          ## 生成 insert 脚本
 | 
											
												
													
														|  |          script = ""
 |  |          script = ""
 | 
											
												
													
														|  |          last_id = 0
 |  |          last_id = 0
 | 
											
												
													
														|  |          if inserts:
 |  |          if inserts:
 | 
											
												
													
														|  | 
 |  | +            inserts_lines = "\n".join(inserts)
 | 
											
												
													
														|  |              script += f"""\n\n-- ----------------------------
 |  |              script += f"""\n\n-- ----------------------------
 | 
											
												
													
														|  |  -- Records of {table_name.lower()}
 |  |  -- Records of {table_name.lower()}
 | 
											
												
													
														|  |  -- ----------------------------
 |  |  -- ----------------------------
 | 
											
												
													
														|  |  -- @formatter:off
 |  |  -- @formatter:off
 | 
											
												
													
														|  |  BEGIN;
 |  |  BEGIN;
 | 
											
												
													
														|  | -{'\n'.join(inserts)}
 |  | 
 | 
											
												
													
														|  | 
 |  | +{inserts_lines}
 | 
											
												
													
														|  |  COMMIT;
 |  |  COMMIT;
 | 
											
												
													
														|  |  -- @formatter:on"""
 |  |  -- @formatter:on"""
 | 
											
												
													
														|  |              match = re.search(r"VALUES \((\d+),", inserts[-1])
 |  |              match = re.search(r"VALUES \((\d+),", inserts[-1])
 | 
											
										
											
												
													
														|  | @@ -332,7 +353,7 @@ class OracleConvertor(Convertor):
 | 
											
												
													
														|  |      def __init__(self, src):
 |  |      def __init__(self, src):
 | 
											
												
													
														|  |          super().__init__(src, "Oracle")
 |  |          super().__init__(src, "Oracle")
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def translate_type(self, type, size: None | int | Tuple[int]):
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
 | 
											
												
													
														|  |          """类型转换"""
 |  |          """类型转换"""
 | 
											
												
													
														|  |          type = type.lower()
 |  |          type = type.lower()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -369,15 +390,19 @@ class OracleConvertor(Convertor):
 | 
											
												
													
														|  |              full_type = self.translate_type(type, col["size"])
 |  |              full_type = self.translate_type(type, col["size"])
 | 
											
												
													
														|  |              nullable = "NULL" if col["nullable"] else "NOT NULL"
 |  |              nullable = "NULL" if col["nullable"] else "NOT NULL"
 | 
											
												
													
														|  |              default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
 |  |              default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
 | 
											
												
													
														|  | -            return f"{'\"size\"' if name == "size" else name } {full_type} {default} {nullable}"
 |  | 
 | 
											
												
													
														|  | 
 |  | +            # Oracle 中 size 不能作为字段名
 | 
											
												
													
														|  | 
 |  | +            field_name = '"size"' if name == "size" else name
 | 
											
												
													
														|  | 
 |  | +            # Oracle DEFAULT 定义在 NULLABLE 之前
 | 
											
												
													
														|  | 
 |  | +            return f"{field_name} {full_type} {default} {nullable}"
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          table_name = ddl["table_name"].lower()
 |  |          table_name = ddl["table_name"].lower()
 | 
											
												
													
														|  |          columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]]
 |  |          columns = [f"{generate_column(col).strip()}" for col in ddl["columns"]]
 | 
											
												
													
														|  | 
 |  | +        field_def_list = ",\n    ".join(columns)
 | 
											
												
													
														|  |          script = f"""-- ----------------------------
 |  |          script = f"""-- ----------------------------
 | 
											
												
													
														|  |  -- Table structure for {table_name}
 |  |  -- Table structure for {table_name}
 | 
											
												
													
														|  |  -- ----------------------------
 |  |  -- ----------------------------
 | 
											
												
													
														|  | -CREATE TABLE {ddl['table_name'].lower()} (
 |  | 
 | 
											
												
													
														|  | -    {',\n    '.join(columns)}
 |  | 
 | 
											
												
													
														|  | 
 |  | +CREATE TABLE {table_name} (
 | 
											
												
													
														|  | 
 |  | +    {field_def_list}
 | 
											
												
													
														|  |  );"""
 |  |  );"""
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          # oracle INSERT '' 不能通过 NOT NULL 校验
 |  |          # oracle INSERT '' 不能通过 NOT NULL 校验
 | 
											
										
											
												
													
														|  | @@ -385,72 +410,51 @@ CREATE TABLE {ddl['table_name'].lower()} (
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          return script
 |  |          return script
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def gen_comment(self, table_sql, table_name) -> str:
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def gen_index(self, ddl: Dict) -> str:
 | 
											
												
													
														|  | 
 |  | +        return "\n".join(f"{script};" for script in self.index(ddl))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def gen_comment(self, table_sql: str, table_name: str) -> str:
 | 
											
												
													
														|  |          script = ""
 |  |          script = ""
 | 
											
												
													
														|  | -        for line in table_sql.split("\n"):
 |  | 
 | 
											
												
													
														|  | -            match = re.search(r"`([^`]+)`.* COMMENT '([^']+)'", line)
 |  | 
 | 
											
												
													
														|  | -            if match:
 |  | 
 | 
											
												
													
														|  | -                script += f"COMMENT ON COLUMN {table_name}.{match.group(1)} IS '{match.group(2).replace('\\n', '\n')}';\n"
 |  | 
 | 
											
												
													
														|  | 
 |  | +        for field, comment_string in self.filed_comments(table_sql):
 | 
											
												
													
														|  | 
 |  | +            script += (
 | 
											
												
													
														|  | 
 |  | +                f"COMMENT ON COLUMN {table_name}.{field} IS '{comment_string}';" + "\n"
 | 
											
												
													
														|  | 
 |  | +            )
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        match = re.search(r"COMMENT \= '([^']+)';", table_sql)
 |  | 
 | 
											
												
													
														|  | -        table_comment = match.group(1) if match else None
 |  | 
 | 
											
												
													
														|  | 
 |  | +        table_comment = self.table_comment(table_sql)
 | 
											
												
													
														|  |          if table_comment:
 |  |          if table_comment:
 | 
											
												
													
														|  | -            script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';"
 |  | 
 | 
											
												
													
														|  | 
 |  | +            script += f"COMMENT ON TABLE {table_name} IS '{table_comment}';\n"
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          return script
 |  |          return script
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def gen_pk(self, table_name) -> str:
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def gen_pk(self, table_name: str) -> str:
 | 
											
												
													
														|  |          """生成主键定义"""
 |  |          """生成主键定义"""
 | 
											
												
													
														|  |          return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
 |  |          return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY (id);\n"
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def gen_index(self, table_ddl) -> str:
 |  | 
 | 
											
												
													
														|  | -        """生成 INDEX 定义"""
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        def generate_columns(columns):
 |  | 
 | 
											
												
													
														|  | -            keys = [
 |  | 
 | 
											
												
													
														|  | -                f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
 |  | 
 | 
											
												
													
														|  | -                for col in columns[0]
 |  | 
 | 
											
												
													
														|  | -            ]
 |  | 
 | 
											
												
													
														|  | -            return ", ".join(keys)
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        script = ""
 |  | 
 | 
											
												
													
														|  | -        for no, index in enumerate(table_ddl["index"], 1):
 |  | 
 | 
											
												
													
														|  | -            columns = generate_columns(index["columns"])
 |  | 
 | 
											
												
													
														|  | -            table_name = table_ddl["table_name"].lower()
 |  | 
 | 
											
												
													
														|  | -            script += (
 |  | 
 | 
											
												
													
														|  | -                f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns});\n"
 |  | 
 | 
											
												
													
														|  | -            )
 |  | 
 | 
											
												
													
														|  | -        return script
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def gen_index(self, ddl: Dict) -> str:
 | 
											
												
													
														|  | 
 |  | +        return "\n".join(f"{script};" for script in self.index(ddl))
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def gen_insert(self, table_name) -> str:
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def gen_insert(self, table_name: str) -> str:
 | 
											
												
													
														|  |          """拷贝 INSERT 语句"""
 |  |          """拷贝 INSERT 语句"""
 | 
											
												
													
														|  | -        PREFIX = f"INSERT INTO `{table_name}`"
 |  | 
 | 
											
												
													
														|  |          inserts = []
 |  |          inserts = []
 | 
											
												
													
														|  | -        for line in self.content.split("\n"):
 |  | 
 | 
											
												
													
														|  | -            if line.startswith(PREFIX):
 |  | 
 | 
											
												
													
														|  | -                head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
 |  | 
 | 
											
												
													
														|  | -                head = head.strip().replace("`", "").lower()
 |  | 
 | 
											
												
													
														|  | -                tail = tail.strip().replace(r"\"", '"')
 |  | 
 | 
											
												
													
														|  | -                script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
 |  | 
 | 
											
												
													
														|  | -                # bit(1)数据转换
 |  | 
 | 
											
												
													
														|  | -                script = script.replace("b'0'", "'0'").replace("b'1'", "'1'")
 |  | 
 | 
											
												
													
														|  | -                # 对日期数据添加 TO_DATE 转换
 |  | 
 | 
											
												
													
														|  | -                script = re.sub(
 |  | 
 | 
											
												
													
														|  | -                    r"('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')",
 |  | 
 | 
											
												
													
														|  | -                    r"to_date(\g<1>, 'SYYYY-MM-DD HH24:MI:SS')",
 |  | 
 | 
											
												
													
														|  | -                    script,
 |  | 
 | 
											
												
													
														|  | -                )
 |  | 
 | 
											
												
													
														|  | -                inserts.append(script)
 |  | 
 | 
											
												
													
														|  | 
 |  | +        for insert_script in Convertor.inserts(table_name, self.content):
 | 
											
												
													
														|  | 
 |  | +            # 对日期数据添加 TO_DATE 转换
 | 
											
												
													
														|  | 
 |  | +            insert_script = re.sub(
 | 
											
												
													
														|  | 
 |  | +                r"('\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}')",
 | 
											
												
													
														|  | 
 |  | +                r"to_date(\g<1>, 'SYYYY-MM-DD HH24:MI:SS')",
 | 
											
												
													
														|  | 
 |  | +                insert_script,
 | 
											
												
													
														|  | 
 |  | +            )
 | 
											
												
													
														|  | 
 |  | +            inserts.append(insert_script)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          ## 生成 insert 脚本
 |  |          ## 生成 insert 脚本
 | 
											
												
													
														|  |          script = ""
 |  |          script = ""
 | 
											
												
													
														|  |          last_id = 0
 |  |          last_id = 0
 | 
											
												
													
														|  |          if inserts:
 |  |          if inserts:
 | 
											
												
													
														|  | 
 |  | +            inserts_lines = "\n".join(inserts)
 | 
											
												
													
														|  |              script += f"""\n\n-- ----------------------------
 |  |              script += f"""\n\n-- ----------------------------
 | 
											
												
													
														|  |  -- Records of {table_name.lower()}
 |  |  -- Records of {table_name.lower()}
 | 
											
												
													
														|  |  -- ----------------------------
 |  |  -- ----------------------------
 | 
											
												
													
														|  |  -- @formatter:off
 |  |  -- @formatter:off
 | 
											
												
													
														|  | -{'\n'.join(inserts)}
 |  | 
 | 
											
												
													
														|  | 
 |  | +{inserts_lines}
 | 
											
												
													
														|  |  COMMIT;
 |  |  COMMIT;
 | 
											
												
													
														|  |  -- @formatter:on"""
 |  |  -- @formatter:on"""
 | 
											
												
													
														|  |              match = re.search(r"VALUES \((\d+),", inserts[-1])
 |  |              match = re.search(r"VALUES \((\d+),", inserts[-1])
 | 
											
										
											
												
													
														|  | @@ -476,7 +480,7 @@ class SQLServerConvertor(Convertor):
 | 
											
												
													
														|  |      def __init__(self, src):
 |  |      def __init__(self, src):
 | 
											
												
													
														|  |          super().__init__(src, "Microsoft SQL Server")
 |  |          super().__init__(src, "Microsoft SQL Server")
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def translate_type(self, type, size):
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
 | 
											
												
													
														|  |          """类型转换"""
 |  |          """类型转换"""
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          type = type.lower()
 |  |          type = type.lower()
 | 
											
										
											
												
													
														|  | @@ -507,7 +511,7 @@ class SQLServerConvertor(Convertor):
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          def _generate_column(col):
 |  |          def _generate_column(col):
 | 
											
												
													
														|  |              name = col["name"].lower()
 |  |              name = col["name"].lower()
 | 
											
												
													
														|  | -            if name == 'id':
 |  | 
 | 
											
												
													
														|  | 
 |  | +            if name == "id":
 | 
											
												
													
														|  |                  return "id bigint NOT NULL PRIMARY KEY IDENTITY"
 |  |                  return "id bigint NOT NULL PRIMARY KEY IDENTITY"
 | 
											
												
													
														|  |              if name == "deleted":
 |  |              if name == "deleted":
 | 
											
												
													
														|  |                  return "deleted bit DEFAULT 0 NOT NULL"
 |  |                  return "deleted bit DEFAULT 0 NOT NULL"
 | 
											
										
											
												
													
														|  | @@ -520,35 +524,34 @@ class SQLServerConvertor(Convertor):
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          table_name = ddl["table_name"].lower()
 |  |          table_name = ddl["table_name"].lower()
 | 
											
												
													
														|  |          columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
 |  |          columns = [f"{_generate_column(col).strip()}" for col in ddl["columns"]]
 | 
											
												
													
														|  | 
 |  | +        filed_def_list = ",\n    ".join(columns)
 | 
											
												
													
														|  |          script = f"""-- ----------------------------
 |  |          script = f"""-- ----------------------------
 | 
											
												
													
														|  |  -- Table structure for {table_name}
 |  |  -- Table structure for {table_name}
 | 
											
												
													
														|  |  -- ----------------------------
 |  |  -- ----------------------------
 | 
											
												
													
														|  |  DROP TABLE IF EXISTS {table_name};
 |  |  DROP TABLE IF EXISTS {table_name};
 | 
											
												
													
														|  |  CREATE TABLE {table_name} (
 |  |  CREATE TABLE {table_name} (
 | 
											
												
													
														|  | -    {',\n    '.join(columns)}
 |  | 
 | 
											
												
													
														|  | 
 |  | +    {filed_def_list}
 | 
											
												
													
														|  |  )
 |  |  )
 | 
											
												
													
														|  |  GO"""
 |  |  GO"""
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          return script
 |  |          return script
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def gen_comment(self, table_sql, table_name) -> str:
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def gen_comment(self, table_sql: str, table_name: str) -> str:
 | 
											
												
													
														|  |          """生成字段及表的注释"""
 |  |          """生成字段及表的注释"""
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          script = ""
 |  |          script = ""
 | 
											
												
													
														|  | -        for line in table_sql.split("\n"):
 |  | 
 | 
											
												
													
														|  | -            match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip())
 |  | 
 | 
											
												
													
														|  | -            if match:
 |  | 
 | 
											
												
													
														|  | -                script += f"""EXEC sp_addextendedproperty
 |  | 
 | 
											
												
													
														|  | -    'MS_Description', N'{match.group(2).replace('\\n', '\n')}',
 |  | 
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        for field, comment_string in self.filed_comments(table_sql):
 | 
											
												
													
														|  | 
 |  | +            script += f"""EXEC sp_addextendedproperty
 | 
											
												
													
														|  | 
 |  | +    'MS_Description', N'{comment_string}',
 | 
											
												
													
														|  |      'SCHEMA', N'dbo',
 |  |      'SCHEMA', N'dbo',
 | 
											
												
													
														|  |      'TABLE', N'{table_name}',
 |  |      'TABLE', N'{table_name}',
 | 
											
												
													
														|  | -    'COLUMN', N'{match.group(1)}'
 |  | 
 | 
											
												
													
														|  | 
 |  | +    'COLUMN', N'{field}'
 | 
											
												
													
														|  |  GO
 |  |  GO
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  """
 |  |  """
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        match = re.search(r"COMMENT \= '([^']+)';", table_sql)
 |  | 
 | 
											
												
													
														|  | -        table_comment = match.group(1) if match else None
 |  | 
 | 
											
												
													
														|  | 
 |  | +        table_comment = self.table_comment(table_sql)
 | 
											
												
													
														|  |          if table_comment:
 |  |          if table_comment:
 | 
											
												
													
														|  |              script += f"""EXEC sp_addextendedproperty
 |  |              script += f"""EXEC sp_addextendedproperty
 | 
											
												
													
														|  |      'MS_Description', N'{table_comment}',
 |  |      'MS_Description', N'{table_comment}',
 | 
											
										
											
												
													
														|  | @@ -557,55 +560,34 @@ GO
 | 
											
												
													
														|  |  GO
 |  |  GO
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  """
 |  |  """
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  |          return script
 |  |          return script
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def gen_pk(self, table_name) -> str:
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def gen_pk(self, table_name: str) -> str:
 | 
											
												
													
														|  |          """生成主键定义"""
 |  |          """生成主键定义"""
 | 
											
												
													
														|  |          return ""
 |  |          return ""
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def gen_index(self, ddl) -> str:
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def gen_index(self, ddl: Dict) -> str:
 | 
											
												
													
														|  |          """生成 index"""
 |  |          """生成 index"""
 | 
											
												
													
														|  | 
 |  | +        return "\n".join(f"{script}\nGO" for script in self.index(ddl))
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        def generate_columns(columns):
 |  | 
 | 
											
												
													
														|  | -            keys = [
 |  | 
 | 
											
												
													
														|  | -                f"{col['name'].lower()}{" " + col['order'].lower() if col['order'] != 'ASC' else ''}"
 |  | 
 | 
											
												
													
														|  | -                for col in columns[0]
 |  | 
 | 
											
												
													
														|  | -            ]
 |  | 
 | 
											
												
													
														|  | -            return ", ".join(keys)
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        script = ""
 |  | 
 | 
											
												
													
														|  | -        for no, index in enumerate(ddl["index"], 1):
 |  | 
 | 
											
												
													
														|  | -            columns = generate_columns(index["columns"])
 |  | 
 | 
											
												
													
														|  | -            table_name = ddl["table_name"].lower()
 |  | 
 | 
											
												
													
														|  | -            script += f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})\nGO\n"
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        return script
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -    def gen_insert(self, table_name) -> str:
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def gen_insert(self, table_name: str) -> str:
 | 
											
												
													
														|  |          """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
 |  |          """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        PREFIX = f"INSERT INTO `{table_name}`"
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  |          # 收集 `table_name` 对应的 insert 语句
 |  |          # 收集 `table_name` 对应的 insert 语句
 | 
											
												
													
														|  |          inserts = []
 |  |          inserts = []
 | 
											
												
													
														|  | -        for line in self.content.split("\n"):
 |  | 
 | 
											
												
													
														|  | -            if line.startswith(PREFIX):
 |  | 
 | 
											
												
													
														|  | -                head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
 |  | 
 | 
											
												
													
														|  | -                head = head.strip().replace("`", "").lower()
 |  | 
 | 
											
												
													
														|  | -                tail = tail.strip().replace(r"\"", '"')
 |  | 
 | 
											
												
													
														|  | -                # SQLServer: 字符串前加N,hack,是否存在替换字符串内容的风险
 |  | 
 | 
											
												
													
														|  | -                tail = tail.replace(", '", ", N'").replace("VALUES ('", "VALUES (N')")
 |  | 
 | 
											
												
													
														|  | -                script = f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
 |  | 
 | 
											
												
													
														|  | -                # bit(1)数据转换
 |  | 
 | 
											
												
													
														|  | -                script = script.replace("b'0'", "'0'").replace("b'1'", "'1'")
 |  | 
 | 
											
												
													
														|  | -                # 删除 insert 的结尾分号
 |  | 
 | 
											
												
													
														|  | -                script = re.sub(";$", r"\nGO", script)
 |  | 
 | 
											
												
													
														|  | -                inserts.append(script)
 |  | 
 | 
											
												
													
														|  | 
 |  | +        for insert_script in Convertor.inserts(table_name, self.content):
 | 
											
												
													
														|  | 
 |  | +            # SQLServer: 字符串前加N,hack,是否存在替换字符串内容的风险
 | 
											
												
													
														|  | 
 |  | +            insert_script = insert_script.replace(", '", ", N'").replace(
 | 
											
												
													
														|  | 
 |  | +                "VALUES ('", "VALUES (N')"
 | 
											
												
													
														|  | 
 |  | +            )
 | 
											
												
													
														|  | 
 |  | +            # 删除 insert 的结尾分号
 | 
											
												
													
														|  | 
 |  | +            insert_script = re.sub(";$", r"\nGO", insert_script)
 | 
											
												
													
														|  | 
 |  | +            inserts.append(insert_script)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          ## 生成 insert 脚本
 |  |          ## 生成 insert 脚本
 | 
											
												
													
														|  |          script = ""
 |  |          script = ""
 | 
											
												
													
														|  |          if inserts:
 |  |          if inserts:
 | 
											
												
													
														|  | 
 |  | +            inserts_lines = "\n".join(inserts)
 | 
											
												
													
														|  |              script += f"""\n\n-- ----------------------------
 |  |              script += f"""\n\n-- ----------------------------
 | 
											
												
													
														|  |  -- Records of {table_name.lower()}
 |  |  -- Records of {table_name.lower()}
 | 
											
												
													
														|  |  -- ----------------------------
 |  |  -- ----------------------------
 | 
											
										
											
												
													
														|  | @@ -614,7 +596,7 @@ BEGIN TRANSACTION
 | 
											
												
													
														|  |  GO
 |  |  GO
 | 
											
												
													
														|  |  SET IDENTITY_INSERT {table_name.lower()} ON
 |  |  SET IDENTITY_INSERT {table_name.lower()} ON
 | 
											
												
													
														|  |  GO
 |  |  GO
 | 
											
												
													
														|  | -{'\n'.join(inserts)}
 |  | 
 | 
											
												
													
														|  | 
 |  | +{inserts_lines}
 | 
											
												
													
														|  |  SET IDENTITY_INSERT {table_name.lower()} OFF
 |  |  SET IDENTITY_INSERT {table_name.lower()} OFF
 | 
											
												
													
														|  |  GO
 |  |  GO
 | 
											
												
													
														|  |  COMMIT
 |  |  COMMIT
 | 
											
										
											
												
													
														|  | @@ -625,10 +607,26 @@ GO
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  def main():
 |  |  def main():
 | 
											
												
													
														|  | -    sql_file = pathlib.Path('../mysql/ruoyi-vue-pro.sql').resolve().as_posix()
 |  | 
 | 
											
												
													
														|  | -    # convertor = PostgreSQLConvertor(sql_file)
 |  | 
 | 
											
												
													
														|  | -    # convertor = OracleConvertor(sql_file)
 |  | 
 | 
											
												
													
														|  | -    convertor = SQLServerConvertor(sql_file)
 |  | 
 | 
											
												
													
														|  | 
 |  | +    parser = argparse.ArgumentParser(description="芋道系统数据库转换工具")
 | 
											
												
													
														|  | 
 |  | +    parser.add_argument(
 | 
											
												
													
														|  | 
 |  | +        "type",
 | 
											
												
													
														|  | 
 |  | +        type=str,
 | 
											
												
													
														|  | 
 |  | +        help="目标数据库类型",
 | 
											
												
													
														|  | 
 |  | +        choices=["postgres", "oracle", "sqlserver"],
 | 
											
												
													
														|  | 
 |  | +    )
 | 
											
												
													
														|  | 
 |  | +    args = parser.parse_args()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    sql_file = pathlib.Path("../mysql/ruoyi-vue-pro.sql").resolve().as_posix()
 | 
											
												
													
														|  | 
 |  | +    convertor = None
 | 
											
												
													
														|  | 
 |  | +    if args.type == "postgres":
 | 
											
												
													
														|  | 
 |  | +        convertor = PostgreSQLConvertor(sql_file)
 | 
											
												
													
														|  | 
 |  | +    elif args.type == "oracle":
 | 
											
												
													
														|  | 
 |  | +        convertor = OracleConvertor(sql_file)
 | 
											
												
													
														|  | 
 |  | +    elif args.type == "sqlserver":
 | 
											
												
													
														|  | 
 |  | +        convertor = SQLServerConvertor(sql_file)
 | 
											
												
													
														|  | 
 |  | +    else:
 | 
											
												
													
														|  | 
 |  | +        raise NotImplementedError(f"不支持目标数据库类型: {args.type}")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |      convertor.print()
 |  |      convertor.print()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 |