`tensorflow_graphics.shape.check_static(...)` throwing error in TF2

0
open
m4ttr4ymond
m4ttr4ymond
Posted 3 months ago

`tensorflow_graphics.shape.check_static(...)` throwing error in TF2 #625

I was writing a data augmentation layer for a PointNet implementation and ran into what appears to be a bug in tensorflow_graphics.shape.check_static(...), as seen on this line.

Offending layer:

class RandomRot(Layer):
  def __init__(self):
    super(RandomRot, self).__init__()

  def build(self, input_shape):
    self.s = tf.constant([input_shape[-1],])

  def call(self, inputs, training=None):
    if not training: return inputs
    
    r = tf.random.uniform(
      shape=self.s,
      minval=0,
      maxval=6.28,
    )

    return tf.linalg.matmul(inputs,from_euler(r))

Error message:

AttributeError: in user code:

    <ipython-input-135-d11754641da6>:81 call  *
        self.x = self.r(self.x,training)
    <ipython-input-130-07bfe7ac5ab9>:25 call  *
        return tf.linalg.matmul(inputs,from_euler(r))
    /usr/local/lib/python3.6/dist-packages/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py:201 from_euler  *
        shape.check_static(
    /usr/local/lib/python3.6/dist-packages/tensorflow_graphics/util/shape.py:206 check_static  *
        if _get_dim(tensor, axis) != value:
    /usr/local/lib/python3.6/dist-packages/tensorflow_graphics/util/shape.py:135 _get_dim  *
        return tensor.shape[axis].value

    AttributeError: 'int' object has no attribute 'value'

It appears that check_static is expecting each element from .shape to be a tensor, but in TF2 they're just ints. If I comment out check_static from from_euler, the function works fine. Strangely enough, it seems to work fine for tensors in eager execution, and only seems to throw errors when using Dataset objects with graph compilation.