Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion alibi/explainers/cfproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow.python.keras.backend as K

from alibi.api.defaults import DEFAULT_DATA_CFP, DEFAULT_META_CFP
from alibi.api.interfaces import Explainer, Explanation, FitMixin
Expand Down Expand Up @@ -120,7 +121,7 @@ def __init__(self,

# check if the passed object is a model and get session
is_model = isinstance(predict, tf.keras.Model)
model_sess = tf.compat.v1.keras.backend.get_session()
model_sess = K.get_session()
is_ae = isinstance(ae_model, tf.keras.Model)
is_enc = isinstance(enc_model, tf.keras.Model)
self.meta['params'].update(is_model=is_model, is_ae=is_ae, is_enc=is_enc)
Expand Down
3 changes: 2 additions & 1 deletion alibi/explainers/counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow.python.keras.backend as K

from alibi.api.defaults import DEFAULT_DATA_CF, DEFAULT_META_CF
from alibi.api.interfaces import Explainer, Explanation
Expand Down Expand Up @@ -167,7 +168,7 @@ def __init__(self,

# check if the passed object is a model and get session
is_model = isinstance(predict_fn, tf.keras.Model)
model_sess = tf.compat.v1.keras.backend.get_session()
model_sess = K.get_session()

self.meta['params'].update(is_model=is_model)

Expand Down
56 changes: 56 additions & 0 deletions alibi/tests/test_counterfactual_proto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# alibi/tests/test_counterfactual_proto.py
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

import tensorflow.python.keras.backend as K

from alibi.api.interfaces import Explanation
from alibi.explainers.cfproto import CounterfactualProto


def test_cfproto_uses_k_session_blackbox():
tf.reset_default_graph()
sess = tf.Session()
K.set_session(sess)
K.set_learning_phase(0)

with sess.as_default():
# Simple TF1 graph: softmax(Wx+b)
x_ph = tf.placeholder(tf.float32, shape=(None, 4), name="x")
W = tf.get_variable("W", shape=(4, 2),
initializer=tf.random_normal_initializer(stddev=0.1))
b = tf.get_variable("b", shape=(2,),
initializer=tf.zeros_initializer())
logits = tf.matmul(x_ph, W) + b
probs = tf.nn.softmax(logits)

sess.run(tf.global_variables_initializer())

def predict_fn(x: np.ndarray) -> np.ndarray:
return sess.run(probs, feed_dict={x_ph: x})

explainer = CounterfactualProto(
predict=predict_fn,
shape=(1, 4),
max_iterations=5,
c_steps=1,
c_init=0.0,
kappa=0.0,
beta=0.1,
gamma=0.0,
use_kdtree=False,
)

assert explainer.sess is K.get_session()

x0 = np.zeros((1, 4), dtype="float32")
explanation = explainer.explain(x0)

assert isinstance(explanation, Explanation)
assert "orig_class" in explanation.data
if explanation.data.get("cf") and "X" in explanation.data["cf"]:
assert explanation.data["cf"]["X"].shape == x0.shape