Matrix Docker Ansible eploy
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

942 lines
29 KiB

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2015, 2016 OpenMarket Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from twisted.internet import defer, reactor
  17. from twisted.enterprise import adbapi
  18. from synapse.storage._base import LoggingTransaction, SQLBaseStore
  19. from synapse.storage.engines import create_engine
  20. from synapse.storage.prepare_database import prepare_database
  21. import argparse
  22. import curses
  23. import logging
  24. import sys
  25. import time
  26. import traceback
  27. import yaml
  28. logger = logging.getLogger("synapse_port_db")
  29. BOOLEAN_COLUMNS = {
  30. "events": ["processed", "outlier", "contains_url"],
  31. "rooms": ["is_public"],
  32. "event_edges": ["is_state"],
  33. "presence_list": ["accepted"],
  34. "presence_stream": ["currently_active"],
  35. "public_room_list_stream": ["visibility"],
  36. "device_lists_outbound_pokes": ["sent"],
  37. "users_who_share_rooms": ["share_private"],
  38. }
  39. APPEND_ONLY_TABLES = [
  40. "event_content_hashes",
  41. "event_reference_hashes",
  42. "event_signatures",
  43. "event_edge_hashes",
  44. "events",
  45. "event_json",
  46. "state_events",
  47. "room_memberships",
  48. "feedback",
  49. "topics",
  50. "room_names",
  51. "rooms",
  52. "local_media_repository",
  53. "local_media_repository_thumbnails",
  54. "remote_media_cache",
  55. "remote_media_cache_thumbnails",
  56. "redactions",
  57. "event_edges",
  58. "event_auth",
  59. "received_transactions",
  60. "sent_transactions",
  61. "transaction_id_to_pdu",
  62. "users",
  63. "state_groups",
  64. "state_groups_state",
  65. "event_to_state_groups",
  66. "rejections",
  67. "event_search",
  68. "presence_stream",
  69. "push_rules_stream",
  70. "current_state_resets",
  71. "ex_outlier_stream",
  72. "cache_invalidation_stream",
  73. "public_room_list_stream",
  74. "state_group_edges",
  75. "stream_ordering_to_exterm",
  76. ]
  77. end_error_exec_info = None
  78. class Store(object):
  79. """This object is used to pull out some of the convenience API from the
  80. Storage layer.
  81. *All* database interactions should go through this object.
  82. """
  83. def __init__(self, db_pool, engine):
  84. self.db_pool = db_pool
  85. self.database_engine = engine
  86. _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
  87. _simple_insert = SQLBaseStore.__dict__["_simple_insert"]
  88. _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
  89. _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
  90. _simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
  91. _simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
  92. _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
  93. _simple_select_one_onecol_txn = SQLBaseStore.__dict__[
  94. "_simple_select_one_onecol_txn"
  95. ]
  96. _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
  97. _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
  98. def runInteraction(self, desc, func, *args, **kwargs):
  99. def r(conn):
  100. try:
  101. i = 0
  102. N = 5
  103. while True:
  104. try:
  105. txn = conn.cursor()
  106. return func(
  107. LoggingTransaction(txn, desc, self.database_engine, [], []),
  108. *args, **kwargs
  109. )
  110. except self.database_engine.module.DatabaseError as e:
  111. if self.database_engine.is_deadlock(e):
  112. logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
  113. if i < N:
  114. i += 1
  115. conn.rollback()
  116. continue
  117. raise
  118. except Exception as e:
  119. logger.debug("[TXN FAIL] {%s} %s", desc, e)
  120. raise
  121. return self.db_pool.runWithConnection(r)
  122. def execute(self, f, *args, **kwargs):
  123. return self.runInteraction(f.__name__, f, *args, **kwargs)
  124. def execute_sql(self, sql, *args):
  125. def r(txn):
  126. txn.execute(sql, args)
  127. return txn.fetchall()
  128. return self.runInteraction("execute_sql", r)
  129. def insert_many_txn(self, txn, table, headers, rows):
  130. sql = "INSERT INTO %s (%s) VALUES (%s)" % (
  131. table,
  132. ", ".join(k for k in headers),
  133. ", ".join("%s" for _ in headers)
  134. )
  135. try:
  136. txn.executemany(sql, rows)
  137. except:
  138. logger.exception(
  139. "Failed to insert: %s",
  140. table,
  141. )
  142. raise
  143. class Porter(object):
  144. def __init__(self, **kwargs):
  145. self.__dict__.update(kwargs)
  146. @defer.inlineCallbacks
  147. def setup_table(self, table):
  148. if table in APPEND_ONLY_TABLES:
  149. # It's safe to just carry on inserting.
  150. row = yield self.postgres_store._simple_select_one(
  151. table="port_from_sqlite3",
  152. keyvalues={"table_name": table},
  153. retcols=("forward_rowid", "backward_rowid"),
  154. allow_none=True,
  155. )
  156. total_to_port = None
  157. if row is None:
  158. if table == "sent_transactions":
  159. forward_chunk, already_ported, total_to_port = (
  160. yield self._setup_sent_transactions()
  161. )
  162. backward_chunk = 0
  163. else:
  164. yield self.postgres_store._simple_insert(
  165. table="port_from_sqlite3",
  166. values={
  167. "table_name": table,
  168. "forward_rowid": 1,
  169. "backward_rowid": 0,
  170. }
  171. )
  172. forward_chunk = 1
  173. backward_chunk = 0
  174. already_ported = 0
  175. else:
  176. forward_chunk = row["forward_rowid"]
  177. backward_chunk = row["backward_rowid"]
  178. if total_to_port is None:
  179. already_ported, total_to_port = yield self._get_total_count_to_port(
  180. table, forward_chunk, backward_chunk
  181. )
  182. else:
  183. def delete_all(txn):
  184. txn.execute(
  185. "DELETE FROM port_from_sqlite3 WHERE table_name = %s",
  186. (table,)
  187. )
  188. txn.execute("TRUNCATE %s CASCADE" % (table,))
  189. yield self.postgres_store.execute(delete_all)
  190. yield self.postgres_store._simple_insert(
  191. table="port_from_sqlite3",
  192. values={
  193. "table_name": table,
  194. "forward_rowid": 1,
  195. "backward_rowid": 0,
  196. }
  197. )
  198. forward_chunk = 1
  199. backward_chunk = 0
  200. already_ported, total_to_port = yield self._get_total_count_to_port(
  201. table, forward_chunk, backward_chunk
  202. )
  203. defer.returnValue(
  204. (table, already_ported, total_to_port, forward_chunk, backward_chunk)
  205. )
  206. @defer.inlineCallbacks
  207. def handle_table(self, table, postgres_size, table_size, forward_chunk,
  208. backward_chunk):
  209. if not table_size:
  210. return
  211. self.progress.add_table(table, postgres_size, table_size)
  212. # Patch from: https://github.com/matrix-org/synapse/issues/2287
  213. if table == "user_directory_search":
  214. # FIXME: actually port it, but for now we can leave it blank
  215. # and have the server regenerate it.
  216. # you will need to set the values of user_directory_stream_pos
  217. # to be ('X', null) to force a regen
  218. return
  219. if table == "event_search":
  220. yield self.handle_search_table(
  221. postgres_size, table_size, forward_chunk, backward_chunk
  222. )
  223. return
  224. forward_select = (
  225. "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
  226. % (table,)
  227. )
  228. backward_select = (
  229. "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
  230. % (table,)
  231. )
  232. do_forward = [True]
  233. do_backward = [True]
  234. while True:
  235. def r(txn):
  236. forward_rows = []
  237. backward_rows = []
  238. if do_forward[0]:
  239. txn.execute(forward_select, (forward_chunk, self.batch_size,))
  240. forward_rows = txn.fetchall()
  241. if not forward_rows:
  242. do_forward[0] = False
  243. if do_backward[0]:
  244. txn.execute(backward_select, (backward_chunk, self.batch_size,))
  245. backward_rows = txn.fetchall()
  246. if not backward_rows:
  247. do_backward[0] = False
  248. if forward_rows or backward_rows:
  249. headers = [column[0] for column in txn.description]
  250. else:
  251. headers = None
  252. return headers, forward_rows, backward_rows
  253. headers, frows, brows = yield self.sqlite_store.runInteraction(
  254. "select", r
  255. )
  256. if frows or brows:
  257. if frows:
  258. forward_chunk = max(row[0] for row in frows) + 1
  259. if brows:
  260. backward_chunk = min(row[0] for row in brows) - 1
  261. rows = frows + brows
  262. self._convert_rows(table, headers, rows)
  263. def insert(txn):
  264. self.postgres_store.insert_many_txn(
  265. txn, table, headers[1:], rows
  266. )
  267. self.postgres_store._simple_update_one_txn(
  268. txn,
  269. table="port_from_sqlite3",
  270. keyvalues={"table_name": table},
  271. updatevalues={
  272. "forward_rowid": forward_chunk,
  273. "backward_rowid": backward_chunk,
  274. },
  275. )
  276. yield self.postgres_store.execute(insert)
  277. postgres_size += len(rows)
  278. self.progress.update(table, postgres_size)
  279. else:
  280. return
  281. @defer.inlineCallbacks
  282. def handle_search_table(self, postgres_size, table_size, forward_chunk,
  283. backward_chunk):
  284. select = (
  285. "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
  286. " FROM event_search as es"
  287. " INNER JOIN events AS e USING (event_id, room_id)"
  288. " WHERE es.rowid >= ?"
  289. " ORDER BY es.rowid LIMIT ?"
  290. )
  291. while True:
  292. def r(txn):
  293. txn.execute(select, (forward_chunk, self.batch_size,))
  294. rows = txn.fetchall()
  295. headers = [column[0] for column in txn.description]
  296. return headers, rows
  297. headers, rows = yield self.sqlite_store.runInteraction("select", r)
  298. if rows:
  299. forward_chunk = rows[-1][0] + 1
  300. # We have to treat event_search differently since it has a
  301. # different structure in the two different databases.
  302. def insert(txn):
  303. sql = (
  304. "INSERT INTO event_search (event_id, room_id, key,"
  305. " sender, vector, origin_server_ts, stream_ordering)"
  306. " VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
  307. )
  308. rows_dict = [
  309. dict(zip(headers, row))
  310. for row in rows
  311. ]
  312. txn.executemany(sql, [
  313. (
  314. row["event_id"],
  315. row["room_id"],
  316. row["key"],
  317. row["sender"],
  318. row["value"],
  319. row["origin_server_ts"],
  320. row["stream_ordering"],
  321. )
  322. for row in rows_dict
  323. ])
  324. self.postgres_store._simple_update_one_txn(
  325. txn,
  326. table="port_from_sqlite3",
  327. keyvalues={"table_name": "event_search"},
  328. updatevalues={
  329. "forward_rowid": forward_chunk,
  330. "backward_rowid": backward_chunk,
  331. },
  332. )
  333. yield self.postgres_store.execute(insert)
  334. postgres_size += len(rows)
  335. self.progress.update("event_search", postgres_size)
  336. else:
  337. return
  338. def setup_db(self, db_config, database_engine):
  339. db_conn = database_engine.module.connect(
  340. **{
  341. k: v for k, v in db_config.get("args", {}).items()
  342. if not k.startswith("cp_")
  343. }
  344. )
  345. prepare_database(db_conn, database_engine, config=None)
  346. db_conn.commit()
  347. @defer.inlineCallbacks
  348. def run(self):
  349. try:
  350. sqlite_db_pool = adbapi.ConnectionPool(
  351. self.sqlite_config["name"],
  352. **self.sqlite_config["args"]
  353. )
  354. postgres_db_pool = adbapi.ConnectionPool(
  355. self.postgres_config["name"],
  356. **self.postgres_config["args"]
  357. )
  358. sqlite_engine = create_engine(sqlite_config)
  359. postgres_engine = create_engine(postgres_config)
  360. self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
  361. self.postgres_store = Store(postgres_db_pool, postgres_engine)
  362. yield self.postgres_store.execute(
  363. postgres_engine.check_database
  364. )
  365. # Step 1. Set up databases.
  366. self.progress.set_state("Preparing SQLite3")
  367. self.setup_db(sqlite_config, sqlite_engine)
  368. self.progress.set_state("Preparing PostgreSQL")
  369. self.setup_db(postgres_config, postgres_engine)
  370. # Step 2. Get tables.
  371. self.progress.set_state("Fetching tables")
  372. sqlite_tables = yield self.sqlite_store._simple_select_onecol(
  373. table="sqlite_master",
  374. keyvalues={
  375. "type": "table",
  376. },
  377. retcol="name",
  378. )
  379. postgres_tables = yield self.postgres_store._simple_select_onecol(
  380. table="information_schema.tables",
  381. keyvalues={},
  382. retcol="distinct table_name",
  383. )
  384. tables = set(sqlite_tables) & set(postgres_tables)
  385. self.progress.set_state("Creating tables")
  386. logger.info("Found %d tables", len(tables))
  387. def create_port_table(txn):
  388. txn.execute(
  389. "CREATE TABLE port_from_sqlite3 ("
  390. " table_name varchar(100) NOT NULL UNIQUE,"
  391. " forward_rowid bigint NOT NULL,"
  392. " backward_rowid bigint NOT NULL"
  393. ")"
  394. )
  395. # The old port script created a table with just a "rowid" column.
  396. # We want people to be able to rerun this script from an old port
  397. # so that they can pick up any missing events that were not
  398. # ported across.
  399. def alter_table(txn):
  400. txn.execute(
  401. "ALTER TABLE IF EXISTS port_from_sqlite3"
  402. " RENAME rowid TO forward_rowid"
  403. )
  404. txn.execute(
  405. "ALTER TABLE IF EXISTS port_from_sqlite3"
  406. " ADD backward_rowid bigint NOT NULL DEFAULT 0"
  407. )
  408. try:
  409. yield self.postgres_store.runInteraction(
  410. "alter_table", alter_table
  411. )
  412. except Exception as e:
  413. logger.info("Failed to create port table: %s", e)
  414. try:
  415. yield self.postgres_store.runInteraction(
  416. "create_port_table", create_port_table
  417. )
  418. except Exception as e:
  419. logger.info("Failed to create port table: %s", e)
  420. self.progress.set_state("Setting up")
  421. # Set up tables.
  422. setup_res = yield defer.gatherResults(
  423. [
  424. self.setup_table(table)
  425. for table in tables
  426. if table not in ["schema_version", "applied_schema_deltas"]
  427. and not table.startswith("sqlite_")
  428. ],
  429. consumeErrors=True,
  430. )
  431. # Process tables.
  432. yield defer.gatherResults(
  433. [
  434. self.handle_table(*res)
  435. for res in setup_res
  436. ],
  437. consumeErrors=True,
  438. )
  439. self.progress.done()
  440. except:
  441. global end_error_exec_info
  442. end_error_exec_info = sys.exc_info()
  443. logger.exception("")
  444. finally:
  445. reactor.stop()
  446. def _convert_rows(self, table, headers, rows):
  447. bool_col_names = BOOLEAN_COLUMNS.get(table, [])
  448. bool_cols = [
  449. i for i, h in enumerate(headers) if h in bool_col_names
  450. ]
  451. def conv(j, col):
  452. if j in bool_cols:
  453. return bool(col)
  454. return col
  455. for i, row in enumerate(rows):
  456. rows[i] = tuple(
  457. conv(j, col)
  458. for j, col in enumerate(row)
  459. if j > 0
  460. )
  461. @defer.inlineCallbacks
  462. def _setup_sent_transactions(self):
  463. # Only save things from the last day
  464. yesterday = int(time.time() * 1000) - 86400000
  465. # And save the max transaction id from each destination
  466. select = (
  467. "SELECT rowid, * FROM sent_transactions WHERE rowid IN ("
  468. "SELECT max(rowid) FROM sent_transactions"
  469. " GROUP BY destination"
  470. ")"
  471. )
  472. def r(txn):
  473. txn.execute(select)
  474. rows = txn.fetchall()
  475. headers = [column[0] for column in txn.description]
  476. ts_ind = headers.index('ts')
  477. return headers, [r for r in rows if r[ts_ind] < yesterday]
  478. headers, rows = yield self.sqlite_store.runInteraction(
  479. "select", r,
  480. )
  481. self._convert_rows("sent_transactions", headers, rows)
  482. inserted_rows = len(rows)
  483. if inserted_rows:
  484. max_inserted_rowid = max(r[0] for r in rows)
  485. def insert(txn):
  486. self.postgres_store.insert_many_txn(
  487. txn, "sent_transactions", headers[1:], rows
  488. )
  489. yield self.postgres_store.execute(insert)
  490. else:
  491. max_inserted_rowid = 0
  492. def get_start_id(txn):
  493. txn.execute(
  494. "SELECT rowid FROM sent_transactions WHERE ts >= ?"
  495. " ORDER BY rowid ASC LIMIT 1",
  496. (yesterday,)
  497. )
  498. rows = txn.fetchall()
  499. if rows:
  500. return rows[0][0]
  501. else:
  502. return 1
  503. next_chunk = yield self.sqlite_store.execute(get_start_id)
  504. next_chunk = max(max_inserted_rowid + 1, next_chunk)
  505. yield self.postgres_store._simple_insert(
  506. table="port_from_sqlite3",
  507. values={
  508. "table_name": "sent_transactions",
  509. "forward_rowid": next_chunk,
  510. "backward_rowid": 0,
  511. }
  512. )
  513. def get_sent_table_size(txn):
  514. txn.execute(
  515. "SELECT count(*) FROM sent_transactions"
  516. " WHERE ts >= ?",
  517. (yesterday,)
  518. )
  519. size, = txn.fetchone()
  520. return int(size)
  521. remaining_count = yield self.sqlite_store.execute(
  522. get_sent_table_size
  523. )
  524. total_count = remaining_count + inserted_rows
  525. defer.returnValue((next_chunk, inserted_rows, total_count))
  526. @defer.inlineCallbacks
  527. def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
  528. frows = yield self.sqlite_store.execute_sql(
  529. "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
  530. forward_chunk,
  531. )
  532. brows = yield self.sqlite_store.execute_sql(
  533. "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
  534. backward_chunk,
  535. )
  536. defer.returnValue(frows[0][0] + brows[0][0])
  537. @defer.inlineCallbacks
  538. def _get_already_ported_count(self, table):
  539. rows = yield self.postgres_store.execute_sql(
  540. "SELECT count(*) FROM %s" % (table,),
  541. )
  542. defer.returnValue(rows[0][0])
  543. @defer.inlineCallbacks
  544. def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
  545. remaining, done = yield defer.gatherResults(
  546. [
  547. self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
  548. self._get_already_ported_count(table),
  549. ],
  550. consumeErrors=True,
  551. )
  552. remaining = int(remaining) if remaining else 0
  553. done = int(done) if done else 0
  554. defer.returnValue((done, remaining + done))
  555. ##############################################
  556. ###### The following is simply UI stuff ######
  557. ##############################################
  558. class Progress(object):
  559. """Used to report progress of the port
  560. """
  561. def __init__(self):
  562. self.tables = {}
  563. self.start_time = int(time.time())
  564. def add_table(self, table, cur, size):
  565. self.tables[table] = {
  566. "start": cur,
  567. "num_done": cur,
  568. "total": size,
  569. "perc": int(cur * 100 / size),
  570. }
  571. def update(self, table, num_done):
  572. data = self.tables[table]
  573. data["num_done"] = num_done
  574. data["perc"] = int(num_done * 100 / data["total"])
  575. def done(self):
  576. pass
  577. class CursesProgress(Progress):
  578. """Reports progress to a curses window
  579. """
  580. def __init__(self, stdscr):
  581. self.stdscr = stdscr
  582. curses.use_default_colors()
  583. curses.curs_set(0)
  584. curses.init_pair(1, curses.COLOR_RED, -1)
  585. curses.init_pair(2, curses.COLOR_GREEN, -1)
  586. self.last_update = 0
  587. self.finished = False
  588. self.total_processed = 0
  589. self.total_remaining = 0
  590. super(CursesProgress, self).__init__()
  591. def update(self, table, num_done):
  592. super(CursesProgress, self).update(table, num_done)
  593. self.total_processed = 0
  594. self.total_remaining = 0
  595. for table, data in self.tables.items():
  596. self.total_processed += data["num_done"] - data["start"]
  597. self.total_remaining += data["total"] - data["num_done"]
  598. self.render()
  599. def render(self, force=False):
  600. now = time.time()
  601. if not force and now - self.last_update < 0.2:
  602. # reactor.callLater(1, self.render)
  603. return
  604. self.stdscr.clear()
  605. rows, cols = self.stdscr.getmaxyx()
  606. duration = int(now) - int(self.start_time)
  607. minutes, seconds = divmod(duration, 60)
  608. duration_str = '%02dm %02ds' % (minutes, seconds,)
  609. if self.finished:
  610. status = "Time spent: %s (Done!)" % (duration_str,)
  611. else:
  612. if self.total_processed > 0:
  613. left = float(self.total_remaining) / self.total_processed
  614. est_remaining = (int(now) - self.start_time) * left
  615. est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
  616. else:
  617. est_remaining_str = "Unknown"
  618. status = (
  619. "Time spent: %s (est. remaining: %s)"
  620. % (duration_str, est_remaining_str,)
  621. )
  622. self.stdscr.addstr(
  623. 0, 0,
  624. status,
  625. curses.A_BOLD,
  626. )
  627. max_len = max([len(t) for t in self.tables.keys()])
  628. left_margin = 5
  629. middle_space = 1
  630. items = self.tables.items()
  631. items.sort(
  632. key=lambda i: (i[1]["perc"], i[0]),
  633. )
  634. for i, (table, data) in enumerate(items):
  635. if i + 2 >= rows:
  636. break
  637. perc = data["perc"]
  638. color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
  639. self.stdscr.addstr(
  640. i + 2, left_margin + max_len - len(table),
  641. table,
  642. curses.A_BOLD | color,
  643. )
  644. size = 20
  645. progress = "[%s%s]" % (
  646. "#" * int(perc * size / 100),
  647. " " * (size - int(perc * size / 100)),
  648. )
  649. self.stdscr.addstr(
  650. i + 2, left_margin + max_len + middle_space,
  651. "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
  652. )
  653. if self.finished:
  654. self.stdscr.addstr(
  655. rows - 1, 0,
  656. "Press any key to exit...",
  657. )
  658. self.stdscr.refresh()
  659. self.last_update = time.time()
  660. def done(self):
  661. self.finished = True
  662. self.render(True)
  663. self.stdscr.getch()
  664. def set_state(self, state):
  665. self.stdscr.clear()
  666. self.stdscr.addstr(
  667. 0, 0,
  668. state + "...",
  669. curses.A_BOLD,
  670. )
  671. self.stdscr.refresh()
  672. class TerminalProgress(Progress):
  673. """Just prints progress to the terminal
  674. """
  675. def update(self, table, num_done):
  676. super(TerminalProgress, self).update(table, num_done)
  677. data = self.tables[table]
  678. print "%s: %d%% (%d/%d)" % (
  679. table, data["perc"],
  680. data["num_done"], data["total"],
  681. )
  682. def set_state(self, state):
  683. print state + "..."
  684. ##############################################
  685. ##############################################
  686. if __name__ == "__main__":
  687. parser = argparse.ArgumentParser(
  688. description="A script to port an existing synapse SQLite database to"
  689. " a new PostgreSQL database."
  690. )
  691. parser.add_argument("-v", action='store_true')
  692. parser.add_argument(
  693. "--sqlite-database", required=True,
  694. help="The snapshot of the SQLite database file. This must not be"
  695. " currently used by a running synapse server"
  696. )
  697. parser.add_argument(
  698. "--postgres-config", type=argparse.FileType('r'), required=True,
  699. help="The database config file for the PostgreSQL database"
  700. )
  701. parser.add_argument(
  702. "--curses", action='store_true',
  703. help="display a curses based progress UI"
  704. )
  705. parser.add_argument(
  706. "--batch-size", type=int, default=1000,
  707. help="The number of rows to select from the SQLite table each"
  708. " iteration [default=1000]",
  709. )
  710. args = parser.parse_args()
  711. logging_config = {
  712. "level": logging.DEBUG if args.v else logging.INFO,
  713. "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
  714. }
  715. if args.curses:
  716. logging_config["filename"] = "port-synapse.log"
  717. logging.basicConfig(**logging_config)
  718. sqlite_config = {
  719. "name": "sqlite3",
  720. "args": {
  721. "database": args.sqlite_database,
  722. "cp_min": 1,
  723. "cp_max": 1,
  724. "check_same_thread": False,
  725. },
  726. }
  727. postgres_config = yaml.safe_load(args.postgres_config)
  728. if "database" in postgres_config:
  729. postgres_config = postgres_config["database"]
  730. if "name" not in postgres_config:
  731. sys.stderr.write("Malformed database config: no 'name'")
  732. sys.exit(2)
  733. if postgres_config["name"] != "psycopg2":
  734. sys.stderr.write("Database must use 'psycopg2' connector.")
  735. sys.exit(3)
  736. def start(stdscr=None):
  737. if stdscr:
  738. progress = CursesProgress(stdscr)
  739. else:
  740. progress = TerminalProgress()
  741. porter = Porter(
  742. sqlite_config=sqlite_config,
  743. postgres_config=postgres_config,
  744. progress=progress,
  745. batch_size=args.batch_size,
  746. )
  747. reactor.callWhenRunning(porter.run)
  748. reactor.run()
  749. if args.curses:
  750. curses.wrapper(start)
  751. else:
  752. start()
  753. if end_error_exec_info:
  754. exc_type, exc_value, exc_traceback = end_error_exec_info
  755. traceback.print_exception(exc_type, exc_value, exc_traceback)