@@ -2046,10 +2046,10 @@ def predict(
20462046
20472047 def compute_contrast (
20482048 self ,
2049- covariates_0 : Union [np .array , pd .DataFrame ],
2050- covariates_1 : Union [np .array , pd .DataFrame ],
2051- basis_0 : np .array = None ,
2052- basis_1 : np .array = None ,
2049+ X_0 : Union [np .array , pd .DataFrame ],
2050+ X_1 : Union [np .array , pd .DataFrame ],
2051+ leaf_basis_0 : np .array = None ,
2052+ leaf_basis_1 : np .array = None ,
20532053 rfx_group_ids_0 : np .array = None ,
20542054 rfx_group_ids_1 : np .array = None ,
20552055 rfx_basis_0 : np .array = None ,
@@ -2068,13 +2068,13 @@ def compute_contrast(
20682068
20692069 Parameters
20702070 ----------
2071- covariates_0 : np.array or pd.DataFrame
2071+ X_0 : np.array or pd.DataFrame
20722072 Covariates used for prediction in the "control" case. Must be a numpy array or dataframe.
2073- covariates_1 : np.array or pd.DataFrame
2073+ X_1 : np.array or pd.DataFrame
20742074 Covariates used for prediction in the "treatment" case. Must be a numpy array or dataframe.
2075- basis_0 : np.array, optional
2075+ leaf_basis_0 : np.array, optional
20762076 Bases used for prediction in the "control" case (by e.g. dot product with leaf values).
2077- basis_1 : np.array, optional
2077+ leaf_basis_1 : np.array, optional
20782078 Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values).
20792079 rfx_group_ids_0 : np.array, optional
20802080 Test set group labels used for prediction from an additive random effects model in the "control" case.
@@ -2135,33 +2135,33 @@ def compute_contrast(
21352135 raise NotSampledError (msg )
21362136
21372137 # Data checks
2138- if not isinstance (covariates_0 , pd .DataFrame ) and not isinstance (
2139- covariates_0 , np .ndarray
2138+ if not isinstance (X_0 , pd .DataFrame ) and not isinstance (
2139+ X_0 , np .ndarray
21402140 ):
2141- raise ValueError ("covariates_0 must be a pandas dataframe or numpy array" )
2142- if not isinstance (covariates_1 , pd .DataFrame ) and not isinstance (
2143- covariates_1 , np .ndarray
2141+ raise ValueError ("X_0 must be a pandas dataframe or numpy array" )
2142+ if not isinstance (X_1 , pd .DataFrame ) and not isinstance (
2143+ X_1 , np .ndarray
21442144 ):
2145- raise ValueError ("covariates_1 must be a pandas dataframe or numpy array" )
2146- if basis_0 is not None :
2147- if not isinstance (basis_0 , np .ndarray ):
2148- raise ValueError ("basis_0 must be a numpy array" )
2149- if basis_0 .shape [0 ] != covariates_0 .shape [0 ]:
2145+ raise ValueError ("X_1 must be a pandas dataframe or numpy array" )
2146+ if leaf_basis_0 is not None :
2147+ if not isinstance (leaf_basis_0 , np .ndarray ):
2148+ raise ValueError ("leaf_basis_0 must be a numpy array" )
2149+ if leaf_basis_0 .shape [0 ] != X_0 .shape [0 ]:
21502150 raise ValueError (
2151- "covariates_0 and basis_0 must have the same number of rows"
2151+ "X_0 and leaf_basis_0 must have the same number of rows"
21522152 )
2153- if basis_1 is not None :
2154- if not isinstance (basis_1 , np .ndarray ):
2155- raise ValueError ("basis_1 must be a numpy array" )
2156- if basis_1 .shape [0 ] != covariates_1 .shape [0 ]:
2153+ if leaf_basis_1 is not None :
2154+ if not isinstance (leaf_basis_1 , np .ndarray ):
2155+ raise ValueError ("leaf_basis_1 must be a numpy array" )
2156+ if leaf_basis_1 .shape [0 ] != X_1 .shape [0 ]:
21572157 raise ValueError (
2158- "covariates_1 and basis_1 must have the same number of rows"
2158+ "X_1 and leaf_basis_1 must have the same number of rows"
21592159 )
21602160
21612161 # Predict for the control arm
21622162 control_preds = self .predict (
2163- covariates = covariates_0 ,
2164- basis = basis_0 ,
2163+ X = X_0 ,
2164+ leaf_basis = leaf_basis_0 ,
21652165 rfx_group_ids = rfx_group_ids_0 ,
21662166 rfx_basis = rfx_basis_0 ,
21672167 type = "posterior" ,
@@ -2171,8 +2171,8 @@ def compute_contrast(
21712171
21722172 # Predict for the treatment arm
21732173 treatment_preds = self .predict (
2174- covariates = covariates_1 ,
2175- basis = basis_1 ,
2174+ X = X_1 ,
2175+ leaf_basis = leaf_basis_1 ,
21762176 rfx_group_ids = rfx_group_ids_1 ,
21772177 rfx_basis = rfx_basis_1 ,
21782178 type = "posterior" ,
0 commit comments