diff --git a/src/google/adk/code_executors/unsafe_local_code_executor.py b/src/google/adk/code_executors/unsafe_local_code_executor.py index e1e8004..f7b592d 100644 --- a/src/google/adk/code_executors/unsafe_local_code_executor.py +++ b/src/google/adk/code_executors/unsafe_local_code_executor.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from contextlib import redirect_stdout import io +import re +from typing import Any from pydantic import Field from typing_extensions import override @@ -24,6 +28,12 @@ from .code_execution_utils import CodeExecutionInput from .code_execution_utils import CodeExecutionResult +def _prepare_globals(code: str, globals_: dict[str, Any]) -> None: + """Prepare globals for code execution, injecting __name__ if needed.""" + if re.search(r"if\s+__name__\s*==\s*['\"]__main__['\"]", code): + globals_['__name__'] = '__main__' + + class UnsafeLocalCodeExecutor(BaseCodeExecutor): """A code executor that unsafely execute code in the current local context.""" @@ -55,6 +65,7 @@ class UnsafeLocalCodeExecutor(BaseCodeExecutor): error = '' try: globals_ = {} + _prepare_globals(code_execution_input.code, globals_) locals_ = {} stdout = io.StringIO() with redirect_stdout(stdout):