spacepaste

  1.  
  2. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ==============================================================================
  16. """configure script to get build parameters from user."""
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import argparse
  21. import errno
  22. import os
  23. import platform
  24. import re
  25. import subprocess
  26. import sys
  27. # pylint: disable=g-import-not-at-top
  28. try:
  29. from shutil import which
  30. except ImportError:
  31. from distutils.spawn import find_executable as which
  32. # pylint: enable=g-import-not-at-top
  33. _DEFAULT_CUDA_VERSION = '9.0'
  34. _DEFAULT_CUDNN_VERSION = '7'
  35. _DEFAULT_NCCL_VERSION = '1.3'
  36. _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
  37. _DEFAULT_CUDA_PATH = '/usr/local/cuda'
  38. _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
  39. _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
  40. 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
  41. _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine()
  42. _TF_OPENCL_VERSION = '1.2'
  43. _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
  44. _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
  45. _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15]
  46. _DEFAULT_PROMPT_ASK_ATTEMPTS = 10
  47. _TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__))
  48. _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
  49. _TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME)
  50. _TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE')
  51. class UserInputError(Exception):
  52. pass
  53. def is_windows():
  54. return platform.system() == 'Windows'
  55. def is_linux():
  56. return platform.system() == 'Linux'
  57. def is_macos():
  58. return platform.system() == 'Darwin'
  59. def is_ppc64le():
  60. return platform.machine() == 'ppc64le'
  61. def is_cygwin():
  62. return platform.system().startswith('CYGWIN_NT')
  63. def get_input(question):
  64. try:
  65. try:
  66. answer = raw_input(question)
  67. except NameError:
  68. answer = input(question) # pylint: disable=bad-builtin
  69. except EOFError:
  70. answer = ''
  71. return answer
  72. def symlink_force(target, link_name):
  73. """Force symlink, equivalent of 'ln -sf'.
  74. Args:
  75. target: items to link to.
  76. link_name: name of the link.
  77. """
  78. try:
  79. os.symlink(target, link_name)
  80. except OSError as e:
  81. if e.errno == errno.EEXIST:
  82. os.remove(link_name)
  83. os.symlink(target, link_name)
  84. else:
  85. raise e
  86. def sed_in_place(filename, old, new):
  87. """Replace old string with new string in file.
  88. Args:
  89. filename: string for filename.
  90. old: string to replace.
  91. new: new string to replace to.
  92. """
  93. with open(filename, 'r') as f:
  94. filedata = f.read()
  95. newdata = filedata.replace(old, new)
  96. with open(filename, 'w') as f:
  97. f.write(newdata)
  98. def write_to_bazelrc(line):
  99. with open(_TF_BAZELRC, 'a') as f:
  100. f.write(line + '\n')
  101. def write_action_env_to_bazelrc(var_name, var):
  102. write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
  103. def run_shell(cmd, allow_non_zero=False):
  104. if allow_non_zero:
  105. try:
  106. output = subprocess.check_output(cmd)
  107. except subprocess.CalledProcessError as e:
  108. output = e.output
  109. else:
  110. output = subprocess.check_output(cmd)
  111. return output.decode('UTF-8').strip()
  112. def cygpath(path):
  113. """Convert path from posix to windows."""
  114. return os.path.abspath(path).replace('\\', '/')
  115. def get_python_path(environ_cp, python_bin_path):
  116. """Get the python site package paths."""
  117. python_paths = []
  118. if environ_cp.get('PYTHONPATH'):
  119. python_paths = environ_cp.get('PYTHONPATH').split(':')
  120. try:
  121. library_paths = run_shell(
  122. [python_bin_path, '-c',
  123. 'import site; print("\\n".join(site.getsitepackages()))']).split('\n')
  124. except subprocess.CalledProcessError:
  125. library_paths = [run_shell(
  126. [python_bin_path, '-c',
  127. 'from distutils.sysconfig import get_python_lib;'
  128. 'print(get_python_lib())'])]
  129. all_paths = set(python_paths + library_paths)
  130. paths = []
  131. for path in all_paths:
  132. if os.path.isdir(path):
  133. paths.append(path)
  134. return paths
  135. def get_python_major_version(python_bin_path):
  136. """Get the python major version."""
  137. return run_shell([python_bin_path, '-c', 'import sys; print(sys.version[0])'])
  138. def setup_python(environ_cp):
  139. """Setup python related env variables."""
  140. # Get PYTHON_BIN_PATH, default is the current running python.
  141. default_python_bin_path = sys.executable
  142. ask_python_bin_path = ('Please specify the location of python. [Default is '
  143. '%s]: ') % default_python_bin_path
  144. while True:
  145. python_bin_path = get_from_env_or_user_or_default(
  146. environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path,
  147. default_python_bin_path)
  148. # Check if the path is valid
  149. if os.path.isfile(python_bin_path) and os.access(
  150. python_bin_path, os.X_OK):
  151. break
  152. elif not os.path.exists(python_bin_path):
  153. print('Invalid python path: %s cannot be found.' % python_bin_path)
  154. else:
  155. print('%s is not executable. Is it the python binary?' % python_bin_path)
  156. environ_cp['PYTHON_BIN_PATH'] = ''
  157. # Convert python path to Windows style before checking lib and version
  158. if is_windows() or is_cygwin():
  159. python_bin_path = cygpath(python_bin_path)
  160. # Get PYTHON_LIB_PATH
  161. python_lib_path = environ_cp.get('PYTHON_LIB_PATH')
  162. if not python_lib_path:
  163. python_lib_paths = get_python_path(environ_cp, python_bin_path)
  164. if environ_cp.get('USE_DEFAULT_PYTHON_LIB_PATH') == '1':
  165. python_lib_path = python_lib_paths[0]
  166. else:
  167. print('Found possible Python library paths:\n %s' %
  168. '\n '.join(python_lib_paths))
  169. default_python_lib_path = python_lib_paths[0]
  170. python_lib_path = get_input(
  171. 'Please input the desired Python library path to use. '
  172. 'Default is [%s]\n' % python_lib_paths[0])
  173. if not python_lib_path:
  174. python_lib_path = default_python_lib_path
  175. environ_cp['PYTHON_LIB_PATH'] = python_lib_path
  176. python_major_version = get_python_major_version(python_bin_path)
  177. # Convert python path to Windows style before writing into bazel.rc
  178. if is_windows() or is_cygwin():
  179. python_lib_path = cygpath(python_lib_path)
  180. # Set-up env variables used by python_configure.bzl
  181. write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
  182. write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
  183. write_to_bazelrc('build --force_python=py%s' % python_major_version)
  184. write_to_bazelrc('build --host_force_python=py%s' % python_major_version)
  185. write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
  186. environ_cp['PYTHON_BIN_PATH'] = python_bin_path
  187. # Write tools/python_bin_path.sh
  188. with open(os.path.join(
  189. _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f:
  190. f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
  191. def reset_tf_configure_bazelrc(workspace_path):
  192. """Reset file that contains customized config settings."""
  193. open(_TF_BAZELRC, 'w').close()
  194. bazelrc_path = os.path.join(workspace_path, '.bazelrc')
  195. data = []
  196. if os.path.exists(bazelrc_path):
  197. with open(bazelrc_path, 'r') as f:
  198. data = f.read().splitlines()
  199. with open(bazelrc_path, 'w') as f:
  200. for l in data:
  201. if _TF_BAZELRC_FILENAME in l:
  202. continue
  203. f.write('%s\n' % l)
  204. if is_windows():
  205. tf_bazelrc_path = _TF_BAZELRC.replace("\\", "/")
  206. else:
  207. tf_bazelrc_path = _TF_BAZELRC
  208. f.write('import %s\n' % tf_bazelrc_path)
  209. def cleanup_makefile():
  210. """Delete any leftover BUILD files from the Makefile build.
  211. These files could interfere with Bazel parsing.
  212. """
  213. makefile_download_dir = os.path.join(
  214. _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads')
  215. if os.path.isdir(makefile_download_dir):
  216. for root, _, filenames in os.walk(makefile_download_dir):
  217. for f in filenames:
  218. if f.endswith('BUILD'):
  219. os.remove(os.path.join(root, f))
  220. def get_var(environ_cp,
  221. var_name,
  222. query_item,
  223. enabled_by_default,
  224. question=None,
  225. yes_reply=None,
  226. no_reply=None):
  227. """Get boolean input from user.
  228. If var_name is not set in env, ask user to enable query_item or not. If the
  229. response is empty, use the default.
  230. Args:
  231. environ_cp: copy of the os.environ.
  232. var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
  233. query_item: string for feature related to the variable, e.g. "Hadoop File
  234. System".
  235. enabled_by_default: boolean for default behavior.
  236. question: optional string for how to ask for user input.
  237. yes_reply: optional string for reply when feature is enabled.
  238. no_reply: optional string for reply when feature is disabled.
  239. Returns:
  240. boolean value of the variable.
  241. Raises:
  242. UserInputError: if an environment variable is set, but it cannot be
  243. interpreted as a boolean indicator, assume that the user has made a
  244. scripting error, and will continue to provide invalid input.
  245. Raise the error to avoid infinitely looping.
  246. """
  247. if not question:
  248. question = 'Do you wish to build TensorFlow with %s support?' % query_item
  249. if not yes_reply:
  250. yes_reply = '%s support will be enabled for TensorFlow.' % query_item
  251. if not no_reply:
  252. no_reply = 'No %s' % yes_reply
  253. yes_reply += '\n'
  254. no_reply += '\n'
  255. if enabled_by_default:
  256. question += ' [Y/n]: '
  257. else:
  258. question += ' [y/N]: '
  259. var = environ_cp.get(var_name)
  260. if var is not None:
  261. var_content = var.strip().lower()
  262. true_strings = ('1', 't', 'true', 'y', 'yes')
  263. false_strings = ('0', 'f', 'false', 'n', 'no')
  264. if var_content in true_strings:
  265. var = True
  266. elif var_content in false_strings:
  267. var = False
  268. else:
  269. raise UserInputError(
  270. 'Environment variable %s must be set as a boolean indicator.\n'
  271. 'The following are accepted as TRUE : %s.\n'
  272. 'The following are accepted as FALSE: %s.\n'
  273. 'Current value is %s.' % (
  274. var_name, ', '.join(true_strings), ', '.join(false_strings),
  275. var))
  276. while var is None:
  277. user_input_origin = get_input(question)
  278. user_input = user_input_origin.strip().lower()
  279. if user_input == 'y':
  280. print(yes_reply)
  281. var = True
  282. elif user_input == 'n':
  283. print(no_reply)
  284. var = False
  285. elif not user_input:
  286. if enabled_by_default:
  287. print(yes_reply)
  288. var = True
  289. else:
  290. print(no_reply)
  291. var = False
  292. else:
  293. print('Invalid selection: %s' % user_input_origin)
  294. return var
  295. def set_build_var(environ_cp, var_name, query_item, option_name,
  296. enabled_by_default, bazel_config_name=None):
  297. """Set if query_item will be enabled for the build.
  298. Ask user if query_item will be enabled. Default is used if no input is given.
  299. Set subprocess environment variable and write to .bazelrc if enabled.
  300. Args:
  301. environ_cp: copy of the os.environ.
  302. var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
  303. query_item: string for feature related to the variable, e.g. "Hadoop File
  304. System".
  305. option_name: string for option to define in .bazelrc.
  306. enabled_by_default: boolean for default behavior.
  307. bazel_config_name: Name for Bazel --config argument to enable build feature.
  308. """
  309. var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
  310. environ_cp[var_name] = var
  311. if var == '1':
  312. write_to_bazelrc('build --define %s=true' % option_name)
  313. elif bazel_config_name is not None:
  314. # TODO(mikecase): Migrate all users of configure.py to use --config Bazel
  315. # options and not to set build configs through environment variables.
  316. write_to_bazelrc('build:%s --define %s=true'
  317. % (bazel_config_name, option_name))
  318. def set_action_env_var(environ_cp,
  319. var_name,
  320. query_item,
  321. enabled_by_default,
  322. question=None,
  323. yes_reply=None,
  324. no_reply=None):
  325. """Set boolean action_env variable.
  326. Ask user if query_item will be enabled. Default is used if no input is given.
  327. Set environment variable and write to .bazelrc.
  328. Args:
  329. environ_cp: copy of the os.environ.
  330. var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
  331. query_item: string for feature related to the variable, e.g. "Hadoop File
  332. System".
  333. enabled_by_default: boolean for default behavior.
  334. question: optional string for how to ask for user input.
  335. yes_reply: optional string for reply when feature is enabled.
  336. no_reply: optional string for reply when feature is disabled.
  337. """
  338. var = int(
  339. get_var(environ_cp, var_name, query_item, enabled_by_default, question,
  340. yes_reply, no_reply))
  341. write_action_env_to_bazelrc(var_name, var)
  342. environ_cp[var_name] = str(var)
  343. def convert_version_to_int(version):
  344. """Convert a version number to a integer that can be used to compare.
  345. Version strings of the form X.YZ and X.Y.Z-xxxxx are supported. The
  346. 'xxxxx' part, for instance 'homebrew' on OS/X, is ignored.
  347. Args:
  348. version: a version to be converted
  349. Returns:
  350. An integer if converted successfully, otherwise return None.
  351. """
  352. version = version.split('-')[0]
  353. version_segments = version.split('.')
  354. for seg in version_segments:
  355. if not seg.isdigit():
  356. return None
  357. version_str = ''.join(['%03d' % int(seg) for seg in version_segments])
  358. return int(version_str)
  359. def check_bazel_version(min_version):
  360. """Check installed bazel version is at least min_version.
  361. Args:
  362. min_version: string for minimum bazel version.
  363. Returns:
  364. The bazel version detected.
  365. """
  366. if which('bazel') is None:
  367. print('Cannot find bazel. Please install bazel.')
  368. sys.exit(0)
  369. curr_version = run_shell(['bazel', '--batch', '--bazelrc=/dev/null', 'version'])
  370. for line in curr_version.split('\n'):
  371. if 'Build label: ' in line:
  372. curr_version = line.split('Build label: ')[1]
  373. break
  374. min_version_int = convert_version_to_int(min_version)
  375. curr_version_int = convert_version_to_int(curr_version)
  376. # Check if current bazel version can be detected properly.
  377. if not curr_version_int:
  378. print('WARNING: current bazel installation is not a release version.')
  379. print('Make sure you are running at least bazel %s' % min_version)
  380. return curr_version
  381. print('You have bazel %s installed.' % curr_version)
  382. if curr_version_int < min_version_int:
  383. print('Please upgrade your bazel installation to version %s or higher to '
  384. 'build TensorFlow!' % min_version)
  385. sys.exit(0)
  386. return curr_version
  387. def set_cc_opt_flags(environ_cp):
  388. """Set up architecture-dependent optimization flags.
  389. Also append CC optimization flags to bazel.rc..
  390. Args:
  391. environ_cp: copy of the os.environ.
  392. """
  393. if is_ppc64le():
  394. # gcc on ppc64le does not support -march, use mcpu instead
  395. default_cc_opt_flags = '-mcpu=native'
  396. elif is_windows():
  397. default_cc_opt_flags = '/arch:AVX'
  398. else:
  399. default_cc_opt_flags = '-march=native'
  400. question = ('Please specify optimization flags to use during compilation when'
  401. ' bazel option "--config=opt" is specified [Default is %s]: '
  402. ) % default_cc_opt_flags
  403. cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS',
  404. question, default_cc_opt_flags)
  405. for opt in cc_opt_flags.split():
  406. write_to_bazelrc('build:opt --copt=%s' % opt)
  407. # It should be safe on the same build host.
  408. if not is_ppc64le() and not is_windows():
  409. write_to_bazelrc('build:opt --host_copt=-march=native')
  410. write_to_bazelrc('build:opt --define with_default_optimizations=true')
  411. # TODO(mikecase): Remove these default defines once we are able to get
  412. # TF Lite targets building without them.
  413. write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK')
  414. write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK')
  415. def set_tf_cuda_clang(environ_cp):
  416. """set TF_CUDA_CLANG action_env.
  417. Args:
  418. environ_cp: copy of the os.environ.
  419. """
  420. question = 'Do you want to use clang as CUDA compiler?'
  421. yes_reply = 'Clang will be used as CUDA compiler.'
  422. no_reply = 'nvcc will be used as CUDA compiler.'
  423. set_action_env_var(
  424. environ_cp,
  425. 'TF_CUDA_CLANG',
  426. None,
  427. False,
  428. question=question,
  429. yes_reply=yes_reply,
  430. no_reply=no_reply)
  431. def set_tf_download_clang(environ_cp):
  432. """Set TF_DOWNLOAD_CLANG action_env."""
  433. question = 'Do you wish to download a fresh release of clang? (Experimental)'
  434. yes_reply = 'Clang will be downloaded and used to compile tensorflow.'
  435. no_reply = 'Clang will not be downloaded.'
  436. set_action_env_var(
  437. environ_cp,
  438. 'TF_DOWNLOAD_CLANG',
  439. None,
  440. False,
  441. question=question,
  442. yes_reply=yes_reply,
  443. no_reply=no_reply)
  444. def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
  445. var_default):
  446. """Get var_name either from env, or user or default.
  447. If var_name has been set as environment variable, use the preset value, else
  448. ask for user input. If no input is provided, the default is used.
  449. Args:
  450. environ_cp: copy of the os.environ.
  451. var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
  452. ask_for_var: string for how to ask for user input.
  453. var_default: default value string.
  454. Returns:
  455. string value for var_name
  456. """
  457. var = environ_cp.get(var_name)
  458. if not var:
  459. var = get_input(ask_for_var)
  460. print('\n')
  461. if not var:
  462. var = var_default
  463. return var
  464. def set_clang_cuda_compiler_path(environ_cp):
  465. """Set CLANG_CUDA_COMPILER_PATH."""
  466. default_clang_path = which('clang') or ''
  467. ask_clang_path = ('Please specify which clang should be used as device and '
  468. 'host compiler. [Default is %s]: ') % default_clang_path
  469. while True:
  470. clang_cuda_compiler_path = get_from_env_or_user_or_default(
  471. environ_cp, 'CLANG_CUDA_COMPILER_PATH', ask_clang_path,
  472. default_clang_path)
  473. if os.path.exists(clang_cuda_compiler_path):
  474. break
  475. # Reset and retry
  476. print('Invalid clang path: %s cannot be found.' % clang_cuda_compiler_path)
  477. environ_cp['CLANG_CUDA_COMPILER_PATH'] = ''
  478. # Set CLANG_CUDA_COMPILER_PATH
  479. environ_cp['CLANG_CUDA_COMPILER_PATH'] = clang_cuda_compiler_path
  480. write_action_env_to_bazelrc('CLANG_CUDA_COMPILER_PATH',
  481. clang_cuda_compiler_path)
  482. def prompt_loop_or_load_from_env(
  483. environ_cp,
  484. var_name,
  485. var_default,
  486. ask_for_var,
  487. check_success,
  488. error_msg,
  489. suppress_default_error=False,
  490. n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS
  491. ):
  492. """Loop over user prompts for an ENV param until receiving a valid response.
  493. For the env param var_name, read from the environment or verify user input
  494. until receiving valid input. When done, set var_name in the environ_cp to its
  495. new value.
  496. Args:
  497. environ_cp: (Dict) copy of the os.environ.
  498. var_name: (String) string for name of environment variable, e.g. "TF_MYVAR".
  499. var_default: (String) default value string.
  500. ask_for_var: (String) string for how to ask for user input.
  501. check_success: (Function) function that takes one argument and returns a
  502. boolean. Should return True if the value provided is considered valid. May
  503. contain a complex error message if error_msg does not provide enough
  504. information. In that case, set suppress_default_error to True.
  505. error_msg: (String) String with one and only one '%s'. Formatted with each
  506. invalid response upon check_success(input) failure.
  507. suppress_default_error: (Bool) Suppress the above error message in favor of
  508. one from the check_success function.
  509. n_ask_attempts: (Integer) Number of times to query for valid input before
  510. raising an error and quitting.
  511. Returns:
  512. [String] The value of var_name after querying for input.
  513. Raises:
  514. UserInputError: if a query has been attempted n_ask_attempts times without
  515. success, assume that the user has made a scripting error, and will
  516. continue to provide invalid input. Raise the error to avoid infinitely
  517. looping.
  518. """
  519. default = environ_cp.get(var_name) or var_default
  520. full_query = '%s [Default is %s]: ' % (
  521. ask_for_var,
  522. default,
  523. )
  524. for _ in range(n_ask_attempts):
  525. val = get_from_env_or_user_or_default(environ_cp,
  526. var_name,
  527. full_query,
  528. default)
  529. if check_success(val):
  530. break
  531. if not suppress_default_error:
  532. print(error_msg % val)
  533. environ_cp[var_name] = ''
  534. else:
  535. raise UserInputError('Invalid %s setting was provided %d times in a row. '
  536. 'Assuming to be a scripting mistake.' %
  537. (var_name, n_ask_attempts))
  538. environ_cp[var_name] = val
  539. return val
  540. def create_android_ndk_rule(environ_cp):
  541. """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule."""
  542. if is_windows() or is_cygwin():
  543. default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' %
  544. environ_cp['APPDATA'])
  545. elif is_macos():
  546. default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME']
  547. else:
  548. default_ndk_path = '%s/Android/Sdk/ndk-bundle' % environ_cp['HOME']
  549. def valid_ndk_path(path):
  550. return (os.path.exists(path) and
  551. os.path.exists(os.path.join(path, 'source.properties')))
  552. android_ndk_home_path = prompt_loop_or_load_from_env(
  553. environ_cp,
  554. var_name='ANDROID_NDK_HOME',
  555. var_default=default_ndk_path,
  556. ask_for_var='Please specify the home path of the Android NDK to use.',
  557. check_success=valid_ndk_path,
  558. error_msg=('The path %s or its child file "source.properties" '
  559. 'does not exist.')
  560. )
  561. write_android_ndk_workspace_rule(android_ndk_home_path)
  562. def create_android_sdk_rule(environ_cp):
  563. """Set Android variables and write Android SDK WORKSPACE rule."""
  564. if is_windows() or is_cygwin():
  565. default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA'])
  566. elif is_macos():
  567. default_sdk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME']
  568. else:
  569. default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME']
  570. def valid_sdk_path(path):
  571. return (os.path.exists(path) and
  572. os.path.exists(os.path.join(path, 'platforms')) and
  573. os.path.exists(os.path.join(path, 'build-tools')))
  574. android_sdk_home_path = prompt_loop_or_load_from_env(
  575. environ_cp,
  576. var_name='ANDROID_SDK_HOME',
  577. var_default=default_sdk_path,
  578. ask_for_var='Please specify the home path of the Android SDK to use.',
  579. check_success=valid_sdk_path,
  580. error_msg=('Either %s does not exist, or it does not contain the '
  581. 'subdirectories "platforms" and "build-tools".'))
  582. platforms = os.path.join(android_sdk_home_path, 'platforms')
  583. api_levels = sorted(os.listdir(platforms))
  584. api_levels = [x.replace('android-', '') for x in api_levels]
  585. def valid_api_level(api_level):
  586. return os.path.exists(os.path.join(android_sdk_home_path,
  587. 'platforms',
  588. 'android-' + api_level))
  589. android_api_level = prompt_loop_or_load_from_env(
  590. environ_cp,
  591. var_name='ANDROID_API_LEVEL',
  592. var_default=api_levels[-1],
  593. ask_for_var=('Please specify the Android SDK API level to use. '
  594. '[Available levels: %s]') % api_levels,
  595. check_success=valid_api_level,
  596. error_msg='Android-%s is not present in the SDK path.')
  597. build_tools = os.path.join(android_sdk_home_path, 'build-tools')
  598. versions = sorted(os.listdir(build_tools))
  599. def valid_build_tools(version):
  600. return os.path.exists(os.path.join(android_sdk_home_path,
  601. 'build-tools',
  602. version))
  603. android_build_tools_version = prompt_loop_or_load_from_env(
  604. environ_cp,
  605. var_name='ANDROID_BUILD_TOOLS_VERSION',
  606. var_default=versions[-1],
  607. ask_for_var=('Please specify an Android build tools version to use. '
  608. '[Available versions: %s]') % versions,
  609. check_success=valid_build_tools,
  610. error_msg=('The selected SDK does not have build-tools version %s '
  611. 'available.'))
  612. write_android_sdk_workspace_rule(android_sdk_home_path,
  613. android_build_tools_version,
  614. android_api_level)
  615. def write_android_sdk_workspace_rule(android_sdk_home_path,
  616. android_build_tools_version,
  617. android_api_level):
  618. print('Writing android_sdk_workspace rule.\n')
  619. with open(_TF_WORKSPACE, 'a') as f:
  620. f.write("""
  621. android_sdk_repository(
  622. name="androidsdk",
  623. api_level=%s,
  624. path="%s",
  625. build_tools_version="%s")\n
  626. """ % (android_api_level, android_sdk_home_path, android_build_tools_version))
  627. def write_android_ndk_workspace_rule(android_ndk_home_path):
  628. print('Writing android_ndk_workspace rule.')
  629. ndk_api_level = check_ndk_level(android_ndk_home_path)
  630. if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
  631. print('WARNING: The API level of the NDK in %s is %s, which is not '
  632. 'supported by Bazel (officially supported versions: %s). Please use '
  633. 'another version. Compiling Android targets may result in confusing '
  634. 'errors.\n' % (android_ndk_home_path, ndk_api_level,
  635. _SUPPORTED_ANDROID_NDK_VERSIONS))
  636. with open(_TF_WORKSPACE, 'a') as f:
  637. f.write("""
  638. android_ndk_repository(
  639. name="androidndk",
  640. path="%s",
  641. api_level=%s)\n
  642. """ % (android_ndk_home_path, ndk_api_level))
  643. def check_ndk_level(android_ndk_home_path):
  644. """Check the revision number of an Android NDK path."""
  645. properties_path = '%s/source.properties' % android_ndk_home_path
  646. if is_windows() or is_cygwin():
  647. properties_path = cygpath(properties_path)
  648. with open(properties_path, 'r') as f:
  649. filedata = f.read()
  650. revision = re.search(r'Pkg.Revision = (\d+)', filedata)
  651. if revision:
  652. return revision.group(1)
  653. return None
  654. def workspace_has_any_android_rule():
  655. """Check the WORKSPACE for existing android_*_repository rules."""
  656. with open(_TF_WORKSPACE, 'r') as f:
  657. workspace = f.read()
  658. has_any_rule = re.search(r'^android_[ns]dk_repository',
  659. workspace,
  660. re.MULTILINE)
  661. return has_any_rule
  662. def set_gcc_host_compiler_path(environ_cp):
  663. """Set GCC_HOST_COMPILER_PATH."""
  664. default_gcc_host_compiler_path = which('gcc') or ''
  665. cuda_bin_symlink = '%s/bin/gcc' % environ_cp.get('CUDA_TOOLKIT_PATH')
  666. if os.path.islink(cuda_bin_symlink):
  667. # os.readlink is only available in linux
  668. default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink)
  669. gcc_host_compiler_path = prompt_loop_or_load_from_env(
  670. environ_cp,
  671. var_name='GCC_HOST_COMPILER_PATH',
  672. var_default=default_gcc_host_compiler_path,
  673. ask_for_var=
  674. 'Please specify which gcc should be used by nvcc as the host compiler.',
  675. check_success=os.path.exists,
  676. error_msg='Invalid gcc path. %s cannot be found.',
  677. )
  678. write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path)
  679. def reformat_version_sequence(version_str, sequence_count):
  680. """Reformat the version string to have the given number of sequences.
  681. For example:
  682. Given (7, 2) -> 7.0
  683. (7.0.1, 2) -> 7.0
  684. (5, 1) -> 5
  685. (5.0.3.2, 1) -> 5
  686. Args:
  687. version_str: String, the version string.
  688. sequence_count: int, an integer.
  689. Returns:
  690. string, reformatted version string.
  691. """
  692. v = version_str.split('.')
  693. if len(v) < sequence_count:
  694. v = v + (['0'] * (sequence_count - len(v)))
  695. return '.'.join(v[:sequence_count])
  696. def set_tf_cuda_version(environ_cp):
  697. """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
  698. ask_cuda_version = (
  699. 'Please specify the CUDA SDK version you want to use, '
  700. 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
  701. for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
  702. # Configure the Cuda SDK version to use.
  703. tf_cuda_version = get_from_env_or_user_or_default(
  704. environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION)
  705. tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2)
  706. # Find out where the CUDA toolkit is installed
  707. default_cuda_path = _DEFAULT_CUDA_PATH
  708. if is_windows() or is_cygwin():
  709. default_cuda_path = cygpath(
  710. environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN))
  711. elif is_linux():
  712. # If the default doesn't exist, try an alternative default.
  713. if (not os.path.exists(default_cuda_path)
  714. ) and os.path.exists(_DEFAULT_CUDA_PATH_LINUX):
  715. default_cuda_path = _DEFAULT_CUDA_PATH_LINUX
  716. ask_cuda_path = ('Please specify the location where CUDA %s toolkit is'
  717. ' installed. Refer to README.md for more details. '
  718. '[Default is %s]: ') % (tf_cuda_version, default_cuda_path)
  719. cuda_toolkit_path = get_from_env_or_user_or_default(
  720. environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path)
  721. if is_windows():
  722. cuda_rt_lib_path = 'lib/x64/cudart.lib'
  723. elif is_linux():
  724. cuda_rt_lib_path = 'lib64/libcudart.so.%s' % tf_cuda_version
  725. elif is_macos():
  726. cuda_rt_lib_path = 'lib/libcudart.%s.dylib' % tf_cuda_version
  727. cuda_toolkit_path_full = os.path.join(cuda_toolkit_path, cuda_rt_lib_path)
  728. if os.path.exists(cuda_toolkit_path_full):
  729. break
  730. # Reset and retry
  731. print('Invalid path to CUDA %s toolkit. %s cannot be found' %
  732. (tf_cuda_version, cuda_toolkit_path_full))
  733. environ_cp['TF_CUDA_VERSION'] = ''
  734. environ_cp['CUDA_TOOLKIT_PATH'] = ''
  735. else:
  736. raise UserInputError('Invalid TF_CUDA_SETTING setting was provided %d '
  737. 'times in a row. Assuming to be a scripting mistake.' %
  738. _DEFAULT_PROMPT_ASK_ATTEMPTS)
  739. # Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION
  740. environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path
  741. write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path)
  742. environ_cp['TF_CUDA_VERSION'] = tf_cuda_version
  743. write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version)
  744. def set_tf_cudnn_version(environ_cp):
  745. """Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION."""
  746. ask_cudnn_version = (
  747. 'Please specify the cuDNN version you want to use. '
  748. '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION
  749. for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
  750. tf_cudnn_version = get_from_env_or_user_or_default(
  751. environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version,
  752. _DEFAULT_CUDNN_VERSION)
  753. tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1)
  754. default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH')
  755. ask_cudnn_path = (r'Please specify the location where cuDNN %s library is '
  756. 'installed. Refer to README.md for more details. [Default'
  757. ' is %s]:') % (tf_cudnn_version, default_cudnn_path)
  758. cudnn_install_path = get_from_env_or_user_or_default(
  759. environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path)
  760. # Result returned from "read" will be used unexpanded. That make "~"
  761. # unusable. Going through one more level of expansion to handle that.
  762. cudnn_install_path = os.path.realpath(
  763. os.path.expanduser(cudnn_install_path))
  764. if is_windows() or is_cygwin():
  765. cudnn_install_path = cygpath(cudnn_install_path)
  766. if is_windows():
  767. cuda_dnn_lib_path = 'lib/x64/cudnn.lib'
  768. cuda_dnn_lib_alt_path = 'lib/x64/cudnn.lib'
  769. elif is_linux():
  770. cuda_dnn_lib_path = 'lib64/libcudnn.so.%s' % tf_cudnn_version
  771. cuda_dnn_lib_alt_path = 'libcudnn.so.%s' % tf_cudnn_version
  772. elif is_macos():
  773. cuda_dnn_lib_path = 'lib/libcudnn.%s.dylib' % tf_cudnn_version
  774. cuda_dnn_lib_alt_path = 'libcudnn.%s.dylib' % tf_cudnn_version
  775. cuda_dnn_lib_path_full = os.path.join(cudnn_install_path, cuda_dnn_lib_path)
  776. cuda_dnn_lib_alt_path_full = os.path.join(cudnn_install_path,
  777. cuda_dnn_lib_alt_path)
  778. if os.path.exists(cuda_dnn_lib_path_full) or os.path.exists(
  779. cuda_dnn_lib_alt_path_full):
  780. break
  781. # Try another alternative for Linux
  782. if is_linux():
  783. ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
  784. cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
  785. cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)',
  786. cudnn_path_from_ldconfig)
  787. if cudnn_path_from_ldconfig:
  788. cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1)
  789. if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig,
  790. tf_cudnn_version)):
  791. cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
  792. break
  793. # Reset and Retry
  794. print(
  795. 'Invalid path to cuDNN %s toolkit. None of the following files can be '
  796. 'found:' % tf_cudnn_version)
  797. print(cuda_dnn_lib_path_full)
  798. print(cuda_dnn_lib_alt_path_full)
  799. if is_linux():
  800. print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version))
  801. environ_cp['TF_CUDNN_VERSION'] = ''
  802. else:
  803. raise UserInputError('Invalid TF_CUDNN setting was provided %d '
  804. 'times in a row. Assuming to be a scripting mistake.' %
  805. _DEFAULT_PROMPT_ASK_ATTEMPTS)
  806. # Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION
  807. environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path
  808. write_action_env_to_bazelrc('CUDNN_INSTALL_PATH', cudnn_install_path)
  809. environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version
  810. write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
  811. def set_tf_tensorrt_install_path(environ_cp):
  812. """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
  813. Adapted from code contributed by Sami Kama (https://github.com/samikama).
  814. Args:
  815. environ_cp: copy of the os.environ.
  816. Raises:
  817. ValueError: if this method was called under non-Linux platform.
  818. UserInputError: if user has provided invalid input multiple times.
  819. """
  820. if not is_linux():
  821. raise ValueError('Currently TensorRT is only supported on Linux platform.')
  822. # Ask user whether to add TensorRT support.
  823. if str(int(get_var(
  824. environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1':
  825. return
  826. for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
  827. ask_tensorrt_path = (r'Please specify the location where TensorRT is '
  828. 'installed. [Default is %s]:') % (
  829. _DEFAULT_TENSORRT_PATH_LINUX)
  830. trt_install_path = get_from_env_or_user_or_default(
  831. environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path,
  832. _DEFAULT_TENSORRT_PATH_LINUX)
  833. # Result returned from "read" will be used unexpanded. That make "~"
  834. # unusable. Going through one more level of expansion to handle that.
  835. trt_install_path = os.path.realpath(
  836. os.path.expanduser(trt_install_path))
  837. def find_libs(search_path):
  838. """Search for libnvinfer.so in "search_path"."""
  839. fl = set()
  840. if os.path.exists(search_path) and os.path.isdir(search_path):
  841. fl.update([os.path.realpath(os.path.join(search_path, x))
  842. for x in os.listdir(search_path) if 'libnvinfer.so' in x])
  843. return fl
  844. possible_files = find_libs(trt_install_path)
  845. possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
  846. possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
  847. def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver):
  848. """Check the compatibility between tensorrt and cudnn/cudart libraries."""
  849. ldd_bin = which('ldd') or '/usr/bin/ldd'
  850. ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep)
  851. cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
  852. cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
  853. cudnn = None
  854. cudart = None
  855. for line in ldd_out:
  856. if 'libcudnn.so' in line:
  857. cudnn = cudnn_pattern.search(line)
  858. elif 'libcudart.so' in line:
  859. cudart = cuda_pattern.search(line)
  860. if cudnn and len(cudnn.group(1)):
  861. cudnn = convert_version_to_int(cudnn.group(1))
  862. if cudart and len(cudart.group(1)):
  863. cudart = convert_version_to_int(cudart.group(1))
  864. return (cudnn == cudnn_ver) and (cudart == cuda_ver)
  865. cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
  866. cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
  867. nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
  868. highest_ver = [0, None, None]
  869. for lib_file in possible_files:
  870. if is_compatible(lib_file, cuda_ver, cudnn_ver):
  871. matches = nvinfer_pattern.search(lib_file)
  872. if len(matches.groups()) == 0:
  873. continue
  874. ver_str = matches.group(1)
  875. ver = convert_version_to_int(ver_str) if len(ver_str) else 0
  876. if ver > highest_ver[0]:
  877. highest_ver = [ver, ver_str, lib_file]
  878. if highest_ver[1] is not None:
  879. trt_install_path = os.path.dirname(highest_ver[2])
  880. tf_tensorrt_version = highest_ver[1]
  881. break
  882. # Try another alternative from ldconfig.
  883. ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
  884. ldconfig_output = run_shell([ldconfig_bin, '-p'])
  885. search_result = re.search(
  886. '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output)
  887. if search_result:
  888. libnvinfer_path_from_ldconfig = search_result.group(2)
  889. if os.path.exists(libnvinfer_path_from_ldconfig):
  890. if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver):
  891. trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
  892. tf_tensorrt_version = search_result.group(1)
  893. break
  894. # Reset and Retry
  895. if possible_files:
  896. print('TensorRT libraries found in one the following directories',
  897. 'are not compatible with selected cuda and cudnn installations')
  898. print(trt_install_path)
  899. print(os.path.join(trt_install_path, 'lib'))
  900. print(os.path.join(trt_install_path, 'lib64'))
  901. if search_result:
  902. print(libnvinfer_path_from_ldconfig)
  903. else:
  904. print(
  905. 'Invalid path to TensorRT. None of the following files can be found:')
  906. print(trt_install_path)
  907. print(os.path.join(trt_install_path, 'lib'))
  908. print(os.path.join(trt_install_path, 'lib64'))
  909. if search_result:
  910. print(libnvinfer_path_from_ldconfig)
  911. else:
  912. raise UserInputError('Invalid TF_TENSORRT setting was provided %d '
  913. 'times in a row. Assuming to be a scripting mistake.' %
  914. _DEFAULT_PROMPT_ASK_ATTEMPTS)
  915. # Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION
  916. environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
  917. write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
  918. environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version
  919. write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version)
  920. def set_tf_nccl_install_path(environ_cp):
  921. """Set NCCL_INSTALL_PATH and TF_NCCL_VERSION.
  922. Args:
  923. environ_cp: copy of the os.environ.
  924. Raises:
  925. ValueError: if this method was called under non-Linux platform.
  926. UserInputError: if user has provided invalid input multiple times.
  927. """
  928. if not is_linux():
  929. raise ValueError('Currently NCCL is only supported on Linux platforms.')
  930. ask_nccl_version = (
  931. 'Please specify the NCCL version you want to use. '
  932. '[Leave empty to default to NCCL %s]: ') % _DEFAULT_NCCL_VERSION
  933. for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
  934. tf_nccl_version = get_from_env_or_user_or_default(
  935. environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, _DEFAULT_NCCL_VERSION)
  936. tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1)
  937. if tf_nccl_version == '1':
  938. break # No need to get install path, NCCL 1 is a GitHub repo.
  939. # TODO(csigg): Look with ldconfig first if we can find the library in paths
  940. # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding
  941. # include directory. This is where the NCCL .deb packages install them.
  942. # Then ask the user if we should use that. Instead of a single
  943. # NCCL_INSTALL_PATH, pass separate NCCL_LIB_PATH and NCCL_HDR_PATH to
  944. # nccl_configure.bzl
  945. default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH')
  946. ask_nccl_path = (r'Please specify the location where NCCL %s library is '
  947. 'installed. Refer to README.md for more details. [Default '
  948. 'is %s]:') % (tf_nccl_version, default_nccl_path)
  949. nccl_install_path = get_from_env_or_user_or_default(
  950. environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path)
  951. # Result returned from "read" will be used unexpanded. That make "~"
  952. # unusable. Going through one more level of expansion to handle that.
  953. nccl_install_path = os.path.realpath(os.path.expanduser(nccl_install_path))
  954. if is_windows() or is_cygwin():
  955. nccl_install_path = cygpath(nccl_install_path)
  956. if is_windows():
  957. nccl_lib_path = 'lib/x64/nccl.lib'
  958. elif is_linux():
  959. nccl_lib_path = 'lib/libnccl.so.%s' % tf_nccl_version
  960. elif is_macos():
  961. nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version
  962. nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
  963. nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h')
  964. if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
  965. # Set NCCL_INSTALL_PATH
  966. environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
  967. write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
  968. break
  969. # Reset and Retry
  970. print('Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
  971. 'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path,
  972. nccl_hdr_path))
  973. environ_cp['TF_NCCL_VERSION'] = ''
  974. else:
  975. raise UserInputError('Invalid TF_NCCL setting was provided %d '
  976. 'times in a row. Assuming to be a scripting mistake.' %
  977. _DEFAULT_PROMPT_ASK_ATTEMPTS)
  978. # Set TF_NCCL_VERSION
  979. environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
  980. write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version)
  981. def get_native_cuda_compute_capabilities(environ_cp):
  982. """Get native cuda compute capabilities.
  983. Args:
  984. environ_cp: copy of the os.environ.
  985. Returns:
  986. string of native cuda compute capabilities, separated by comma.
  987. """
  988. device_query_bin = os.path.join(
  989. environ_cp.get('CUDA_TOOLKIT_PATH'), 'extras/demo_suite/deviceQuery')
  990. if os.path.isfile(device_query_bin) and os.access(device_query_bin, os.X_OK):
  991. try:
  992. output = run_shell(device_query_bin).split('\n')
  993. pattern = re.compile('[0-9]*\\.[0-9]*')
  994. output = [pattern.search(x) for x in output if 'Capability' in x]
  995. output = ','.join(x.group() for x in output if x is not None)
  996. except subprocess.CalledProcessError:
  997. output = ''
  998. else:
  999. output = ''
  1000. return output
  1001. def set_tf_cuda_compute_capabilities(environ_cp):
  1002. """Set TF_CUDA_COMPUTE_CAPABILITIES."""
  1003. while True:
  1004. native_cuda_compute_capabilities = get_native_cuda_compute_capabilities(
  1005. environ_cp)
  1006. if not native_cuda_compute_capabilities:
  1007. default_cuda_compute_capabilities = _DEFAULT_CUDA_COMPUTE_CAPABILITIES
  1008. else:
  1009. default_cuda_compute_capabilities = native_cuda_compute_capabilities
  1010. ask_cuda_compute_capabilities = (
  1011. 'Please specify a list of comma-separated '
  1012. 'Cuda compute capabilities you want to '
  1013. 'build with.\nYou can find the compute '
  1014. 'capability of your device at: '
  1015. 'https://developer.nvidia.com/cuda-gpus.\nPlease'
  1016. ' note that each additional compute '
  1017. 'capability significantly increases your '
  1018. 'build time and binary size. [Default is: %s]' %
  1019. default_cuda_compute_capabilities)
  1020. tf_cuda_compute_capabilities = get_from_env_or_user_or_default(
  1021. environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES',
  1022. ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
  1023. # Check whether all capabilities from the input is valid
  1024. all_valid = True
  1025. for compute_capability in tf_cuda_compute_capabilities.split(','):
  1026. m = re.match('[0-9]+.[0-9]+', compute_capability)
  1027. if not m:
  1028. print('Invalid compute capability: ' % compute_capability)
  1029. all_valid = False
  1030. else:
  1031. ver = int(m.group(0).split('.')[0])
  1032. if ver < 3:
  1033. print('Only compute capabilities 3.0 or higher are supported.')
  1034. all_valid = False
  1035. if all_valid:
  1036. break
  1037. # Reset and Retry
  1038. environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = ''
  1039. # Set TF_CUDA_COMPUTE_CAPABILITIES
  1040. environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = tf_cuda_compute_capabilities
  1041. write_action_env_to_bazelrc('TF_CUDA_COMPUTE_CAPABILITIES',
  1042. tf_cuda_compute_capabilities)
  1043. def set_other_cuda_vars(environ_cp):
  1044. """Set other CUDA related variables."""
  1045. if is_windows():
  1046. # The following three variables are needed for MSVC toolchain configuration
  1047. # in Bazel
  1048. environ_cp['CUDA_PATH'] = environ_cp.get('CUDA_TOOLKIT_PATH')
  1049. environ_cp['CUDA_COMPUTE_CAPABILITIES'] = environ_cp.get(
  1050. 'TF_CUDA_COMPUTE_CAPABILITIES')
  1051. environ_cp['NO_WHOLE_ARCHIVE_OPTION'] = 1
  1052. write_action_env_to_bazelrc('CUDA_PATH', environ_cp.get('CUDA_PATH'))
  1053. write_action_env_to_bazelrc('CUDA_COMPUTE_CAPABILITIE',
  1054. environ_cp.get('CUDA_COMPUTE_CAPABILITIE'))
  1055. write_action_env_to_bazelrc('NO_WHOLE_ARCHIVE_OPTION',
  1056. environ_cp.get('NO_WHOLE_ARCHIVE_OPTION'))
  1057. write_to_bazelrc('build --config=win-cuda')
  1058. write_to_bazelrc('test --config=win-cuda')
  1059. else:
  1060. # If CUDA is enabled, always use GPU during build and test.
  1061. if environ_cp.get('TF_CUDA_CLANG') == '1':
  1062. write_to_bazelrc('build --config=cuda_clang')
  1063. write_to_bazelrc('test --config=cuda_clang')
  1064. else:
  1065. write_to_bazelrc('build --config=cuda')
  1066. write_to_bazelrc('test --config=cuda')
  1067. def set_host_cxx_compiler(environ_cp):
  1068. """Set HOST_CXX_COMPILER."""
  1069. default_cxx_host_compiler = which('g++') or ''
  1070. host_cxx_compiler = prompt_loop_or_load_from_env(
  1071. environ_cp,
  1072. var_name='HOST_CXX_COMPILER',
  1073. var_default=default_cxx_host_compiler,
  1074. ask_for_var=('Please specify which C++ compiler should be used as the '
  1075. 'host C++ compiler.'),
  1076. check_success=os.path.exists,
  1077. error_msg='Invalid C++ compiler path. %s cannot be found.',
  1078. )
  1079. write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler)
  1080. def set_host_c_compiler(environ_cp):
  1081. """Set HOST_C_COMPILER."""
  1082. default_c_host_compiler = which('gcc') or ''
  1083. host_c_compiler = prompt_loop_or_load_from_env(
  1084. environ_cp,
  1085. var_name='HOST_C_COMPILER',
  1086. var_default=default_c_host_compiler,
  1087. ask_for_var=('Please specify which C compiler should be used as the host '
  1088. 'C compiler.'),
  1089. check_success=os.path.exists,
  1090. error_msg='Invalid C compiler path. %s cannot be found.',
  1091. )
  1092. write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler)
  1093. def set_computecpp_toolkit_path(environ_cp):
  1094. """Set COMPUTECPP_TOOLKIT_PATH."""
  1095. def toolkit_exists(toolkit_path):
  1096. """Check if a computecpp toolkit path is valid."""
  1097. if is_linux():
  1098. sycl_rt_lib_path = 'lib/libComputeCpp.so'
  1099. else:
  1100. sycl_rt_lib_path = ''
  1101. sycl_rt_lib_path_full = os.path.join(toolkit_path,
  1102. sycl_rt_lib_path)
  1103. exists = os.path.exists(sycl_rt_lib_path_full)
  1104. if not exists:
  1105. print('Invalid SYCL %s library path. %s cannot be found' %
  1106. (_TF_OPENCL_VERSION, sycl_rt_lib_path_full))
  1107. return exists
  1108. computecpp_toolkit_path = prompt_loop_or_load_from_env(
  1109. environ_cp,
  1110. var_name='COMPUTECPP_TOOLKIT_PATH',
  1111. var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH,
  1112. ask_for_var=(
  1113. 'Please specify the location where ComputeCpp for SYCL %s is '
  1114. 'installed.' % _TF_OPENCL_VERSION),
  1115. check_success=toolkit_exists,
  1116. error_msg='Invalid SYCL compiler path. %s cannot be found.',
  1117. suppress_default_error=True)
  1118. write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH',
  1119. computecpp_toolkit_path)
  1120. def set_trisycl_include_dir(environ_cp):
  1121. """Set TRISYCL_INCLUDE_DIR."""
  1122. ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
  1123. 'include directory. (Use --config=sycl_trisycl '
  1124. 'when building with Bazel) '
  1125. '[Default is %s]: '
  1126. ) % (_DEFAULT_TRISYCL_INCLUDE_DIR)
  1127. while True:
  1128. trisycl_include_dir = get_from_env_or_user_or_default(
  1129. environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
  1130. _DEFAULT_TRISYCL_INCLUDE_DIR)
  1131. if os.path.exists(trisycl_include_dir):
  1132. break
  1133. print('Invalid triSYCL include directory, %s cannot be found'
  1134. % (trisycl_include_dir))
  1135. # Set TRISYCL_INCLUDE_DIR
  1136. environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
  1137. write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR',
  1138. trisycl_include_dir)
  1139. def set_mpi_home(environ_cp):
  1140. """Set MPI_HOME."""
  1141. default_mpi_home = which('mpirun') or which('mpiexec') or ''
  1142. default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home))
  1143. def valid_mpi_path(mpi_home):
  1144. exists = (os.path.exists(os.path.join(mpi_home, 'include')) and
  1145. os.path.exists(os.path.join(mpi_home, 'lib')))
  1146. if not exists:
  1147. print('Invalid path to the MPI Toolkit. %s or %s cannot be found' %
  1148. (os.path.join(mpi_home, 'include'),
  1149. os.path.exists(os.path.join(mpi_home, 'lib'))))
  1150. return exists
  1151. _ = prompt_loop_or_load_from_env(
  1152. environ_cp,
  1153. var_name='MPI_HOME',
  1154. var_default=default_mpi_home,
  1155. ask_for_var='Please specify the MPI toolkit folder.',
  1156. check_success=valid_mpi_path,
  1157. error_msg='',
  1158. suppress_default_error=True)
  1159. def set_other_mpi_vars(environ_cp):
  1160. """Set other MPI related variables."""
  1161. # Link the MPI header files
  1162. mpi_home = environ_cp.get('MPI_HOME')
  1163. symlink_force('%s/include/mpi.h' % mpi_home, 'third_party/mpi/mpi.h')
  1164. # Determine if we use OpenMPI or MVAPICH, these require different header files
  1165. # to be included here to make bazel dependency checker happy
  1166. if os.path.exists(os.path.join(mpi_home, 'include/mpi_portable_platform.h')):
  1167. symlink_force(
  1168. os.path.join(mpi_home, 'include/mpi_portable_platform.h'),
  1169. 'third_party/mpi/mpi_portable_platform.h')
  1170. # TODO(gunan): avoid editing files in configure
  1171. sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=False',
  1172. 'MPI_LIB_IS_OPENMPI=True')
  1173. else:
  1174. # MVAPICH / MPICH
  1175. symlink_force(
  1176. os.path.join(mpi_home, 'include/mpio.h'), 'third_party/mpi/mpio.h')
  1177. symlink_force(
  1178. os.path.join(mpi_home, 'include/mpicxx.h'), 'third_party/mpi/mpicxx.h')
  1179. # TODO(gunan): avoid editing files in configure
  1180. sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI=True',
  1181. 'MPI_LIB_IS_OPENMPI=False')
  1182. if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')):
  1183. symlink_force(
  1184. os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so')
  1185. else:
  1186. raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home)
  1187. def set_grpc_build_flags():
  1188. write_to_bazelrc('build --define grpc_no_ares=true')
  1189. def set_windows_build_flags():
  1190. if is_windows():
  1191. # The non-monolithic build is not supported yet
  1192. write_to_bazelrc('build --config monolithic')
  1193. # Suppress warning messages
  1194. write_to_bazelrc('build --copt=-w --host_copt=-w')
  1195. # Output more verbose information when something goes wrong
  1196. write_to_bazelrc('build --verbose_failures')
  1197. def config_info_line(name, help_text):
  1198. """Helper function to print formatted help text for Bazel config options."""
  1199. print('\t--config=%-12s\t# %s' % (name, help_text))
  1200. def main():
  1201. parser = argparse.ArgumentParser()
  1202. parser.add_argument("--workspace",
  1203. type=str,
  1204. default=_TF_WORKSPACE_ROOT,
  1205. help="The absolute path to your active Bazel workspace.")
  1206. args = parser.parse_args()
  1207. # Make a copy of os.environ to be clear when functions and getting and setting
  1208. # environment variables.
  1209. environ_cp = dict(os.environ)
  1210. check_bazel_version('0.10.0')
  1211. reset_tf_configure_bazelrc(args.workspace)
  1212. cleanup_makefile()
  1213. setup_python(environ_cp)
  1214. if is_windows():
  1215. environ_cp['TF_NEED_S3'] = '0'
  1216. environ_cp['TF_NEED_GCP'] = '0'
  1217. environ_cp['TF_NEED_HDFS'] = '0'
  1218. environ_cp['TF_NEED_JEMALLOC'] = '0'
  1219. environ_cp['TF_NEED_KAFKA'] = '0'
  1220. environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
  1221. environ_cp['TF_NEED_COMPUTECPP'] = '0'
  1222. environ_cp['TF_NEED_OPENCL'] = '0'
  1223. environ_cp['TF_CUDA_CLANG'] = '0'
  1224. environ_cp['TF_NEED_TENSORRT'] = '0'
  1225. # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on
  1226. # Windows.
  1227. environ_cp['TF_DOWNLOAD_CLANG'] = '0'
  1228. if is_macos():
  1229. environ_cp['TF_NEED_JEMALLOC'] = '0'
  1230. environ_cp['TF_NEED_TENSORRT'] = '0'
  1231. set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
  1232. 'with_jemalloc', True)
  1233. set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform',
  1234. 'with_gcp_support', True, 'gcp')
  1235. set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System',
  1236. 'with_hdfs_support', True, 'hdfs')
  1237. set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System',
  1238. 'with_s3_support', True, 's3')
  1239. set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
  1240. 'with_kafka_support', True, 'kafka')
  1241. set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
  1242. False, 'xla')
  1243. set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
  1244. False, 'gdr')
  1245. set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support',
  1246. False, 'verbs')
  1247. set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False)
  1248. if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
  1249. set_host_cxx_compiler(environ_cp)
  1250. set_host_c_compiler(environ_cp)
  1251. set_action_env_var(environ_cp, 'TF_NEED_COMPUTECPP', 'ComputeCPP', True)
  1252. if environ_cp.get('TF_NEED_COMPUTECPP') == '1':
  1253. set_computecpp_toolkit_path(environ_cp)
  1254. else:
  1255. set_trisycl_include_dir(environ_cp)
  1256. set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
  1257. if (environ_cp.get('TF_NEED_CUDA') == '1' and
  1258. 'TF_CUDA_CONFIG_REPO' not in environ_cp):
  1259. set_tf_cuda_version(environ_cp)
  1260. set_tf_cudnn_version(environ_cp)
  1261. if is_linux():
  1262. set_tf_tensorrt_install_path(environ_cp)
  1263. set_tf_nccl_install_path(environ_cp)
  1264. set_tf_cuda_compute_capabilities(environ_cp)
  1265. if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get(
  1266. 'LD_LIBRARY_PATH') != '1':
  1267. write_action_env_to_bazelrc('LD_LIBRARY_PATH',
  1268. environ_cp.get('LD_LIBRARY_PATH'))
  1269. set_tf_cuda_clang(environ_cp)
  1270. if environ_cp.get('TF_CUDA_CLANG') == '1':
  1271. # Ask whether we should download the clang toolchain.
  1272. set_tf_download_clang(environ_cp)
  1273. if environ_cp.get('TF_DOWNLOAD_CLANG') != '1':
  1274. # Set up which clang we should use as the cuda / host compiler.
  1275. set_clang_cuda_compiler_path(environ_cp)
  1276. else:
  1277. # Set up which gcc nvcc should use as the host compiler
  1278. # No need to set this on Windows
  1279. if not is_windows():
  1280. set_gcc_host_compiler_path(environ_cp)
  1281. set_other_cuda_vars(environ_cp)
  1282. else:
  1283. # CUDA not required. Ask whether we should download the clang toolchain and
  1284. # use it for the CPU build.
  1285. set_tf_download_clang(environ_cp)
  1286. if environ_cp.get('TF_DOWNLOAD_CLANG') == '1':
  1287. write_to_bazelrc('build --config=download_clang')
  1288. write_to_bazelrc('test --config=download_clang')
  1289. set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
  1290. if environ_cp.get('TF_NEED_MPI') == '1':
  1291. set_mpi_home(environ_cp)
  1292. set_other_mpi_vars(environ_cp)
  1293. set_grpc_build_flags()
  1294. set_cc_opt_flags(environ_cp)
  1295. set_windows_build_flags()
  1296. if workspace_has_any_android_rule():
  1297. print('The WORKSPACE file has at least one of ["android_sdk_repository", '
  1298. '"android_ndk_repository"] already set. Will not ask to help '
  1299. 'configure the WORKSPACE. Please delete the existing rules to '
  1300. 'activate the helper.\n')
  1301. else:
  1302. if get_var(
  1303. environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
  1304. False,
  1305. ('Would you like to interactively configure ./WORKSPACE for '
  1306. 'Android builds?'),
  1307. 'Searching for NDK and SDK installations.',
  1308. 'Not configuring the WORKSPACE for Android builds.'):
  1309. create_android_ndk_rule(environ_cp)
  1310. create_android_sdk_rule(environ_cp)
  1311. print('Preconfigured Bazel build configs. You can use any of the below by '
  1312. 'adding "--config=<>" to your build command. See tools/bazel.rc for '
  1313. 'more details.')
  1314. config_info_line('mkl', 'Build with MKL support.')
  1315. config_info_line('monolithic', 'Config for mostly static monolithic build.')
  1316. if __name__ == '__main__':
  1317. main()
  1318.