diff options
author | David Luevano Alvarado <55825613+luevano@users.noreply.github.com> | 2020-11-18 22:43:41 -0700 |
---|---|---|
committer | David Luevano Alvarado <55825613+luevano@users.noreply.github.com> | 2020-11-18 22:43:41 -0700 |
commit | 15e353a18b1c7c70000aaecc41953bd82f5df92c (patch) | |
tree | 28f6b80e3b4a80d9ba4574a40ee801d0d2c80838 /other_exp | |
parent | 3c29f2a6782f1e7b9d0ef2856c6ef6cb6c06c76a (diff) |
First exploration, following tutorial
Diffstat (limited to 'other_exp')
-rw-r--r-- | other_exp/gcn_classification.py | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/other_exp/gcn_classification.py b/other_exp/gcn_classification.py new file mode 100644 index 000000000..dc15cf4fd --- /dev/null +++ b/other_exp/gcn_classification.py @@ -0,0 +1,49 @@ +import tensorflow as tf +from tensorflow.compat.v1 import placeholder +from tensorflow.compat.v1.layers import dense + + +# Disable eager execution (evaluate tf operations instantly instead of having +# to build a graph) so placeholder() can work. +tf.compat.v1.disable_eager_execution() + + +n_nodes = 50 +n_features = 50 +n_labels = 10 + +X = placeholder(tf.float64, shape=(None, n_nodes, n_features)) +A = placeholder(tf.float64, shape=(None, n_nodes, n_nodes)) +Y_truth = placeholder(tf.float64, shape=(None, n_labels)) + +# Function for implementation of H⁽l+1)=sigma(A(AH^lW^l)+ b^l). +# With the bias term given by the tf dense layer. +def graph_conv(_X, _A, O): + """ + Equation of graph convolution. + _X: vector X. Nodes. + _A: adjacency matrix. Edges or path. + """ + out = dense(_X, units=O, use_bias=True) + out = tf.matmul(_A, out) + out = tf.nn.relu(out) + + return out + +X_new = graph_conv(X, A, 32) +print(X_new) + +gconv1 = graph_conv(X, A, 32) +gconv2 = graph_conv(gconv1, A, 32) +gconv3 = graph_conv(gconv2, A, 32) + +Y_pred = tf.nn.softmax(dense(gconv3, units=n_labels, use_bias=True), axis=2) +print(Y_pred) + +Y_pred = tf.reshape(Y_pred, [-1]) +loss = tf.reduce_mean(Y_truth*tf.math.log(Y_pred + 1.0 ** -5)) + +print(loss) + + +print(tf.__version__) |