add_supplier_and_settlement.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. """
  2. Database migration script for Supplier Management feature.
  3. This script adds:
  4. 1. suppliers table with unique name constraint
  5. 2. Unique index on persons.name
  6. 3. Unique index on items.name
  7. 4. supplier_id foreign key on items table
  8. 5. is_settled field on work_records table
  9. Compatible with both SQLite and PostgreSQL.
  10. Requirements: 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7
  11. """
  12. import os
  13. import sys
  14. from datetime import datetime, timezone
  15. # Add parent directory to path for imports
  16. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  17. from sqlalchemy import create_engine, text, inspect
  18. from sqlalchemy.exc import OperationalError, IntegrityError
  19. def get_database_url():
  20. """Get database URL from environment or use default SQLite."""
  21. return os.environ.get('DATABASE_URL') or \
  22. 'sqlite:///' + os.path.join(os.path.dirname(os.path.dirname(__file__)), 'dev.db')
  23. def is_sqlite(engine):
  24. """Check if the database is SQLite."""
  25. return 'sqlite' in engine.dialect.name.lower()
  26. def is_postgresql(engine):
  27. """Check if the database is PostgreSQL."""
  28. return 'postgresql' in engine.dialect.name.lower()
  29. def index_exists(engine, table_name, index_name):
  30. """Check if an index exists on a table."""
  31. inspector = inspect(engine)
  32. indexes = inspector.get_indexes(table_name)
  33. return any(idx['name'] == index_name for idx in indexes)
  34. def column_exists(engine, table_name, column_name):
  35. """Check if a column exists in a table."""
  36. inspector = inspect(engine)
  37. columns = inspector.get_columns(table_name)
  38. return any(col['name'] == column_name for col in columns)
  39. def table_exists(engine, table_name):
  40. """Check if a table exists in the database."""
  41. inspector = inspect(engine)
  42. return table_name in inspector.get_table_names()
  43. def upgrade(engine):
  44. """
  45. Upgrade the database schema.
  46. Operations:
  47. 1. Create suppliers table with unique name index
  48. 2. Add unique index on persons.name
  49. 3. Add unique index on items.name
  50. 4. Add supplier_id column to items table
  51. 5. Add is_settled column to work_records table
  52. """
  53. with engine.connect() as conn:
  54. # 1. Create suppliers table
  55. if not table_exists(engine, 'suppliers'):
  56. print("Creating suppliers table...")
  57. conn.execute(text("""
  58. CREATE TABLE suppliers (
  59. id INTEGER PRIMARY KEY AUTOINCREMENT,
  60. name VARCHAR(100) NOT NULL,
  61. created_at DATETIME,
  62. updated_at DATETIME
  63. )
  64. """) if is_sqlite(engine) else text("""
  65. CREATE TABLE suppliers (
  66. id SERIAL PRIMARY KEY,
  67. name VARCHAR(100) NOT NULL,
  68. created_at TIMESTAMP,
  69. updated_at TIMESTAMP
  70. )
  71. """))
  72. conn.commit()
  73. print(" - suppliers table created")
  74. else:
  75. print(" - suppliers table already exists, skipping")
  76. # 2. Add unique index on suppliers.name
  77. if not index_exists(engine, 'suppliers', 'ix_suppliers_name'):
  78. print("Adding unique index on suppliers.name...")
  79. conn.execute(text(
  80. "CREATE UNIQUE INDEX ix_suppliers_name ON suppliers (name)"
  81. ))
  82. conn.commit()
  83. print(" - ix_suppliers_name index created")
  84. else:
  85. print(" - ix_suppliers_name index already exists, skipping")
  86. # 3. Add unique index on persons.name
  87. if not index_exists(engine, 'persons', 'ix_persons_name_unique'):
  88. print("Adding unique index on persons.name...")
  89. try:
  90. conn.execute(text(
  91. "CREATE UNIQUE INDEX ix_persons_name_unique ON persons (name)"
  92. ))
  93. conn.commit()
  94. print(" - ix_persons_name_unique index created")
  95. except (OperationalError, IntegrityError) as e:
  96. print(f" - Warning: Could not create unique index on persons.name: {e}")
  97. print(" This may be due to duplicate names in existing data.")
  98. conn.rollback()
  99. else:
  100. print(" - ix_persons_name_unique index already exists, skipping")
  101. # 4. Add unique index on items.name
  102. if not index_exists(engine, 'items', 'ix_items_name_unique'):
  103. print("Adding unique index on items.name...")
  104. try:
  105. conn.execute(text(
  106. "CREATE UNIQUE INDEX ix_items_name_unique ON items (name)"
  107. ))
  108. conn.commit()
  109. print(" - ix_items_name_unique index created")
  110. except (OperationalError, IntegrityError) as e:
  111. print(f" - Warning: Could not create unique index on items.name: {e}")
  112. print(" This may be due to duplicate names in existing data.")
  113. conn.rollback()
  114. else:
  115. print(" - ix_items_name_unique index already exists, skipping")
  116. # 5. Add supplier_id column to items table
  117. if not column_exists(engine, 'items', 'supplier_id'):
  118. print("Adding supplier_id column to items table...")
  119. conn.execute(text(
  120. "ALTER TABLE items ADD COLUMN supplier_id INTEGER REFERENCES suppliers(id)"
  121. ))
  122. conn.commit()
  123. print(" - supplier_id column added to items")
  124. else:
  125. print(" - supplier_id column already exists in items, skipping")
  126. # 6. Add is_settled column to work_records table
  127. if not column_exists(engine, 'work_records', 'is_settled'):
  128. print("Adding is_settled column to work_records table...")
  129. if is_sqlite(engine):
  130. # SQLite doesn't support DEFAULT in ALTER TABLE well,
  131. # so we add the column and then update existing rows
  132. conn.execute(text(
  133. "ALTER TABLE work_records ADD COLUMN is_settled BOOLEAN DEFAULT 0"
  134. ))
  135. conn.execute(text(
  136. "UPDATE work_records SET is_settled = 0 WHERE is_settled IS NULL"
  137. ))
  138. else:
  139. # PostgreSQL supports DEFAULT in ALTER TABLE
  140. conn.execute(text(
  141. "ALTER TABLE work_records ADD COLUMN is_settled BOOLEAN NOT NULL DEFAULT FALSE"
  142. ))
  143. conn.commit()
  144. print(" - is_settled column added to work_records")
  145. else:
  146. print(" - is_settled column already exists in work_records, skipping")
  147. print("\nMigration upgrade completed successfully!")
  148. def downgrade(engine):
  149. """
  150. Downgrade the database schema (reverse the migration).
  151. Operations:
  152. 1. Remove is_settled column from work_records
  153. 2. Remove supplier_id column from items
  154. 3. Remove unique index from items.name
  155. 4. Remove unique index from persons.name
  156. 5. Remove unique index from suppliers.name
  157. 6. Drop suppliers table
  158. """
  159. with engine.connect() as conn:
  160. # For SQLite, we need to recreate tables to remove columns
  161. # For PostgreSQL, we can use ALTER TABLE DROP COLUMN
  162. if is_sqlite(engine):
  163. print("SQLite detected - column removal requires table recreation")
  164. print("Downgrade for SQLite is not fully supported.")
  165. print("Please manually recreate the database if needed.")
  166. # We can still drop indexes and the suppliers table
  167. print("Dropping indexes...")
  168. try:
  169. conn.execute(text("DROP INDEX IF EXISTS ix_items_name_unique"))
  170. conn.execute(text("DROP INDEX IF EXISTS ix_persons_name_unique"))
  171. conn.execute(text("DROP INDEX IF EXISTS ix_suppliers_name"))
  172. conn.commit()
  173. except Exception as e:
  174. print(f" - Warning: {e}")
  175. conn.rollback()
  176. print("Dropping suppliers table...")
  177. try:
  178. conn.execute(text("DROP TABLE IF EXISTS suppliers"))
  179. conn.commit()
  180. except Exception as e:
  181. print(f" - Warning: {e}")
  182. conn.rollback()
  183. else:
  184. # PostgreSQL downgrade
  185. print("Removing is_settled column from work_records...")
  186. try:
  187. conn.execute(text("ALTER TABLE work_records DROP COLUMN IF EXISTS is_settled"))
  188. conn.commit()
  189. except Exception as e:
  190. print(f" - Warning: {e}")
  191. conn.rollback()
  192. print("Removing supplier_id column from items...")
  193. try:
  194. conn.execute(text("ALTER TABLE items DROP COLUMN IF EXISTS supplier_id"))
  195. conn.commit()
  196. except Exception as e:
  197. print(f" - Warning: {e}")
  198. conn.rollback()
  199. print("Dropping indexes...")
  200. try:
  201. conn.execute(text("DROP INDEX IF EXISTS ix_items_name_unique"))
  202. conn.execute(text("DROP INDEX IF EXISTS ix_persons_name_unique"))
  203. conn.execute(text("DROP INDEX IF EXISTS ix_suppliers_name"))
  204. conn.commit()
  205. except Exception as e:
  206. print(f" - Warning: {e}")
  207. conn.rollback()
  208. print("Dropping suppliers table...")
  209. try:
  210. conn.execute(text("DROP TABLE IF EXISTS suppliers"))
  211. conn.commit()
  212. except Exception as e:
  213. print(f" - Warning: {e}")
  214. conn.rollback()
  215. print("\nMigration downgrade completed!")
  216. def main():
  217. """Main entry point for the migration script."""
  218. import argparse
  219. parser = argparse.ArgumentParser(
  220. description='Database migration for Supplier Management feature'
  221. )
  222. parser.add_argument(
  223. 'action',
  224. choices=['upgrade', 'downgrade'],
  225. help='Migration action to perform'
  226. )
  227. parser.add_argument(
  228. '--database-url',
  229. help='Database URL (overrides DATABASE_URL environment variable)'
  230. )
  231. args = parser.parse_args()
  232. # Get database URL
  233. database_url = args.database_url or get_database_url()
  234. print(f"Database: {database_url}")
  235. # Create engine
  236. engine = create_engine(database_url)
  237. # Perform migration
  238. if args.action == 'upgrade':
  239. print("\n=== Running upgrade migration ===\n")
  240. upgrade(engine)
  241. else:
  242. print("\n=== Running downgrade migration ===\n")
  243. downgrade(engine)
  244. if __name__ == '__main__':
  245. main()