1818from typing import Callable , TypeAlias
1919
2020import jaxlib
21+ from catalyst import QJIT
2122from jaxlib .mlir .dialects import stablehlo as jstablehlo # pylint: disable=no-name-in-module
2223from jaxlib .mlir .ir import Context as jContext # pylint: disable=no-name-in-module
2324from jaxlib .mlir .ir import Module as jModule # pylint: disable=no-name-in-module
3031from xdsl .dialects import tensor as xtensor
3132from xdsl .dialects import transform as xtransform
3233from xdsl .parser import Parser as xParser
34+ from xdsl .traits import SymbolTable as xSymbolTable
3335
3436from pennylane .compiler .python_compiler .quantum_dialect import QuantumDialect
3537
@@ -42,9 +44,7 @@ def _module_inline(func: JaxJittedFunction, *args, **kwargs) -> jModule:
4244
4345
4446def module (func : JaxJittedFunction ) -> Callable [..., jModule ]:
45- """
46- Decorator for _module_inline
47- """
47+ """Decorator for _module_inline"""
4848
4949 @wraps (func )
5050 def wrapper (* args , ** kwargs ) -> jModule :
@@ -54,18 +54,14 @@ def wrapper(*args, **kwargs) -> jModule:
5454
5555
5656def _generic_inline (func : JaxJittedFunction , * args , ** kwargs ) -> str : # pragma: no cover
57- """
58- Create the generic textual representation for the jax.jit'ed function
59- """
57+ """Create the generic textual representation for the jax.jit'ed function"""
6058 lowered = func .lower (* args , ** kwargs )
6159 mod = lowered .compiler_ir ()
6260 return mod .operation .get_asm (binary = False , print_generic_op_form = True , assume_verified = True )
6361
6462
6563def generic (func : JaxJittedFunction ) -> Callable [..., str ]: # pragma: no cover
66- """
67- Decorator for _generic_inline.
68- """
64+ """Decorator for _generic_inline."""
6965
7066 @wraps (func )
7167 def wrapper (* args , ** kwargs ) -> str :
@@ -125,12 +121,47 @@ def wrapper(*_, **__):
125121
126122
127123def xdsl_module (func : JaxJittedFunction ) -> Callable [..., xbuiltin .ModuleOp ]: # pragma: no cover
128- """
129- Decorator for _xdsl_module_inline
130- """
124+ """Decorator for _xdsl_module_inline"""
131125
132126 @wraps (func )
133127 def wrapper (* args , ** kwargs ) -> xbuiltin .ModuleOp :
134128 return _xdsl_module_inline (func , * args , ** kwargs )
135129
136130 return wrapper
131+
132+
133+ def inline_module (
134+ from_mod : xbuiltin .ModuleOp , to_mod : xbuiltin .ModuleOp , change_main_to : str = None
135+ ) -> None :
136+ """Inline the contents of one xDSL module into another xDSL module. The inlined body is appended
137+ to the end of ``to_mod``."""
138+ if change_main_to :
139+ main = xSymbolTable .lookup_symbol (from_mod , "main" )
140+ if main is not None :
141+ assert isinstance (main , xfunc .FuncOp )
142+ main .properties ["sym_name" ] = xbuiltin .StringAttr (change_main_to )
143+
144+ for op in from_mod .body .ops :
145+ xSymbolTable .insert_or_update (to_mod , op .clone ())
146+
147+
148+ def inline_jit_to_module (func : JaxJittedFunction , mod : xbuiltin .ModuleOp , * args , ** kwargs ) -> None :
149+ """Inline a ``jax.jit``-ed Python function to an xDSL module. The inlined body is appended
150+ to the end of ``mod``."""
151+ func_mod = _xdsl_module_inline (func , * args , ** kwargs )
152+ inline_module (func_mod , mod , change_main_to = func .__name__ )
153+
154+
155+ def xdsl_from_qjit (func : QJIT ) -> Callable [..., xbuiltin .ModuleOp ]:
156+ """Decorator to convert QJIT-ed functions into xDSL modules."""
157+
158+ @wraps (func )
159+ def wrapper (* args , ** kwargs ):
160+ func .jaxpr , * _ = func .capture (args , ** kwargs )
161+ mlir_module = func .generate_ir ()
162+ generic_str = mlir_module .operation .get_asm (
163+ binary = False , print_generic_op_form = True , assume_verified = True
164+ )
165+ return parse_generic_to_xdsl_module (generic_str )
166+
167+ return wrapper
0 commit comments